Source code
Revision control
Copy as Markdown
Other Tools
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
import http.server
import os
import socket
import socketserver
import threading
from pathlib import Path
THREADS = []
class CustomHTTPRequestHandler(http.server.SimpleHTTPRequestHandler):
hub_root = ""
def translate_path(self, path):
# Remove front slash and query args to match the files
return str(self.hub_root / Path(path.lstrip("/").split("?")[0]))
def send_head(self):
path = Path(self.translate_path(self.path))
if path.is_dir():
return super().send_head()
# when dealing with a file, we set the ETag header using the file size.
if path.is_file():
file_size = path.stat().st_size
etag = f'"{file_size}"'
# Handle conditional GET requests
if_match = self.headers.get("If-None-Match")
if if_match == etag:
self.send_response(304)
self.end_headers()
return None
self.send_response(200)
self.send_header("Content-type", self.guess_type(str(path)))
self.send_header("Content-Length", str(file_size))
self.send_header("ETag", etag)
self.end_headers()
return path.open("rb")
self.send_error(404, "File not found")
def serve_directory(directory, port):
"""Serves the directory at the given port."""
CustomHTTPRequestHandler.hub_root = directory
with socketserver.TCPServer(("", port), CustomHTTPRequestHandler) as httpd:
httpd.serve_forever()
def start_hub(root_directory):
"""Starts a local hub server and returns the port and thread."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
port = s.getsockname()[1]
server_thread = threading.Thread(
target=serve_directory, args=(root_directory, port), daemon=True
)
server_thread.start()
return port, server_thread
def before_runs(env):
"""Runs before all performance tests.
We grab MOZ_ML_LOCAL_DIR. If set we serve MOZ_ML_LOCAL_DIR/onnx-models as our local hub.
MOZ_FETCHES_DIR is used in the CI as an alternate localtion.
"""
fetches_dir = os.environ.get("MOZ_ML_LOCAL_DIR")
if fetches_dir is None:
fetches_dir = os.environ.get("MOZ_FETCHES_DIR")
if fetches_dir is None:
return
hub_dir = Path(fetches_dir) / "onnx-models"
if not hub_dir.is_dir():
return
port, server_thread = start_hub(hub_dir)
THREADS.append(server_thread)
def after_runs(env):
if len(THREADS) > 0:
print("Shutting down")
THREADS[0].join(timeout=0)
THREADS.clear()