Source code
Revision control
Copy as Markdown
Other Tools
#!/usr/bin/env python
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
"""Spawns necessary HTTP servers for testing Marionette in child
processes.
"""
import argparse
import multiprocessing
import os
import sys
from collections import defaultdict
from six import iteritems
from . import httpd
__all__ = [
"default_doc_root",
"iter_proc",
"iter_url",
"registered_servers",
"servers",
"start",
"where_is",
]
here = os.path.abspath(os.path.dirname(__file__))
class BlockingChannel(object):
def __init__(self, channel):
self.chan = channel
self.lock = multiprocessing.Lock()
def call(self, func, args=()):
self.send((func, args))
return self.recv()
def send(self, *args):
try:
self.lock.acquire()
self.chan.send(args)
finally:
self.lock.release()
def recv(self):
try:
self.lock.acquire()
payload = self.chan.recv()
if isinstance(payload, tuple) and len(payload) == 1:
return payload[0]
return payload
except KeyboardInterrupt:
return ("stop", ())
finally:
self.lock.release()
class ServerProxy(multiprocessing.Process, BlockingChannel):
def __init__(self, channel, init_func, *init_args, **init_kwargs):
multiprocessing.Process.__init__(self)
BlockingChannel.__init__(self, channel)
self.init_func = init_func
self.init_args = init_args
self.init_kwargs = init_kwargs
def run(self):
try:
server = self.init_func(*self.init_args, **self.init_kwargs)
server.start()
self.send(("ok", ()))
while True:
# ["func", ("arg", ...)]
# ["prop", ()]
sattr, fargs = self.recv()
attr = getattr(server, sattr)
# apply fargs to attr if it is a function
if callable(attr):
rv = attr(*fargs)
# otherwise attr is a property
else:
rv = attr
self.send(rv)
if sattr == "stop":
return
except Exception as e:
self.send(("stop", e))
except KeyboardInterrupt:
server.stop()
class ServerProc(BlockingChannel):
def __init__(self, init_func):
self._init_func = init_func
self.proc = None
parent_chan, self.child_chan = multiprocessing.Pipe()
BlockingChannel.__init__(self, parent_chan)
def start(self, doc_root, ssl_config, **kwargs):
self.proc = ServerProxy(
self.child_chan, self._init_func, doc_root, ssl_config, **kwargs
)
self.proc.daemon = True
self.proc.start()
res, exc = self.recv()
if res == "stop":
raise exc
def get_url(self, url):
return self.call("get_url", (url,))
@property
def doc_root(self):
return self.call("doc_root", ())
def stop(self):
self.call("stop")
if not self.is_alive:
return
self.proc.join()
def kill(self):
if not self.is_alive:
return
self.proc.terminate()
self.proc.join(0)
@property
def is_alive(self):
if self.proc is not None:
return self.proc.is_alive()
return False
def http_server(doc_root, ssl_config, host="127.0.0.1", **kwargs):
def https_server(doc_root, ssl_config, host="127.0.0.1", **kwargs):
return httpd.FixtureServer(
doc_root,
ssl_key=ssl_config["key_path"],
ssl_cert=ssl_config["cert_path"],
**kwargs
)
def start_servers(doc_root, ssl_config, **kwargs):
servers = defaultdict()
for schema, builder_fn in registered_servers:
proc = ServerProc(builder_fn)
proc.start(doc_root, ssl_config, **kwargs)
servers[schema] = (proc.get_url("/"), proc)
return servers
def start(doc_root=None, **kwargs):
"""Start all relevant test servers.
If no `doc_root` is given the default
testing/marionette/harness/marionette_harness/www directory will be used.
Additional keyword arguments can be given which will be passed on
to the individual ``FixtureServer``'s in httpd.py.
"""
doc_root = doc_root or default_doc_root
ssl_config = {
"cert_path": httpd.default_ssl_cert,
"key_path": httpd.default_ssl_key,
}
global servers
servers = start_servers(doc_root, ssl_config, **kwargs)
return servers
def where_is(uri, on="http"):
"""Returns the full URL, including scheme, hostname, and port, for
a fixture resource from the server associated with the ``on`` key.
It will by default look for the resource in the "http" server.
"""
return servers.get(on)[1].get_url(uri)
def iter_proc(servers):
for _, (_, proc) in iteritems(servers):
yield proc
def iter_url(servers):
for _, (url, _) in iteritems(servers):
yield url
default_doc_root = os.path.join(os.path.dirname(here), "www")
registered_servers = [("http", http_server), ("https", https_server)]
servers = defaultdict()
def main(args):
global servers
parser = argparse.ArgumentParser()
parser.add_argument(
"-r", dest="doc_root", help="Path to document root. Overrides default."
)
args = parser.parse_args()
servers = start(args.doc_root)
for url in iter_url(servers):
print("{}: listening on {}".format(sys.argv[0], url), file=sys.stderr)
try:
while any(proc.is_alive for proc in iter_proc(servers)):
for proc in iter_proc(servers):
proc.proc.join(1)
except KeyboardInterrupt:
for proc in iter_proc(servers):
proc.kill()
if __name__ == "__main__":
main(sys.argv[1:])