Source code

Revision control

Copy as Markdown

Other Tools

#!/usr/bin/env python
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
"""Spawns necessary HTTP servers for testing Marionette in child
processes.
"""
import argparse
import multiprocessing
import os
import sys
from collections import defaultdict
from six import iteritems
from . import httpd
__all__ = [
"default_doc_root",
"iter_proc",
"iter_url",
"registered_servers",
"servers",
"start",
"where_is",
]
here = os.path.abspath(os.path.dirname(__file__))
class BlockingChannel(object):
def __init__(self, channel):
self.chan = channel
self.lock = multiprocessing.Lock()
def call(self, func, args=()):
self.send((func, args))
return self.recv()
def send(self, *args):
try:
self.lock.acquire()
self.chan.send(args)
finally:
self.lock.release()
def recv(self):
try:
self.lock.acquire()
payload = self.chan.recv()
if isinstance(payload, tuple) and len(payload) == 1:
return payload[0]
return payload
except KeyboardInterrupt:
return ("stop", ())
finally:
self.lock.release()
class ServerProxy(multiprocessing.Process, BlockingChannel):
def __init__(self, channel, init_func, *init_args, **init_kwargs):
multiprocessing.Process.__init__(self)
BlockingChannel.__init__(self, channel)
self.init_func = init_func
self.init_args = init_args
self.init_kwargs = init_kwargs
def run(self):
try:
server = self.init_func(*self.init_args, **self.init_kwargs)
server.start()
self.send(("ok", ()))
while True:
# ["func", ("arg", ...)]
# ["prop", ()]
sattr, fargs = self.recv()
attr = getattr(server, sattr)
# apply fargs to attr if it is a function
if callable(attr):
rv = attr(*fargs)
# otherwise attr is a property
else:
rv = attr
self.send(rv)
if sattr == "stop":
return
except Exception as e:
self.send(("stop", e))
except KeyboardInterrupt:
server.stop()
class ServerProc(BlockingChannel):
def __init__(self, init_func):
self._init_func = init_func
self.proc = None
parent_chan, self.child_chan = multiprocessing.Pipe()
BlockingChannel.__init__(self, parent_chan)
def start(self, doc_root, ssl_config, **kwargs):
self.proc = ServerProxy(
self.child_chan, self._init_func, doc_root, ssl_config, **kwargs
)
self.proc.daemon = True
self.proc.start()
res, exc = self.recv()
if res == "stop":
raise exc
def get_url(self, url):
return self.call("get_url", (url,))
@property
def doc_root(self):
return self.call("doc_root", ())
def stop(self):
self.call("stop")
if not self.is_alive:
return
self.proc.join()
def kill(self):
if not self.is_alive:
return
self.proc.terminate()
self.proc.join(0)
@property
def is_alive(self):
if self.proc is not None:
return self.proc.is_alive()
return False
def http_server(doc_root, ssl_config, host="127.0.0.1", **kwargs):
return httpd.FixtureServer(doc_root, url="http://{}:0/".format(host), **kwargs)
def https_server(doc_root, ssl_config, host="127.0.0.1", **kwargs):
return httpd.FixtureServer(
doc_root,
url="https://{}:0/".format(host),
ssl_key=ssl_config["key_path"],
ssl_cert=ssl_config["cert_path"],
**kwargs
)
def start_servers(doc_root, ssl_config, **kwargs):
servers = defaultdict()
for schema, builder_fn in registered_servers:
proc = ServerProc(builder_fn)
proc.start(doc_root, ssl_config, **kwargs)
servers[schema] = (proc.get_url("/"), proc)
return servers
def start(doc_root=None, **kwargs):
"""Start all relevant test servers.
If no `doc_root` is given the default
testing/marionette/harness/marionette_harness/www directory will be used.
Additional keyword arguments can be given which will be passed on
to the individual ``FixtureServer``'s in httpd.py.
"""
doc_root = doc_root or default_doc_root
ssl_config = {
"cert_path": httpd.default_ssl_cert,
"key_path": httpd.default_ssl_key,
}
global servers
servers = start_servers(doc_root, ssl_config, **kwargs)
return servers
def where_is(uri, on="http"):
"""Returns the full URL, including scheme, hostname, and port, for
a fixture resource from the server associated with the ``on`` key.
It will by default look for the resource in the "http" server.
"""
return servers.get(on)[1].get_url(uri)
def iter_proc(servers):
for _, (_, proc) in iteritems(servers):
yield proc
def iter_url(servers):
for _, (url, _) in iteritems(servers):
yield url
default_doc_root = os.path.join(os.path.dirname(here), "www")
registered_servers = [("http", http_server), ("https", https_server)]
servers = defaultdict()
def main(args):
global servers
parser = argparse.ArgumentParser()
parser.add_argument(
"-r", dest="doc_root", help="Path to document root. Overrides default."
)
args = parser.parse_args()
servers = start(args.doc_root)
for url in iter_url(servers):
print("{}: listening on {}".format(sys.argv[0], url), file=sys.stderr)
try:
while any(proc.is_alive for proc in iter_proc(servers)):
for proc in iter_proc(servers):
proc.proc.join(1)
except KeyboardInterrupt:
for proc in iter_proc(servers):
proc.kill()
if __name__ == "__main__":
main(sys.argv[1:])