Source code

Revision control

Copy as Markdown

Other Tools

from __future__ import annotations
import logging
import threading
from copy import copy
from itertools import product
from typing import TYPE_CHECKING
from dnslib import QTYPE, RCODE, RR, A
from dnslib.label import DNSLabel
from dnslib.server import BaseResolver, DNSLogger, DNSServer
if TYPE_CHECKING:
from typing import Any, Callable
from dnslib.dns import DNSRecord
from dnslib.server import DNSHandler
from wptserve.config import Config
logger = logging.getLogger()
class Resolver(BaseResolver): # type: ignore[misc]
def __init__(
self,
allowed_hosts: set[str],
destination: str,
unknown_labels: set[str] | None = None,
ttl: int = 3600,
) -> None:
super().__init__()
self.unknown_labels = unknown_labels
self.zone = [
(rr.rname, QTYPE[rr.rtype], rr)
for rr in (
RR(
host if host.endswith(".") else host + ".",
rtype=QTYPE.A,
rdata=A(destination),
ttl=ttl,
)
for host in allowed_hosts
)
]
def resolve(self, request: DNSRecord, handler: DNSHandler) -> DNSRecord:
reply = request.reply()
qname = request.q.qname
qtype = QTYPE[request.q.qtype]
if self.unknown_labels:
qlabels = {DNSLabel(label) for label in str(qname).rstrip(".").split(".")}
if any(DNSLabel(label) in qlabels for label in self.unknown_labels):
reply.header.rcode = RCODE.NXDOMAIN
return reply
has_answer = False
for name, rtype, rr in self.zone:
if qname.matchGlob(name):
if qtype == rtype:
a = copy(rr)
a.rname = qname
reply.add_answer(a)
has_answer = True
if has_answer:
return reply
reply.header.rcode = RCODE.NXDOMAIN
return reply
class LoggingDNSLogger(DNSLogger): # type: ignore[misc]
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.logf: Callable[[str], None]
super().__init__(*args, logf=logger.debug, **kwargs)
def log_error(self, *args: Any, **kwargs: Any) -> None:
old_logf = self.logf
self.logf = logger.error
try:
super().log_error(*args, **kwargs)
finally:
self.logf = old_logf
class DNSServerDaemon:
def __init__(
self,
host: str,
port: int,
bind_address: bool,
config: Config,
wildcards: int | None = None,
**kwargs: Any,
) -> None:
if wildcards == 0:
wildcards = max(s.count(".") for s in config["all_domains_set"])
if wildcards is not None:
hosts = {config["browser_host"]} | set(config["alternate_hosts"].values())
wildcard_hosts = {".".join(("*",) * i) for i in range(1, wildcards + 1)}
resolver_hosts = hosts | {
".".join(x) + "." for x in product(wildcard_hosts, hosts)
}
resolver = Resolver(
resolver_hosts, config["server_host"], config["not_subdomains"]
)
else:
resolver = Resolver(config["domains_set"], config["server_host"])
self.server = DNSServer(
resolver,
address=host if bind_address else "",
port=port,
logger=LoggingDNSLogger(),
)
self.server_thread: threading.Thread | None = None
def start(self) -> None:
self.started = True
self.server_thread = threading.Thread(target=self.server.start)
self.server_thread.setDaemon(True) # don't hang on exit
self.server_thread.start()
def stop(self) -> None:
if self.started:
assert self.server_thread is not None
try:
self.server.shutdown()
self.server.server_close()
self.server_thread.join()
self.server_thread = None
except AttributeError:
pass
self.started = False
self.server = None