Source code for djura.data_loader
# SPDX-License-Identifier: AGPL-3.0-or-later
# Copyright (C) 2025-2026 Djura | Risk - Data - Engineering S.r.l.
"""Download, cache, and load the bundled NGA-West2 pickle dataset.
The dataset is too large (>100 MB) to ship inside the wheel, so it is
hosted as a gzip-compressed asset on a GitHub Release and fetched on
first use into a per-user cache directory.
"""
import gzip
import hashlib
import os
import pickle
import shutil
import urllib.error
import urllib.request
from pathlib import Path
from typing import Any
PACKAGE_NAME = "djura"
DATA_FILENAME = "NGA_W2_v2.pickle"
# Update both constants (and re-run the release-data workflow) when the
# dataset changes. Compute the new hash with:
# python -c "import hashlib,sys; \
# print(hashlib.file_digest(open(sys.argv[1],'rb'),'sha256').hexdigest())" \
# NGA_W2_v2.pickle.gz
GITHUB_RELEASE_URL = (
"https://github.com/djura-risk-data-engineering/djura"
"/releases/download/data-v1/NGA_W2_v2.pickle.gz"
)
EXPECTED_SHA256 = (
# SHA-256 of the compressed .gz asset at the URL above.
# Fill this in by running the command above against the actual release
# asset, then commit the result.
"21cdc4519483e5f1187ebf69d04c9643fa771507f30026f39886c6e29af3aa4e"
)
# Refuse downloads larger than 500 MB (uncompressed pickle is ~107 MB).
_MAX_DOWNLOAD_BYTES = 500 * 1024 * 1024
_MB = 1024 ** 2
def _cache_dir() -> Path:
return Path.home() / ".cache" / PACKAGE_NAME
def _cache_path() -> Path:
return _cache_dir() / DATA_FILENAME
def _sha256(path: Path) -> str:
h = hashlib.sha256()
with open(path, "rb") as f:
for chunk in iter(lambda: f.read(1 << 20), b""):
h.update(chunk)
return h.hexdigest()
def _download_and_extract(dest: Path) -> None:
url = GITHUB_RELEASE_URL
dest.parent.mkdir(parents=True, exist_ok=True)
tmp_gz = dest.with_suffix(dest.suffix + ".gz.part")
try:
with urllib.request.urlopen(url, timeout=120) as response, \
open(tmp_gz, "wb") as out:
downloaded = 0
while chunk := response.read(1 << 20):
downloaded += len(chunk)
if downloaded > _MAX_DOWNLOAD_BYTES:
raise RuntimeError(
f"Download from {url} exceeded "
f"{_MAX_DOWNLOAD_BYTES // _MB} MB limit — "
"aborting."
)
out.write(chunk)
except urllib.error.HTTPError as e:
tmp_gz.unlink(missing_ok=True)
raise RuntimeError(
f"Failed to download dataset from {url} "
f"(HTTP {e.code}). Make sure the GitHub Release "
"exists and the asset is public."
) from e
except urllib.error.URLError as e:
tmp_gz.unlink(missing_ok=True)
raise RuntimeError(
f"Failed to download dataset from {url}: "
f"{e.reason}. Check your network connection."
) from e
if EXPECTED_SHA256:
actual = _sha256(tmp_gz)
if actual != EXPECTED_SHA256:
tmp_gz.unlink(missing_ok=True)
raise RuntimeError(
"SHA-256 mismatch for downloaded asset.\n"
f" expected: {EXPECTED_SHA256}\n"
f" actual: {actual}\n"
"The file may be corrupted or tampered with. "
"Delete the partial download and try again, or "
"report the issue at https://github.com/"
"djura-risk-data-engineering/djura/issues"
)
tmp_pkl = dest.with_suffix(dest.suffix + ".part")
try:
with gzip.open(tmp_gz, "rb") as gz_in, \
open(tmp_pkl, "wb") as pkl_out:
shutil.copyfileobj(gz_in, pkl_out)
tmp_pkl.replace(dest)
finally:
tmp_gz.unlink(missing_ok=True)
tmp_pkl.unlink(missing_ok=True)
[docs]
def load_data() -> Any:
"""Return the deserialized dataset, downloading and caching it if needed.
"""
cache = _cache_path()
if not cache.exists():
_download_and_extract(cache)
with open(cache, "rb") as f:
return pickle.load(f)
[docs]
def clear_cache() -> None:
"""Remove the cached dataset so it is re-downloaded on next load_data().
"""
global _nga_west2
_nga_west2 = None
cache = _cache_path()
cache.unlink(missing_ok=True)
_nga_west2: Any = None
[docs]
def get_nga_west2() -> Any:
"""Return the NGA-West2 metadata, loading it at most once per process.
Override the source by setting the ``DJURA_METADATA_PATH`` environment
variable to the path of a custom pickle file.
"""
global _nga_west2
if _nga_west2 is None:
custom = os.environ.get("DJURA_METADATA_PATH")
if custom:
with open(custom, "rb") as f:
_nga_west2 = pickle.load(f)
else:
_nga_west2 = load_data()
return _nga_west2