Source code

Revision control

Copy as Markdown

Other Tools

#!/usr/bin/env python3
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
import argparse
import hashlib
import os
import shutil
import subprocess
import sys
import urllib.request
from pathlib import Path
import yaml
HERE = Path(__file__).resolve().parent
FETCH_FILE = (
HERE / "../../../../../taskcluster/kinds/fetch/onnxruntime-web-fetch.yml"
).resolve()
def is_git_lfs_installed():
try:
output = subprocess.check_output(
["git", "lfs", "version"], stderr=subprocess.DEVNULL, text=True
)
return "git-lfs" in output.lower()
except (subprocess.CalledProcessError, FileNotFoundError):
return False
def compute_sha256(file_path):
"""Compute SHA-256 of a file (binary read)."""
hasher = hashlib.sha256()
with file_path.open("rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hasher.update(chunk)
return hasher.hexdigest()
def download_wasm(fetches, fetches_dir):
"""
Download and verify ort.jsep.wasm if needed,
using the 'ort.jsep.wasm' entry in the YAML file.
"""
wasm_fetch = fetches["ort.jsep.wasm"]["fetch"]
url = wasm_fetch["url"]
expected_sha256 = wasm_fetch["sha256"]
filename = url.split("/")[-1]
output_file = fetches_dir / filename
# If the file exists and its checksum matches, skip re-download
if output_file.exists():
print(f"Found existing file {output_file}, verifying checksum...")
if compute_sha256(output_file) == expected_sha256:
print("Existing file's checksum matches. Skipping download.")
return
else:
print("Checksum mismatch on existing file. Removing and re-downloading...")
output_file.unlink()
# Download the file
print(f"Downloading {url} to {output_file}...")
with urllib.request.urlopen(url) as response, open(output_file, "wb") as out_file:
shutil.copyfileobj(response, out_file)
# Verify SHA-256
print(f"Verifying SHA-256 of {output_file}...")
downloaded_sha256 = compute_sha256(output_file)
if downloaded_sha256 != expected_sha256:
output_file.unlink(missing_ok=True)
raise ValueError(
f"Checksum mismatch for {filename}! "
f"Expected: {expected_sha256}, got: {downloaded_sha256}"
)
print(f"File {filename} downloaded and verified successfully!")
def list_models(fetches):
"""
List all YAML keys where fetch.type == 'git',
along with the path-prefix specified in the YAML.
"""
print("Available git-based models from the YAML:\n")
for key, data in fetches.items():
fetch = data.get("fetch")
if fetch and fetch.get("type") == "git":
path_prefix = fetch.get("path-prefix", "[no path-prefix specified]")
print(f"- {key} -> path-prefix: {path_prefix}")
print("\n(Use `--model <key>` to clone one of these repositories.)")
def clone_model(key, data, fetches_dir):
"""
Clone (or re-clone) a model if needed.
The directory is determined by 'path-prefix' from the YAML,
relative to --fetches-dir. Example:
path-prefix: "onnx-models/Xenova/all-MiniLM-L6-v2/main/"
We'll end up cloning to <fetches-dir>/onnx-models/Xenova/all-MiniLM-L6-v2/main
"""
fetch_data = data["fetch"]
repo_url = fetch_data["repo"]
path_prefix = fetch_data["path-prefix"]
revision = fetch_data.get("revision", "main")
# Compute the final directory from --fetches-dir + path-prefix
repo_dir = fetches_dir / path_prefix
# Ensure parent directories exist
repo_dir.parent.mkdir(parents=True, exist_ok=True)
# If the target directory exists, verify that it matches the correct repo & revision
if repo_dir.exists():
# 1. Check if .git exists
if not (repo_dir / ".git").is_dir():
print(f"Directory '{repo_dir}' exists but is not a git repo. Removing it.")
shutil.rmtree(repo_dir, ignore_errors=True)
else:
# 2. Check if remote origin URL matches
try:
existing_url = subprocess.check_output(
["git", "remote", "get-url", "origin"], cwd=repo_dir, text=True
).strip()
except subprocess.CalledProcessError:
existing_url = None
if existing_url != repo_url:
print(
f"Repository at '{repo_dir}' has remote '{existing_url}' "
f"instead of '{repo_url}'. Removing it."
)
shutil.rmtree(repo_dir, ignore_errors=True)
else:
# 3. Check if HEAD commit matches 'revision'
try:
current_revision = subprocess.check_output(
["git", "rev-parse", "HEAD"],
cwd=repo_dir,
text=True,
).strip()
except subprocess.CalledProcessError:
current_revision = None
# If the revision is a branch name or tag, matching HEAD exactly
# might not always be correct. We're keeping it simple:
# if HEAD != revision, remove & reclone.
if current_revision != revision:
print(
f"Repo at '{repo_dir}' has HEAD {current_revision}, "
f"but we need '{revision}'. Removing it."
)
shutil.rmtree(repo_dir, ignore_errors=True)
# If we removed the directory or it never existed, clone it
if not repo_dir.exists():
print(f"Cloning {repo_url} into '{repo_dir}'...")
# Normal clone first
subprocess.run(["git", "clone", repo_url, str(repo_dir)], check=True)
# Then checkout the desired revision (branch, commit, or tag)
subprocess.run(["git", "checkout", revision], cwd=repo_dir, check=True)
print(f"Checked out revision '{revision}' in '{repo_dir}'.")
else:
print(f"{repo_dir} already exists and is up to date. Skipping clone.")
def clone_models(keys, fetches, fetches_dir):
"""
Clone each model specified by YAML key, if fetch.type == 'git'.
Uses the path-prefix from the YAML to determine the final directory.
"""
if not keys:
return
# Initialize git lfs once (if we have at least one model)
subprocess.run(["git", "lfs", "install"], check=True)
for key in keys:
if key not in fetches:
raise ValueError(f"Model '{key}' not found in YAML.")
data = fetches[key]
if data.get("fetch", {}).get("type") != "git":
raise ValueError(f"Model '{key}' is not a git fetch type.")
clone_model(key, data, fetches_dir)
def main():
if not is_git_lfs_installed():
print("git lfs is required for this program to run:")
print("\t$ sudo apt install git-lfs")
print("\t$ sudo yum install git-lfs")
print("\t$ brew install git-lfs")
print()
sys.exit(1)
parser = argparse.ArgumentParser(
description="Download ort.jsep.wasm and optionally clone specified models."
)
default_dir = os.getenv("MOZ_ML_LOCAL_DIR", None)
parser.add_argument(
"--fetches-dir",
help="Directory to store the downloaded files (and cloned repos). Uses MOZ_FETCH_DIR if present.",
default=default_dir,
)
parser.add_argument(
"--list-models",
action="store_true",
help="List all available git-based models (keys in the YAML) and exit.",
)
parser.add_argument(
"--model",
action="append",
help="YAML key of a model to clone (can be specified multiple times).",
)
args = parser.parse_args()
# Load YAML
with FETCH_FILE.open("r", encoding="utf-8") as f:
fetches = yaml.safe_load(f)
# If listing models, do so and exit
if args.list_models:
list_models(fetches)
return
if args.fetches_dir is None:
raise ValueError(
"Missing --fetches-dir argument or MOZ_ML_LOCAL_DIR env var. Please specify a directory to store the downloaded files"
)
fetches_dir = Path(args.fetches_dir).resolve()
fetches_dir.mkdir(parents=True, exist_ok=True)
# Always download/verify ort.jsep.wasm
download_wasm(fetches, fetches_dir)
# Clone requested models
if args.model:
clone_models(args.model, fetches, fetches_dir)
if __name__ == "__main__":
main()