Source code
Revision control
Copy as Markdown
Other Tools
/* -*- Mode: C++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
/* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
#include "TestCommon.h"
#include "gtest/gtest.h"
#include "Http3ConnectUDPStream.h"
#include "Http3Session.h"
#include "nsIUDPSocket.h"
#include "nsIIOService.h"
#include "nsIProtocolProxyService.h"
#include "nsIProtocolHandler.h"
#include "nsThreadUtils.h"
#include "nsStringStream.h"
#include "nsProxyInfo.h"
#include "nsHttpConnectionInfo.h"
#include "nsHttpRequestHead.h"
#include "nsHttpHandler.h"
#include "mozilla/Components.h"
using namespace mozilla;
using namespace mozilla::net;
static const char* kProxyHost = "proxy.org";
static const char* kHost = "example.com";
static const int32_t kPort = 4433;
static const char* kPathTemplate =
"/.well-known/masque/udp/{target_host}/{target_port}/";
static const char* kPathHeader = "/.well-known/masque/udp/example.com/4433/";
class Http3SessionStub final : public Http3SessionBase {
public:
NS_INLINE_DECL_REFCOUNTING(Http3SessionStub, override)
nsresult TryActivating(const nsACString& aMethod, const nsACString& aScheme,
const nsACString& aAuthorityHeader,
const nsACString& aPath, const nsACString& aHeaders,
uint64_t* aStreamId,
Http3StreamBase* aStream) override {
mPathHeader = aPath;
mAuthHeader = aAuthorityHeader;
return NS_OK;
}
void CloseSendingSide(uint64_t aStreamId) override {}
void SendHTTPDatagram(uint64_t aStreamId, nsTArray<uint8_t>& aData,
uint64_t aTrackingId) override {
mOutputData.AppendElements(aData);
}
nsresult SendRequestBody(uint64_t aStreamId, const char* buf, uint32_t count,
uint32_t* countRead) override {
return NS_OK;
}
nsresult ReadResponseData(uint64_t aStreamId, char* aBuf, uint32_t aCount,
uint32_t* aCountWritten, bool* aFin) override {
*aCountWritten = 0;
*aFin = false;
return NS_OK;
}
nsresult SendPriorityUpdateFrame(uint64_t aStreamId, uint8_t aPriorityUrgency,
bool aPriorityIncremental) override {
return NS_OK;
}
void ConnectSlowConsumer(Http3StreamBase* stream) override {}
void CloseWebTransportConn() override {}
void StreamHasDataToWrite(Http3StreamBase* aStream) override {
mReadyForWrite.AppendElement(aStream);
}
nsresult CloseWebTransport(uint64_t aSessionId, uint32_t aError,
const nsACString& aMessage) override {
return NS_OK;
}
void SendDatagram(Http3WebTransportSession* aSession,
nsTArray<uint8_t>& aData, uint64_t aTrackingId) override {}
uint64_t MaxDatagramSize(uint64_t aSessionId) override { return 0; }
nsresult TryActivatingWebTransportStream(uint64_t* aStreamId,
Http3StreamBase* aStream) override {
*aStreamId = 0;
return NS_OK;
}
void ResetWebTransportStream(Http3WebTransportStream* aStream,
uint64_t aErrorCode) override {}
void StreamStopSending(Http3WebTransportStream* aStream,
uint8_t aErrorCode) override {}
void SetSendOrder(Http3StreamBase* aStream,
Maybe<int64_t> aSendOrder) override {}
void ProcessOutput() {
for (const auto& stream : mReadyForWrite) {
Unused << stream->ReadSegments();
}
mReadyForWrite.Clear();
}
void FinishTunnelSetup(nsAHttpTransaction* aTransaction) override {
mFinishTunnelSetupCalled = true;
}
bool FinishTunnelSetupCalled() const { return mFinishTunnelSetupCalled; }
nsTArray<uint8_t> TakeOutputData() { return std::move(mOutputData); }
const nsCString& PathHeader() { return mPathHeader; }
const nsCString& AuthHeader() { return mAuthHeader; }
private:
~Http3SessionStub() = default;
nsTArray<RefPtr<Http3StreamBase>> mReadyForWrite;
nsTArray<uint8_t> mOutputData;
nsCString mPathHeader;
nsCString mAuthHeader;
bool mFinishTunnelSetupCalled = false;
};
class DummyHttpTransaction : public nsAHttpTransaction {
public:
NS_DECL_THREADSAFE_ISUPPORTS
DummyHttpTransaction() {
nsCString buffer;
buffer.AssignLiteral("capsule-protocol = ?1\r\n\r\n");
NS_NewCStringInputStream(getter_AddRefs(mRequestStream), buffer);
nsCOMPtr<nsIProtocolProxyService> pps;
pps = mozilla::components::ProtocolProxy::Service();
if (pps) {
nsCOMPtr<nsIProxyInfo> info;
nsresult rv = pps->NewMASQUEProxyInfo(
nsCString(kProxyHost), -1, nsCString(kPathTemplate), ""_ns, ""_ns, 0,
0, nullptr, getter_AddRefs(info));
if (NS_FAILED(rv)) {
return;
}
mConnInfo = new nsHttpConnectionInfo(
nsCString(kHost), kPort, ""_ns, ""_ns,
static_cast<nsProxyInfo*>(info.get()), OriginAttributes());
}
}
static nsresult ReadRequestSegment(nsIInputStream* stream, void* closure,
const char* buf, uint32_t offset,
uint32_t count, uint32_t* countRead) {
DummyHttpTransaction* trans = (DummyHttpTransaction*)closure;
return trans->mReader->OnReadSegment(buf, count, countRead);
}
void SetConnection(nsAHttpConnection*) override {}
nsAHttpConnection* Connection() override { return nullptr; }
void GetSecurityCallbacks(nsIInterfaceRequestor**) override {}
void OnTransportStatus(nsITransport* transport, nsresult status,
int64_t progress) override {}
bool IsDone() override { return mIsDone; }
nsresult Status() override { return NS_OK; }
uint32_t Caps() override { return 0; }
[[nodiscard]] nsresult ReadSegments(nsAHttpSegmentReader* reader,
uint32_t count,
uint32_t* countRead) override {
mReader = reader;
Unused << mRequestStream->ReadSegments(ReadRequestSegment, this, count,
countRead);
mReader = nullptr;
return NS_OK;
}
[[nodiscard]] nsresult WriteSegments(nsAHttpSegmentWriter* writer,
uint32_t count,
uint32_t* countWritten) override {
char buf[1024];
Unused << writer->OnWriteSegment(buf, 1024, countWritten);
mIsDone = true;
return NS_OK;
}
void Close(nsresult reason) override {}
nsHttpConnectionInfo* ConnectionInfo() override { return mConnInfo.get(); }
void SetProxyConnectFailed() override {}
nsHttpRequestHead* RequestHead() override {
if (mRequestHead) {
return mRequestHead.get();
}
mRequestHead = MakeUnique<nsHttpRequestHead>();
Unused << mRequestHead->SetHeader(nsHttp::Host, "example.com"_ns);
return mRequestHead.get();
}
uint32_t Http1xTransactionCount() override { return 0; }
[[nodiscard]] nsresult TakeSubTransactions(
nsTArray<RefPtr<nsAHttpTransaction>>& outTransactions) override {
return NS_OK;
}
private:
virtual ~DummyHttpTransaction() = default;
nsAHttpSegmentReader* mReader{nullptr};
nsCOMPtr<nsIInputStream> mRequestStream;
UniquePtr<nsHttpRequestHead> mRequestHead;
bool mIsDone = false;
RefPtr<nsHttpConnectionInfo> mConnInfo;
};
NS_IMPL_ISUPPORTS(DummyHttpTransaction, nsISupportsWeakReference)
class UDPListener final : public nsIUDPSocketSyncListener {
public:
NS_DECL_ISUPPORTS
UDPListener() = default;
NS_IMETHOD OnPacketReceived(nsIUDPSocket* aSocket) override {
nsTArray<uint8_t> data;
NetAddr addr{};
Unused << aSocket->RecvWithAddr(&addr, data);
mReceivedData.AppendElements(data);
return NS_OK;
}
NS_IMETHOD OnStopListening(nsIUDPSocket* aSocket, nsresult aStatus) override {
mOnStopListeningCalled = true;
return NS_OK;
}
nsTArray<uint8_t> TakeInputData() { return std::move(mReceivedData); }
bool OnStopListeningCalled() const { return mOnStopListeningCalled; }
private:
~UDPListener() = default;
bool mOnStopListeningCalled = false;
nsTArray<uint8_t> mReceivedData;
};
NS_IMPL_ISUPPORTS(UDPListener, nsIUDPSocketSyncListener)
static void InitHttpHandler() {
if (gHttpHandler) {
return;
}
nsresult rv;
nsCOMPtr<nsIIOService> ios = do_GetIOService(&rv);
if (NS_FAILED(rv)) {
return;
}
nsCOMPtr<nsIProtocolHandler> handler;
rv = ios->GetProtocolHandler("http", getter_AddRefs(handler));
if (NS_FAILED(rv)) {
return;
}
}
static already_AddRefed<Http3ConnectUDPStream> CreateUDPStream(
Http3SessionStub* aSession) {
RefPtr<DummyHttpTransaction> trans = new DummyHttpTransaction();
RefPtr<Http3ConnectUDPStream> stream =
new Http3ConnectUDPStream(trans, aSession, NS_GetCurrentThread());
NetAddr peerAddr;
peerAddr.InitFromString("127.0.0.1"_ns);
stream->SetPeerAddr(peerAddr);
aSession->StreamHasDataToWrite(stream);
aSession->ProcessOutput();
// HTTP/3 200
static constexpr uint8_t kResponse[] = {0x48, 0x54, 0x54, 0x50, 0x2F, 0x33,
0x20, 0x32, 0x30, 0x30, 0x0A, 0x0A};
static constexpr uint32_t kResponseLen = sizeof(kResponse) - 1;
nsTArray<uint8_t> response;
response.AppendElements(kResponse, kResponseLen);
stream->SetResponseHeaders(response, false, false);
Unused << stream->WriteSegments();
return stream.forget();
}
namespace ConnectUdp::testing {
static void CreateTestData(uint32_t aNumBytes, nsTArray<uint8_t>& aDataOut) {
static constexpr const char kSampleText[] =
"{\"type\":\"message\",\"id\":42,\"payload\":\"The quick brown fox jumps "
"over the lazy dog.\"}";
static constexpr uint32_t kSampleTextLen = sizeof(kSampleText) - 1;
aDataOut.SetCapacity(aNumBytes);
while (aNumBytes > 0) {
uint32_t chunkSize = std::min(kSampleTextLen, aNumBytes);
aDataOut.AppendElements(reinterpret_cast<const uint8_t*>(kSampleText),
chunkSize);
aNumBytes -= chunkSize;
}
}
static void ValidateData(nsTArray<uint8_t>& aInput,
nsTArray<uint8_t>& aExpectedData) {
ASSERT_EQ(aExpectedData.Length(), aInput.Length());
for (size_t i = 0; i < aExpectedData.Length(); i++) {
ASSERT_EQ(aExpectedData[i], aInput[i]);
}
}
} // namespace ConnectUdp::testing
TEST(ConnectUDP, SendDataBeforeActivate)
{
InitHttpHandler();
RefPtr<Http3SessionStub> session = new Http3SessionStub();
RefPtr<Http3ConnectUDPStream> stream =
new Http3ConnectUDPStream(nullptr, session, NS_GetCurrentThread());
nsCOMPtr<nsIUDPSocket> udp = static_cast<nsIUDPSocket*>(stream.get());
ASSERT_TRUE(udp);
NetAddr addr;
addr.InitFromString("127.0.0.1"_ns);
nsTArray<uint8_t> data;
ConnectUdp::testing::CreateTestData(100, data);
uint32_t written = 0;
nsresult rv =
udp->SendWithAddress(&addr, data.Elements(), data.Length(), &written);
ASSERT_EQ(rv, NS_ERROR_NOT_AVAILABLE);
}
TEST(ConnectUDP, SendData)
{
InitHttpHandler();
RefPtr<Http3SessionStub> session = new Http3SessionStub();
RefPtr<Http3ConnectUDPStream> stream = CreateUDPStream(session);
ASSERT_TRUE(session->FinishTunnelSetupCalled());
ASSERT_TRUE(session->AuthHeader().EqualsASCII(kProxyHost));
ASSERT_TRUE(session->PathHeader().EqualsASCII(kPathHeader));
nsCOMPtr<nsIUDPSocket> udp = static_cast<nsIUDPSocket*>(stream.get());
ASSERT_TRUE(udp);
NetAddr peerAddr;
peerAddr.InitFromString("127.0.0.1"_ns);
nsTArray<uint8_t> data;
ConnectUdp::testing::CreateTestData(100, data);
uint32_t written = 0;
nsresult rv =
udp->SendWithAddress(&peerAddr, data.Elements(), data.Length(), &written);
ASSERT_EQ(rv, NS_OK);
NS_ProcessPendingEvents(nullptr);
session->ProcessOutput();
nsTArray<uint8_t> output = session->TakeOutputData();
ConnectUdp::testing::ValidateData(data, output);
data.Clear();
ConnectUdp::testing::CreateTestData(200, data);
rv =
udp->SendWithAddress(&peerAddr, data.Elements(), data.Length(), &written);
ASSERT_EQ(rv, NS_OK);
NS_ProcessPendingEvents(nullptr);
session->ProcessOutput();
output = session->TakeOutputData();
ConnectUdp::testing::ValidateData(data, output);
ASSERT_EQ(stream->ByteCountSent(), 300u);
udp->Close();
}
TEST(ConnectUDP, RecvData)
{
InitHttpHandler();
RefPtr<Http3SessionStub> session = new Http3SessionStub();
RefPtr<Http3ConnectUDPStream> stream = CreateUDPStream(session);
ASSERT_TRUE(session->FinishTunnelSetupCalled());
ASSERT_TRUE(session->AuthHeader().EqualsASCII(kProxyHost));
ASSERT_TRUE(session->PathHeader().EqualsASCII(kPathHeader));
nsCOMPtr<nsIUDPSocket> udp = static_cast<nsIUDPSocket*>(stream.get());
ASSERT_TRUE(udp);
RefPtr<UDPListener> listener = new UDPListener();
udp->SyncListen(listener);
nsTArray<uint8_t> data;
ConnectUdp::testing::CreateTestData(100, data);
stream->OnDatagramReceived(std::move(data));
nsTArray<uint8_t> input = listener->TakeInputData();
ASSERT_EQ(input.Length(), 100u);
ConnectUdp::testing::CreateTestData(200, data);
stream->OnDatagramReceived(std::move(data));
input = listener->TakeInputData();
ASSERT_EQ(input.Length(), 200u);
ASSERT_EQ(stream->ByteCountReceived(), 300u);
udp->Close();
ASSERT_EQ(listener->OnStopListeningCalled(), true);
}