Source code

Revision control

Copy as Markdown

Other Tools

/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
/* vim:set ts=4 sw=2 sts=2 et cin: */
/* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
#include "DNS.h"
#include "DNSUtils.h"
#include "nsCharSeparatedTokenizer.h"
#include "nsContentUtils.h"
#include "nsHttpHandler.h"
#include "nsHttpChannel.h"
#include "nsHostResolver.h"
#include "nsIHttpChannel.h"
#include "nsIHttpChannelInternal.h"
#include "nsIIOService.h"
#include "nsIInputStream.h"
#include "nsIObliviousHttp.h"
#include "nsISupports.h"
#include "nsISupportsUtils.h"
#include "nsITimedChannel.h"
#include "nsIUploadChannel2.h"
#include "nsIURIMutator.h"
#include "nsNetUtil.h"
#include "nsQueryObject.h"
#include "nsStringStream.h"
#include "nsThreadUtils.h"
#include "nsURLHelper.h"
#include "ObliviousHttpChannel.h"
#include "TRR.h"
#include "TRRService.h"
#include "TRRServiceChannel.h"
#include "TRRLoadInfo.h"
#include "mozilla/Base64.h"
#include "mozilla/DebugOnly.h"
#include "mozilla/Logging.h"
#include "mozilla/Preferences.h"
#include "mozilla/StaticPrefs_network.h"
#include "mozilla/Telemetry.h"
#include "mozilla/TimeStamp.h"
#include "mozilla/Tokenizer.h"
#include "mozilla/UniquePtr.h"
// Put DNSLogging.h at the end to avoid LOG being overwritten by other headers.
#include "DNSLogging.h"
#include "mozilla/glean/GleanMetrics.h"
namespace mozilla {
namespace net {
NS_IMPL_ISUPPORTS(TRR, nsIHttpPushListener, nsIInterfaceRequestor,
nsIStreamListener, nsIRunnable, nsITimerCallback)
// when firing off a normal A or AAAA query
TRR::TRR(AHostResolver* aResolver, nsHostRecord* aRec, enum TrrType aType)
: mozilla::Runnable("TRR"),
mRec(aRec),
mHostResolver(aResolver),
mType(aType),
mOriginSuffix(aRec->originSuffix) {
mHost = aRec->host;
mPB = aRec->pb;
MOZ_DIAGNOSTIC_ASSERT(XRE_IsParentProcess() || XRE_IsSocketProcess(),
"TRR must be in parent or socket process");
}
// when following CNAMEs
TRR::TRR(AHostResolver* aResolver, nsHostRecord* aRec, nsCString& aHost,
enum TrrType& aType, unsigned int aLoopCount, bool aPB)
: mozilla::Runnable("TRR"),
mHost(aHost),
mRec(aRec),
mHostResolver(aResolver),
mType(aType),
mPB(aPB),
mCnameLoop(aLoopCount),
mOriginSuffix(aRec ? aRec->originSuffix : ""_ns) {
MOZ_DIAGNOSTIC_ASSERT(XRE_IsParentProcess() || XRE_IsSocketProcess(),
"TRR must be in parent or socket process");
}
// used on push
TRR::TRR(AHostResolver* aResolver, bool aPB)
: mozilla::Runnable("TRR"), mHostResolver(aResolver), mPB(aPB) {
MOZ_DIAGNOSTIC_ASSERT(XRE_IsParentProcess() || XRE_IsSocketProcess(),
"TRR must be in parent or socket process");
}
// to verify a domain
TRR::TRR(AHostResolver* aResolver, nsACString& aHost, enum TrrType aType,
const nsACString& aOriginSuffix, bool aPB, bool aUseFreshConnection)
: mozilla::Runnable("TRR"),
mHost(aHost),
mRec(nullptr),
mHostResolver(aResolver),
mType(aType),
mPB(aPB),
mOriginSuffix(aOriginSuffix),
mUseFreshConnection(aUseFreshConnection) {
MOZ_DIAGNOSTIC_ASSERT(XRE_IsParentProcess() || XRE_IsSocketProcess(),
"TRR must be in parent or socket process");
}
void TRR::HandleTimeout() {
mTimeout = nullptr;
RecordReason(TRRSkippedReason::TRR_TIMEOUT);
Cancel(NS_ERROR_NET_TIMEOUT_EXTERNAL);
}
NS_IMETHODIMP
TRR::Notify(nsITimer* aTimer) {
if (aTimer == mTimeout) {
HandleTimeout();
} else {
MOZ_CRASH("Unknown timer");
}
return NS_OK;
}
NS_IMETHODIMP
TRR::Run() {
MOZ_ASSERT_IF(XRE_IsParentProcess() && TRRService::Get(),
NS_IsMainThread() || TRRService::Get()->IsOnTRRThread());
MOZ_ASSERT_IF(XRE_IsSocketProcess(), NS_IsMainThread());
if ((TRRService::Get() == nullptr) || NS_FAILED(SendHTTPRequest())) {
RecordReason(TRRSkippedReason::TRR_SEND_FAILED);
FailData(NS_ERROR_FAILURE);
// The dtor will now be run
}
return NS_OK;
}
DNSPacket* TRR::GetOrCreateDNSPacket() {
if (!mPacket) {
mPacket = MakeUnique<DNSPacket>();
}
return mPacket.get();
}
nsresult TRR::CreateQueryURI(nsIURI** aOutURI) {
nsAutoCString uri;
nsCOMPtr<nsIURI> dnsURI;
if (UseDefaultServer()) {
TRRService::Get()->GetURI(uri);
} else {
uri = mRec->mTrrServer;
}
nsresult rv = NS_NewURI(getter_AddRefs(dnsURI), uri);
if (NS_FAILED(rv)) {
RecordReason(TRRSkippedReason::TRR_BAD_URL);
return rv;
}
dnsURI.forget(aOutURI);
return NS_OK;
}
bool TRR::MaybeBlockRequest() {
if (((mType == TRRTYPE_A) || (mType == TRRTYPE_AAAA)) &&
mRec->mEffectiveTRRMode != nsIRequest::TRR_ONLY_MODE) {
// let NS resolves skip the blocklist check
// we also don't check the blocklist for TRR only requests
MOZ_ASSERT(mRec);
// If TRRService isn't enabled anymore for the req, don't do TRR.
if (!TRRService::Get()->Enabled(mRec->mEffectiveTRRMode)) {
RecordReason(TRRSkippedReason::TRR_MODE_NOT_ENABLED);
return true;
}
if (!StaticPrefs::network_trr_strict_native_fallback() &&
UseDefaultServer() &&
TRRService::Get()->IsTemporarilyBlocked(mHost, mOriginSuffix, mPB,
true)) {
if (mType == TRRTYPE_A) {
// count only blocklist for A records to avoid double counts
Telemetry::Accumulate(Telemetry::DNS_TRR_BLACKLISTED3,
TRRService::ProviderKey(), true);
}
RecordReason(TRRSkippedReason::TRR_HOST_BLOCKED_TEMPORARY);
// not really an error but no TRR is issued
return true;
}
if (TRRService::Get()->IsExcludedFromTRR(mHost)) {
RecordReason(TRRSkippedReason::TRR_EXCLUDED);
return true;
}
if (UseDefaultServer() && (mType == TRRTYPE_A)) {
Telemetry::Accumulate(Telemetry::DNS_TRR_BLACKLISTED3,
TRRService::ProviderKey(), false);
}
}
return false;
}
nsresult TRR::SendHTTPRequest() {
// This is essentially the "run" method - created from nsHostResolver
if (mCancelled) {
return NS_ERROR_FAILURE;
}
if ((mType != TRRTYPE_A) && (mType != TRRTYPE_AAAA) &&
(mType != TRRTYPE_NS) && (mType != TRRTYPE_TXT) &&
(mType != TRRTYPE_HTTPSSVC)) {
// limit the calling interface because nsHostResolver has explicit slots for
// these types
return NS_ERROR_FAILURE;
}
if (MaybeBlockRequest()) {
return NS_ERROR_UNKNOWN_HOST;
}
LOG(("TRR::SendHTTPRequest resolve %s type %u\n", mHost.get(), mType));
nsAutoCString body;
bool disableECS = StaticPrefs::network_trr_disable_ECS();
nsresult rv =
GetOrCreateDNSPacket()->EncodeRequest(body, mHost, mType, disableECS);
if (NS_FAILED(rv)) {
HandleEncodeError(rv);
return rv;
}
bool useGet = StaticPrefs::network_trr_useGET();
nsCOMPtr<nsIURI> dnsURI;
rv = CreateQueryURI(getter_AddRefs(dnsURI));
if (NS_FAILED(rv)) {
LOG(("TRR:SendHTTPRequest: NewURI failed!\n"));
return rv;
}
if (useGet) {
/* For GET requests, the outgoing packet needs to be Base64url-encoded and
then appended to the end of the URI. */
nsAutoCString encoded;
rv = Base64URLEncode(body.Length(),
reinterpret_cast<const unsigned char*>(body.get()),
Base64URLEncodePaddingPolicy::Omit, encoded);
NS_ENSURE_SUCCESS(rv, rv);
nsAutoCString query;
rv = dnsURI->GetQuery(query);
if (NS_FAILED(rv)) {
return rv;
}
if (query.IsEmpty()) {
query.Assign("?dns="_ns);
} else {
query.Append("&dns="_ns);
}
query.Append(encoded);
rv = NS_MutateURI(dnsURI).SetQuery(query).Finalize(dnsURI);
LOG(("TRR::SendHTTPRequest GET dns=%s\n", body.get()));
}
nsCOMPtr<nsIChannel> channel;
bool useOHTTP = StaticPrefs::network_trr_use_ohttp();
if (useOHTTP) {
nsCOMPtr<nsIObliviousHttpService> ohttpService(
do_GetService("@mozilla.org/network/oblivious-http-service;1"));
if (!ohttpService) {
return NS_ERROR_FAILURE;
}
nsCOMPtr<nsIURI> relayURI;
nsTArray<uint8_t> encodedConfig;
rv = ohttpService->GetTRRSettings(getter_AddRefs(relayURI), encodedConfig);
if (NS_FAILED(rv)) {
return rv;
}
if (!relayURI) {
return NS_ERROR_FAILURE;
}
rv = ohttpService->NewChannel(relayURI, dnsURI, encodedConfig,
getter_AddRefs(channel));
} else {
rv = DNSUtils::CreateChannelHelper(dnsURI, getter_AddRefs(channel));
}
if (NS_FAILED(rv) || !channel) {
LOG(("TRR:SendHTTPRequest: NewChannel failed!\n"));
return rv;
}
auto loadFlags = nsIRequest::LOAD_ANONYMOUS | nsIRequest::INHIBIT_CACHING |
nsIRequest::LOAD_BYPASS_CACHE |
nsIChannel::LOAD_BYPASS_URL_CLASSIFIER;
if (mUseFreshConnection) {
// Causes TRRServiceChannel to tell the connection manager
// to clear out any connection with the current conn info.
loadFlags |= nsIRequest::LOAD_FRESH_CONNECTION;
}
channel->SetLoadFlags(loadFlags);
NS_ENSURE_SUCCESS(rv, rv);
rv = channel->SetNotificationCallbacks(this);
NS_ENSURE_SUCCESS(rv, rv);
nsCOMPtr<nsIHttpChannel> httpChannel = do_QueryInterface(channel);
if (!httpChannel) {
return NS_ERROR_UNEXPECTED;
}
// This connection should not use TRR
rv = httpChannel->SetTRRMode(nsIRequest::TRR_DISABLED_MODE);
NS_ENSURE_SUCCESS(rv, rv);
nsCString contentType(ContentType());
rv = httpChannel->SetRequestHeader("Accept"_ns, contentType, false);
NS_ENSURE_SUCCESS(rv, rv);
nsAutoCString cred;
if (UseDefaultServer()) {
TRRService::Get()->GetCredentials(cred);
}
if (!cred.IsEmpty()) {
rv = httpChannel->SetRequestHeader("Authorization"_ns, cred, false);
NS_ENSURE_SUCCESS(rv, rv);
}
nsCOMPtr<nsIHttpChannelInternal> internalChannel = do_QueryInterface(channel);
if (!internalChannel) {
return NS_ERROR_UNEXPECTED;
}
// setting a small stream window means the h2 stack won't pipeline a window
// update with each HEADERS or reply to a DATA with a WINDOW UPDATE
rv = internalChannel->SetInitialRwin(127 * 1024);
NS_ENSURE_SUCCESS(rv, rv);
rv = internalChannel->SetIsTRRServiceChannel(true);
NS_ENSURE_SUCCESS(rv, rv);
if (UseDefaultServer() && StaticPrefs::network_trr_async_connInfo()) {
RefPtr<nsHttpConnectionInfo> trrConnInfo =
TRRService::Get()->TRRConnectionInfo();
if (trrConnInfo) {
nsAutoCString host;
dnsURI->GetHost(host);
if (host.Equals(trrConnInfo->GetOrigin())) {
internalChannel->SetConnectionInfo(trrConnInfo);
LOG(("TRR::SendHTTPRequest use conn info:%s\n",
trrConnInfo->HashKey().get()));
} else {
MOZ_DIAGNOSTIC_ASSERT(false);
}
} else {
TRRService::Get()->InitTRRConnectionInfo();
}
}
if (useGet) {
rv = httpChannel->SetRequestMethod("GET"_ns);
NS_ENSURE_SUCCESS(rv, rv);
} else {
nsCOMPtr<nsIUploadChannel2> uploadChannel = do_QueryInterface(httpChannel);
if (!uploadChannel) {
return NS_ERROR_UNEXPECTED;
}
uint32_t streamLength = body.Length();
nsCOMPtr<nsIInputStream> uploadStream;
rv =
NS_NewCStringInputStream(getter_AddRefs(uploadStream), std::move(body));
NS_ENSURE_SUCCESS(rv, rv);
rv = uploadChannel->ExplicitSetUploadStream(uploadStream, contentType,
streamLength, "POST"_ns, false);
NS_ENSURE_SUCCESS(rv, rv);
}
rv = SetupTRRServiceChannelInternal(httpChannel, useGet, contentType);
if (NS_FAILED(rv)) {
return rv;
}
rv = httpChannel->AsyncOpen(this);
if (NS_FAILED(rv)) {
return rv;
}
// If the asyncOpen succeeded we can say that we actually attempted to
// use the TRR connection.
RefPtr<AddrHostRecord> addrRec = do_QueryObject(mRec);
if (addrRec) {
addrRec->mResolverType = ResolverType();
}
NS_NewTimerWithCallback(
getter_AddRefs(mTimeout), this,
mTimeoutMs ? mTimeoutMs : TRRService::Get()->GetRequestTimeout(),
nsITimer::TYPE_ONE_SHOT);
mChannel = channel;
return NS_OK;
}
// static
nsresult TRR::SetupTRRServiceChannelInternal(nsIHttpChannel* aChannel,
bool aUseGet,
const nsACString& aContentType) {
nsCOMPtr<nsIHttpChannel> httpChannel = aChannel;
MOZ_ASSERT(httpChannel);
nsresult rv = NS_OK;
if (!aUseGet) {
rv =
httpChannel->SetRequestHeader("Cache-Control"_ns, "no-store"_ns, false);
NS_ENSURE_SUCCESS(rv, rv);
}
// Sanitize the request by removing the Accept-Language header so we minimize
// the amount of fingerprintable information we send to the server.
if (!StaticPrefs::network_trr_send_accept_language_headers()) {
rv = httpChannel->SetRequestHeader("Accept-Language"_ns, ""_ns, false);
NS_ENSURE_SUCCESS(rv, rv);
}
// Sanitize the request by removing the User-Agent
if (!StaticPrefs::network_trr_send_user_agent_headers()) {
rv = httpChannel->SetRequestHeader("User-Agent"_ns, ""_ns, false);
NS_ENSURE_SUCCESS(rv, rv);
}
if (StaticPrefs::network_trr_send_empty_accept_encoding_headers()) {
rv = httpChannel->SetEmptyRequestHeader("Accept-Encoding"_ns);
NS_ENSURE_SUCCESS(rv, rv);
}
// set the *default* response content type
if (NS_FAILED(httpChannel->SetContentType(aContentType))) {
LOG(("TRR::SetupTRRServiceChannelInternal: couldn't set content-type!\n"));
}
nsCOMPtr<nsITimedChannel> timedChan(do_QueryInterface(httpChannel));
if (timedChan) {
timedChan->SetTimingEnabled(true);
}
return NS_OK;
}
NS_IMETHODIMP
TRR::GetInterface(const nsIID& iid, void** result) {
if (!iid.Equals(NS_GET_IID(nsIHttpPushListener))) {
return NS_ERROR_NO_INTERFACE;
}
nsCOMPtr<nsIHttpPushListener> copy(this);
*result = copy.forget().take();
return NS_OK;
}
nsresult TRR::DohDecodeQuery(const nsCString& query, nsCString& host,
enum TrrType& type) {
FallibleTArray<uint8_t> binary;
bool found_dns = false;
LOG(("TRR::DohDecodeQuery %s!\n", query.get()));
// extract "dns=" from the query string
nsAutoCString data;
for (const nsACString& token :
nsCCharSeparatedTokenizer(query, '&').ToRange()) {
nsDependentCSubstring dns = Substring(token, 0, 4);
nsAutoCString check(dns);
if (check.Equals("dns=")) {
nsDependentCSubstring q = Substring(token, 4, -1);
data = q;
found_dns = true;
break;
}
}
if (!found_dns) {
LOG(("TRR::DohDecodeQuery no dns= in pushed URI query string\n"));
return NS_ERROR_ILLEGAL_VALUE;
}
nsresult rv =
Base64URLDecode(data, Base64URLDecodePaddingPolicy::Ignore, binary);
NS_ENSURE_SUCCESS(rv, rv);
uint32_t avail = binary.Length();
if (avail < 12) {
return NS_ERROR_FAILURE;
}
// check the query bit and the opcode
if ((binary[2] & 0xf8) != 0) {
return NS_ERROR_FAILURE;
}
uint32_t qdcount = (binary[4] << 8) + binary[5];
if (!qdcount) {
return NS_ERROR_FAILURE;
}
uint32_t index = 12;
uint32_t length = 0;
host.Truncate();
do {
if (avail < (index + 1)) {
return NS_ERROR_UNEXPECTED;
}
length = binary[index];
if (length) {
if (host.Length()) {
host.Append(".");
}
if (avail < (index + 1 + length)) {
return NS_ERROR_UNEXPECTED;
}
host.Append((const char*)(&binary[0]) + index + 1, length);
}
index += 1 + length; // skip length byte + label
} while (length);
LOG(("TRR::DohDecodeQuery host %s\n", host.get()));
if (avail < (index + 2)) {
return NS_ERROR_UNEXPECTED;
}
uint16_t i16 = 0;
i16 += binary[index] << 8;
i16 += binary[index + 1];
type = (enum TrrType)i16;
LOG(("TRR::DohDecodeQuery type %d\n", (int)type));
return NS_OK;
}
nsresult TRR::ReceivePush(nsIHttpChannel* pushed, nsHostRecord* pushedRec) {
if (!mHostResolver) {
return NS_ERROR_UNEXPECTED;
}
LOG(("TRR::ReceivePush: PUSH incoming!\n"));
nsCOMPtr<nsIURI> uri;
pushed->GetURI(getter_AddRefs(uri));
nsAutoCString query;
if (uri) {
uri->GetQuery(query);
}
if (NS_FAILED(DohDecodeQuery(query, mHost, mType)) ||
HostIsIPLiteral(mHost)) { // literal
LOG(("TRR::ReceivePush failed to decode %s\n", mHost.get()));
return NS_ERROR_UNEXPECTED;
}
if ((mType != TRRTYPE_A) && (mType != TRRTYPE_AAAA) &&
(mType != TRRTYPE_TXT) && (mType != TRRTYPE_HTTPSSVC)) {
LOG(("TRR::ReceivePush unknown type %d\n", mType));
return NS_ERROR_UNEXPECTED;
}
if (TRRService::Get()->IsExcludedFromTRR(mHost)) {
return NS_ERROR_FAILURE;
}
uint32_t type = nsIDNSService::RESOLVE_TYPE_DEFAULT;
if (mType == TRRTYPE_TXT) {
type = nsIDNSService::RESOLVE_TYPE_TXT;
} else if (mType == TRRTYPE_HTTPSSVC) {
type = nsIDNSService::RESOLVE_TYPE_HTTPSSVC;
}
RefPtr<nsHostRecord> hostRecord;
nsresult rv;
rv = mHostResolver->GetHostRecord(
mHost, ""_ns, type, pushedRec->flags, pushedRec->af, pushedRec->pb,
pushedRec->originSuffix, getter_AddRefs(hostRecord));
if (NS_FAILED(rv)) {
return rv;
}
// Since we don't ever call nsHostResolver::NameLookup for this record,
// we need to copy the trr mode from the previous record
if (hostRecord->mEffectiveTRRMode == nsIRequest::TRR_DEFAULT_MODE) {
hostRecord->mEffectiveTRRMode =
static_cast<nsIRequest::TRRMode>(pushedRec->mEffectiveTRRMode);
}
rv = mHostResolver->TrrLookup_unlocked(hostRecord, this);
if (NS_FAILED(rv)) {
return rv;
}
rv = pushed->AsyncOpen(this);
if (NS_FAILED(rv)) {
return rv;
}
// OK!
mChannel = pushed;
mRec.swap(hostRecord);
return NS_OK;
}
NS_IMETHODIMP
TRR::OnPush(nsIHttpChannel* associated, nsIHttpChannel* pushed) {
LOG(("TRR::OnPush entry\n"));
MOZ_ASSERT(associated == mChannel);
if (!mRec) {
return NS_ERROR_FAILURE;
}
if (!UseDefaultServer()) {
return NS_ERROR_FAILURE;
}
RefPtr<TRR> trr = new TRR(mHostResolver, mPB);
trr->SetPurpose(mPurpose);
return trr->ReceivePush(pushed, mRec);
}
NS_IMETHODIMP
TRR::OnStartRequest(nsIRequest* aRequest) {
LOG(("TRR::OnStartRequest %p %s %d\n", this, mHost.get(), mType));
nsresult status = NS_OK;
aRequest->GetStatus(&status);
if (NS_FAILED(status)) {
if (NS_IsOffline()) {
RecordReason(TRRSkippedReason::TRR_IS_OFFLINE);
}
switch (status) {
case NS_ERROR_UNKNOWN_HOST:
RecordReason(TRRSkippedReason::TRR_CHANNEL_DNS_FAIL);
break;
case NS_ERROR_OFFLINE:
RecordReason(TRRSkippedReason::TRR_IS_OFFLINE);
break;
case NS_ERROR_NET_RESET:
RecordReason(TRRSkippedReason::TRR_NET_RESET);
break;
case NS_ERROR_NET_TIMEOUT:
case NS_ERROR_NET_TIMEOUT_EXTERNAL:
RecordReason(TRRSkippedReason::TRR_NET_TIMEOUT);
break;
case NS_ERROR_PROXY_CONNECTION_REFUSED:
RecordReason(TRRSkippedReason::TRR_NET_REFUSED);
break;
case NS_ERROR_NET_INTERRUPT:
RecordReason(TRRSkippedReason::TRR_NET_INTERRUPT);
break;
case NS_ERROR_NET_INADEQUATE_SECURITY:
RecordReason(TRRSkippedReason::TRR_NET_INADEQ_SEQURITY);
break;
default:
RecordReason(TRRSkippedReason::TRR_UNKNOWN_CHANNEL_FAILURE);
}
}
return NS_OK;
}
void TRR::SaveAdditionalRecords(
const nsClassHashtable<nsCStringHashKey, DOHresp>& aRecords) {
if (!mRec) {
return;
}
nsresult rv;
for (const auto& recordEntry : aRecords) {
if (!recordEntry.GetData() || recordEntry.GetData()->mAddresses.IsEmpty()) {
// no point in adding empty records.
continue;
}
// If IPv6 is disabled don't add anything else than IPv4.
if (StaticPrefs::network_dns_disableIPv6() &&
std::find_if(recordEntry.GetData()->mAddresses.begin(),
recordEntry.GetData()->mAddresses.end(),
[](const NetAddr& addr) { return !addr.IsIPAddrV4(); }) !=
recordEntry.GetData()->mAddresses.end()) {
continue;
}
RefPtr<nsHostRecord> hostRecord;
rv = mHostResolver->GetHostRecord(
recordEntry.GetKey(), EmptyCString(),
nsIDNSService::RESOLVE_TYPE_DEFAULT, mRec->flags, AF_UNSPEC, mRec->pb,
mRec->originSuffix, getter_AddRefs(hostRecord));
if (NS_FAILED(rv)) {
LOG(("Failed to get host record for additional record %s",
nsCString(recordEntry.GetKey()).get()));
continue;
}
RefPtr<AddrInfo> ai(
new AddrInfo(recordEntry.GetKey(), ResolverType(), TRRTYPE_A,
std::move(recordEntry.GetData()->mAddresses),
recordEntry.GetData()->mTtl));
mHostResolver->MaybeRenewHostRecord(hostRecord);
// Since we're not actually calling NameLookup for this record, we need
// to set these fields to avoid assertions in CompleteLookup.
// This is quite hacky, and should be fixed.
hostRecord->Reset();
hostRecord->mResolving++;
hostRecord->mEffectiveTRRMode =
static_cast<nsIRequest::TRRMode>(mRec->mEffectiveTRRMode);
LOG(("Completing lookup for additional: %s",
nsCString(recordEntry.GetKey()).get()));
(void)mHostResolver->CompleteLookup(hostRecord, NS_OK, ai, mPB,
mOriginSuffix, TRRSkippedReason::TRR_OK,
this);
}
}
void TRR::StoreIPHintAsDNSRecord(const struct SVCB& aSVCBRecord) {
LOG(("TRR::StoreIPHintAsDNSRecord [%p] [%s]", this,
aSVCBRecord.mSvcDomainName.get()));
CopyableTArray<NetAddr> addresses;
aSVCBRecord.GetIPHints(addresses);
if (StaticPrefs::network_dns_disableIPv6()) {
addresses.RemoveElementsBy(
[](const NetAddr& addr) { return !addr.IsIPAddrV4(); });
}
if (addresses.IsEmpty()) {
return;
}
RefPtr<nsHostRecord> hostRecord;
nsresult rv = mHostResolver->GetHostRecord(
aSVCBRecord.mSvcDomainName, EmptyCString(),
nsIDNSService::RESOLVE_TYPE_DEFAULT,
mRec->flags | nsIDNSService::RESOLVE_IP_HINT, AF_UNSPEC, mRec->pb,
mRec->originSuffix, getter_AddRefs(hostRecord));
if (NS_FAILED(rv)) {
LOG(("Failed to get host record"));
return;
}
mHostResolver->MaybeRenewHostRecord(hostRecord);
RefPtr<AddrInfo> ai(new AddrInfo(aSVCBRecord.mSvcDomainName, ResolverType(),
TRRTYPE_A, std::move(addresses), mTTL));
// Since we're not actually calling NameLookup for this record, we need
// to set these fields to avoid assertions in CompleteLookup.
// This is quite hacky, and should be fixed.
hostRecord->mResolving++;
hostRecord->mEffectiveTRRMode =
static_cast<nsIRequest::TRRMode>(mRec->mEffectiveTRRMode);
(void)mHostResolver->CompleteLookup(hostRecord, NS_OK, ai, mPB, mOriginSuffix,
TRRSkippedReason::TRR_OK, this);
}
nsresult TRR::ReturnData(nsIChannel* aChannel) {
if (mType != TRRTYPE_TXT && mType != TRRTYPE_HTTPSSVC) {
// create and populate an AddrInfo instance to pass on
RefPtr<AddrInfo> ai(new AddrInfo(mHost, ResolverType(), mType,
nsTArray<NetAddr>(), mDNS.mTtl));
auto builder = ai->Build();
builder.SetAddresses(std::move(mDNS.mAddresses));
builder.SetCanonicalHostname(mCname);
// Set timings.
nsCOMPtr<nsITimedChannel> timedChan = do_QueryInterface(aChannel);
if (timedChan) {
TimeStamp asyncOpen, start, end;
if (NS_SUCCEEDED(timedChan->GetAsyncOpen(&asyncOpen)) &&
!asyncOpen.IsNull()) {
builder.SetTrrFetchDuration(
(TimeStamp::Now() - asyncOpen).ToMilliseconds());
}
if (NS_SUCCEEDED(timedChan->GetRequestStart(&start)) &&
NS_SUCCEEDED(timedChan->GetResponseEnd(&end)) && !start.IsNull() &&
!end.IsNull()) {
builder.SetTrrFetchDurationNetworkOnly((end - start).ToMilliseconds());
}
}
ai = builder.Finish();
if (!mHostResolver) {
return NS_ERROR_FAILURE;
}
RecordReason(TRRSkippedReason::TRR_OK);
(void)mHostResolver->CompleteLookup(mRec, NS_OK, ai, mPB, mOriginSuffix,
mTRRSkippedReason, this);
mHostResolver = nullptr;
mRec = nullptr;
} else {
RecordReason(TRRSkippedReason::TRR_OK);
(void)mHostResolver->CompleteLookupByType(mRec, NS_OK, mResult,
mTRRSkippedReason, mTTL, mPB);
}
return NS_OK;
}
nsresult TRR::FailData(nsresult error) {
if (!mHostResolver) {
return NS_ERROR_FAILURE;
}
// If we didn't record a reason until now, record a default one.
RecordReason(TRRSkippedReason::TRR_FAILED);
if (mType == TRRTYPE_TXT || mType == TRRTYPE_HTTPSSVC) {
TypeRecordResultType empty(Nothing{});
(void)mHostResolver->CompleteLookupByType(mRec, error, empty,
mTRRSkippedReason, 0, mPB);
} else {
// create and populate an TRR AddrInfo instance to pass on to signal that
// this comes from TRR
nsTArray<NetAddr> noAddresses;
RefPtr<AddrInfo> ai =
new AddrInfo(mHost, ResolverType(), mType, std::move(noAddresses));
(void)mHostResolver->CompleteLookup(mRec, error, ai, mPB, mOriginSuffix,
mTRRSkippedReason, this);
}
mHostResolver = nullptr;
mRec = nullptr;
return NS_OK;
}
void TRR::HandleDecodeError(nsresult aStatusCode) {
auto rcode = mPacket->GetRCode();
if (rcode.isOk() && rcode.unwrap() != 0) {
if (rcode.unwrap() == 0x03) {
RecordReason(TRRSkippedReason::TRR_NXDOMAIN);
} else {
RecordReason(TRRSkippedReason::TRR_RCODE_FAIL);
}
} else if (aStatusCode == NS_ERROR_UNKNOWN_HOST ||
aStatusCode == NS_ERROR_DEFINITIVE_UNKNOWN_HOST) {
RecordReason(TRRSkippedReason::TRR_NO_ANSWERS);
} else {
RecordReason(TRRSkippedReason::TRR_DECODE_FAILED);
}
}
bool TRR::HasUsableResponse() {
if (mType == TRRTYPE_A || mType == TRRTYPE_AAAA) {
return !mDNS.mAddresses.IsEmpty();
}
if (mType == TRRTYPE_TXT) {
return mResult.is<TypeRecordTxt>();
}
if (mType == TRRTYPE_HTTPSSVC) {
return mResult.is<TypeRecordHTTPSSVC>();
}
return false;
}
nsresult TRR::FollowCname(nsIChannel* aChannel) {
nsresult rv = NS_OK;
nsAutoCString cname;
while (NS_SUCCEEDED(rv) && mDNS.mAddresses.IsEmpty() && !mCname.IsEmpty() &&
mCnameLoop > 0) {
mCnameLoop--;
LOG(("TRR::On200Response CNAME %s => %s (%u)\n", mHost.get(), mCname.get(),
mCnameLoop));
cname = mCname;
mCname.Truncate();
LOG(("TRR: check for CNAME record for %s within previous response\n",
cname.get()));
nsClassHashtable<nsCStringHashKey, DOHresp> additionalRecords;
rv = GetOrCreateDNSPacket()->Decode(
cname, mType, mCname, StaticPrefs::network_trr_allow_rfc1918(), mDNS,
mResult, additionalRecords, mTTL);
if (NS_FAILED(rv)) {
LOG(("TRR::FollowCname DohDecode %x\n", (int)rv));
HandleDecodeError(rv);
}
}
// restore mCname as DohDecode() change it
mCname = cname;
if (NS_SUCCEEDED(rv) && HasUsableResponse()) {
ReturnData(aChannel);
return NS_OK;
}
bool ra = mPacket && mPacket->RecursionAvailable().unwrapOr(false);
LOG(("ra = %d", ra));
if (rv == NS_ERROR_UNKNOWN_HOST && ra) {
// If recursion is available, but no addresses have been returned,
// we can just return a failure here.
LOG(("TRR::FollowCname not sending another request as RA flag is set."));
FailData(NS_ERROR_UNKNOWN_HOST);
return NS_OK;
}
if (!mCnameLoop) {
LOG(("TRR::On200Response CNAME loop, eject!\n"));
return NS_ERROR_REDIRECT_LOOP;
}
LOG(("TRR::On200Response CNAME %s => %s (%u)\n", mHost.get(), mCname.get(),
mCnameLoop));
RefPtr<TRR> trr =
new TRR(mHostResolver, mRec, mCname, mType, mCnameLoop, mPB);
trr->SetPurpose(mPurpose);
if (!TRRService::Get()) {
return NS_ERROR_FAILURE;
}
return TRRService::Get()->DispatchTRRRequest(trr);
}
nsresult TRR::On200Response(nsIChannel* aChannel) {
// decode body and create an AddrInfo struct for the response
nsClassHashtable<nsCStringHashKey, DOHresp> additionalRecords;
RefPtr<TypeHostRecord> typeRec = do_QueryObject(mRec);
if (typeRec && typeRec->mOriginHost) {
GetOrCreateDNSPacket()->SetOriginHost(typeRec->mOriginHost);
}
nsresult rv = GetOrCreateDNSPacket()->Decode(
mHost, mType, mCname, StaticPrefs::network_trr_allow_rfc1918(), mDNS,
mResult, additionalRecords, mTTL);
if (NS_FAILED(rv)) {
LOG(("TRR::On200Response DohDecode %x\n", (int)rv));
HandleDecodeError(rv);
return rv;
}
if (StaticPrefs::network_trr_add_additional_records()) {
SaveAdditionalRecords(additionalRecords);
}
if (mResult.is<TypeRecordHTTPSSVC>()) {
auto& results = mResult.as<TypeRecordHTTPSSVC>();
for (const auto& rec : results) {
StoreIPHintAsDNSRecord(rec);
}
}
if (!mDNS.mAddresses.IsEmpty() || mType == TRRTYPE_TXT || mCname.IsEmpty()) {
// pass back the response data
ReturnData(aChannel);
return NS_OK;
}
LOG(("TRR::On200Response trying CNAME %s", mCname.get()));
return FollowCname(aChannel);
}
void TRR::RecordProcessingTime(nsIChannel* aChannel) {
// This method records the time it took from the last received byte of the
// DoH response until we've notified the consumer with a host record.
nsCOMPtr<nsITimedChannel> timedChan = do_QueryInterface(aChannel);
if (!timedChan) {
return;
}
TimeStamp end;
if (NS_FAILED(timedChan->GetResponseEnd(&end))) {
return;
}
if (end.IsNull()) {
return;
}
Telemetry::AccumulateTimeDelta(Telemetry::DNS_TRR_PROCESSING_TIME, end);
LOG(("Processing DoH response took %f ms",
(TimeStamp::Now() - end).ToMilliseconds()));
}
void TRR::ReportStatus(nsresult aStatusCode) {
// If the TRR was cancelled by nsHostResolver, then we don't need to report
// it as failed; otherwise it can cause the confirmation to fail.
if (UseDefaultServer() && aStatusCode != NS_ERROR_ABORT) {
// Bad content is still considered "okay" if the HTTP response is okay
TRRService::Get()->RecordTRRStatus(this);
}
}
static void RecordHttpVersion(nsIHttpChannel* aHttpChannel) {
nsCOMPtr<nsIHttpChannelInternal> internalChannel =
do_QueryInterface(aHttpChannel);
if (!internalChannel) {
LOG(("RecordHttpVersion: Failed to QI nsIHttpChannelInternal"));
return;
}
uint32_t major, minor;
if (NS_FAILED(internalChannel->GetResponseVersion(&major, &minor))) {
LOG(("RecordHttpVersion: Failed to get protocol version"));
return;
}
auto label = Telemetry::LABELS_DNS_TRR_HTTP_VERSION2::h_1;
if (major == 2) {
label = Telemetry::LABELS_DNS_TRR_HTTP_VERSION2::h_2;
} else if (major == 3) {
label = Telemetry::LABELS_DNS_TRR_HTTP_VERSION2::h_3;
}
Telemetry::AccumulateCategoricalKeyed(TRRService::ProviderKey(), label);
LOG(("RecordHttpVersion: Provider responded using HTTP version: %d", major));
}
NS_IMETHODIMP
TRR::OnStopRequest(nsIRequest* aRequest, nsresult aStatusCode) {
// The dtor will be run after the function returns
LOG(("TRR:OnStopRequest %p %s %d failed=%d code=%X\n", this, mHost.get(),
mType, mFailed, (unsigned int)aStatusCode));
nsCOMPtr<nsIChannel> channel;
channel.swap(mChannel);
mChannelStatus = aStatusCode;
if (NS_SUCCEEDED(aStatusCode)) {
nsCString label = "regular"_ns;
if (mPB) {
label = "private"_ns;
}
mozilla::glean::networking::trr_request_count.Get(label).Add(1);
}
{
// Cancel the timer since we don't need it anymore.
nsCOMPtr<nsITimer> timer;
mTimeout.swap(timer);
if (timer) {
timer->Cancel();
}
}
auto scopeExit = MakeScopeExit([&] { ReportStatus(aStatusCode); });
nsresult rv = NS_OK;
// if status was "fine", parse the response and pass on the answer
if (!mFailed && NS_SUCCEEDED(aStatusCode)) {
nsCOMPtr<nsIHttpChannel> httpChannel = do_QueryInterface(aRequest);
if (!httpChannel) {
return NS_ERROR_UNEXPECTED;
}
nsAutoCString contentType;
httpChannel->GetContentType(contentType);
if (contentType.Length() &&
!contentType.LowerCaseEqualsASCII(ContentType())) {
LOG(("TRR:OnStopRequest %p %s %d wrong content type %s\n", this,
mHost.get(), mType, contentType.get()));
FailData(NS_ERROR_UNEXPECTED);
return NS_OK;
}
uint32_t httpStatus;
rv = httpChannel->GetResponseStatus(&httpStatus);
if (NS_SUCCEEDED(rv) && httpStatus == 200) {
rv = On200Response(channel);
if (NS_SUCCEEDED(rv) && UseDefaultServer()) {
RecordReason(TRRSkippedReason::TRR_OK);
RecordProcessingTime(channel);
RecordHttpVersion(httpChannel);
return rv;
}
} else {
RecordReason(TRRSkippedReason::TRR_SERVER_RESPONSE_ERR);
LOG(("TRR:OnStopRequest:%d %p rv %x httpStatus %d\n", __LINE__, this,
(int)rv, httpStatus));
}
}
LOG(("TRR:OnStopRequest %p status %x mFailed %d\n", this, (int)aStatusCode,
mFailed));
FailData(NS_SUCCEEDED(rv) ? NS_ERROR_UNKNOWN_HOST : rv);
return NS_OK;
}
NS_IMETHODIMP
TRR::OnDataAvailable(nsIRequest* aRequest, nsIInputStream* aInputStream,
uint64_t aOffset, const uint32_t aCount) {
LOG(("TRR:OnDataAvailable %p %s %d failed=%d aCount=%u\n", this, mHost.get(),
mType, mFailed, (unsigned int)aCount));
// receive DNS response into the local buffer
if (mFailed) {
return NS_ERROR_FAILURE;
}
nsresult rv = GetOrCreateDNSPacket()->OnDataAvailable(aRequest, aInputStream,
aOffset, aCount);
if (NS_FAILED(rv)) {
LOG(("TRR::OnDataAvailable:%d fail\n", __LINE__));
mFailed = true;
return rv;
}
return NS_OK;
}
void TRR::Cancel(nsresult aStatus) {
bool isTRRServiceChannel = false;
nsCOMPtr<nsIHttpChannelInternal> httpChannelInternal(
do_QueryInterface(mChannel));
if (httpChannelInternal) {
nsresult rv =
httpChannelInternal->GetIsTRRServiceChannel(&isTRRServiceChannel);
if (NS_FAILED(rv)) {
isTRRServiceChannel = false;
}
}
// nsHttpChannel can be only canceled on the main thread.
RefPtr<nsHttpChannel> httpChannel = do_QueryObject(mChannel);
if (isTRRServiceChannel && !XRE_IsSocketProcess() && !httpChannel) {
if (TRRService::Get()) {
nsCOMPtr<nsIThread> thread = TRRService::Get()->TRRThread();
if (thread && !thread->IsOnCurrentThread()) {
thread->Dispatch(NS_NewRunnableFunction(
"TRR::Cancel",
[self = RefPtr(this), aStatus]() { self->Cancel(aStatus); }));
return;
}
}
} else {
if (!NS_IsMainThread()) {
NS_DispatchToMainThread(NS_NewRunnableFunction(
"TRR::Cancel",
[self = RefPtr(this), aStatus]() { self->Cancel(aStatus); }));
return;
}
}
if (mCancelled) {
return;
}
mCancelled = true;
if (mChannel) {
RecordReason(TRRSkippedReason::TRR_REQ_CANCELLED);
LOG(("TRR: %p canceling Channel %p %s %d status=%" PRIx32 "\n", this,
mChannel.get(), mHost.get(), mType, static_cast<uint32_t>(aStatus)));
mChannel->Cancel(aStatus);
}
}
bool TRR::UseDefaultServer() { return !mRec || mRec->mTrrServer.IsEmpty(); }
} // namespace net
} // namespace mozilla