Source code

Revision control

Copy as Markdown

Other Tools

/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
/* vim: set ts=2 et sw=2 tw=80: */
/* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this file,
* You can obtain one at http://mozilla.org/MPL/2.0/. */
#ifndef tls_connect_h_
#define tls_connect_h_
#include <tuple>
#include "sslproto.h"
#include "sslt.h"
#include "nss.h"
#include "tls_agent.h"
#include "tls_filter.h"
#define GTEST_HAS_RTTI 0
#include "gtest/gtest.h"
namespace nss_test {
extern std::string VersionString(uint16_t version);
// A generic TLS connection test base.
class TlsConnectTestBase : public ::testing::Test {
public:
static ::testing::internal::ParamGenerator<SSLProtocolVariant>
kTlsVariantsStream;
static ::testing::internal::ParamGenerator<SSLProtocolVariant>
kTlsVariantsDatagram;
static ::testing::internal::ParamGenerator<SSLProtocolVariant>
kTlsVariantsAll;
static ::testing::internal::ParamGenerator<uint16_t> kTlsV10;
static ::testing::internal::ParamGenerator<uint16_t> kTlsV11;
static ::testing::internal::ParamGenerator<uint16_t> kTlsV12;
static ::testing::internal::ParamGenerator<uint16_t> kTlsV10V11;
static ::testing::internal::ParamGenerator<uint16_t> kTlsV11V12;
static ::testing::internal::ParamGenerator<uint16_t> kTlsV10ToV12;
static ::testing::internal::ParamGenerator<uint16_t> kTlsV13;
static ::testing::internal::ParamGenerator<uint16_t> kTlsV11Plus;
static ::testing::internal::ParamGenerator<uint16_t> kTlsV12Plus;
static ::testing::internal::ParamGenerator<uint16_t> kTlsVAll;
TlsConnectTestBase(SSLProtocolVariant variant, uint16_t version);
virtual ~TlsConnectTestBase();
virtual void SetUp();
virtual void TearDown();
PRTime now() const { return now_; }
// Initialize client and server.
void Init();
// Clear the statistics.
void ClearStats();
// Clear the server session cache.
void ClearServerCache();
// Make sure TLS is configured for a connection.
virtual void EnsureTlsSetup();
// Reset and keep the same certificate names
void Reset();
// Reset, and update the certificate names on both peers
void Reset(const std::string& server_name,
const std::string& client_name = "client");
// Replace the server.
void MakeNewServer();
// Set up
void StartConnect();
// Run the handshake.
void Handshake();
// Connect and check that it works.
void Connect();
// Check that the connection was successfully established.
void CheckConnected();
// Connect and expect it to fail.
void ConnectExpectFail();
void ExpectAlert(std::shared_ptr<TlsAgent>& sender, uint8_t alert);
void ConnectExpectAlert(std::shared_ptr<TlsAgent>& sender, uint8_t alert);
void ConnectExpectFailOneSide(TlsAgent::Role failingSide);
void ConnectWithCipherSuite(uint16_t cipher_suite);
void CheckEarlyDataLimit(const std::shared_ptr<TlsAgent>& agent,
size_t expected_size);
// Check that the keys used in the handshake match expectations.
void CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group,
SSLAuthType auth_type, SSLSignatureScheme sig_scheme) const;
// This version guesses some of the values.
void CheckKeys(SSLKEAType kea_type, SSLAuthType auth_type) const;
// This version assumes defaults.
void CheckKeys() const;
// Check that keys on resumed sessions.
void CheckKeysResumption(SSLKEAType kea_type, SSLNamedGroup kea_group,
SSLNamedGroup original_kea_group,
SSLAuthType auth_type,
SSLSignatureScheme sig_scheme);
void CheckGroups(const DataBuffer& groups,
std::function<void(SSLNamedGroup)> check_group);
void CheckShares(const DataBuffer& shares,
std::function<void(SSLNamedGroup)> check_group);
void CheckEpochs(uint16_t client_epoch, uint16_t server_epoch) const;
void ConfigureVersion(uint16_t version);
void SetExpectedVersion(uint16_t version);
// Expect resumption of a particular type.
void ExpectResumption(SessionResumptionMode expected,
uint8_t num_resumed = 1);
void DisableAllCiphers();
void EnableOnlyStaticRsaCiphers();
void EnableOnlyDheCiphers();
void EnableSomeEcdhCiphers();
void EnableExtendedMasterSecret();
void ConfigureSelfEncrypt();
void ConfigureSessionCache(SessionResumptionMode client,
SessionResumptionMode server);
void EnableAlpn();
void EnableAlpnWithCallback(const std::vector<uint8_t>& client,
std::string server_choice);
void EnableAlpn(const std::vector<uint8_t>& vals);
void EnsureModelSockets();
void CheckAlpn(const std::string& val);
void EnableSrtp();
void CheckSrtp() const;
void SendReceive(size_t total = 50);
void AddPsk(const ScopedPK11SymKey& psk, std::string label, SSLHashType hash,
uint16_t zeroRttSuite = TLS_NULL_WITH_NULL_NULL);
void RemovePsk(std::string label);
void SetupForZeroRtt();
void SetupForResume();
void ZeroRttSendReceive(
bool expect_writable, bool expect_readable,
std::function<bool()> post_clienthello_check = nullptr);
void Receive(size_t amount);
void ExpectExtendedMasterSecret(bool expected);
void ExpectEarlyDataAccepted(bool expected);
void EnableECDHEServerKeyReuse();
void SkipVersionChecks();
// Move the DTLS timers for both endpoints to pop the next timer.
void ShiftDtlsTimers();
void AdvanceTime(PRTime time_shift);
void ResetAntiReplay(PRTime window);
void RolloverAntiReplay();
void SaveAlgorithmPolicy();
void RestoreAlgorithmPolicy();
static ScopedSECItem MakeEcKeyParams(SSLNamedGroup group);
static void GenerateEchConfig(
HpkeKemId kem_id, const std::vector<HpkeSymmetricSuite>& cipher_suites,
const std::string& public_name, uint16_t max_name_len, DataBuffer& record,
ScopedSECKEYPublicKey& pubKey, ScopedSECKEYPrivateKey& privKey);
void SetupEch(std::shared_ptr<TlsAgent>& client,
std::shared_ptr<TlsAgent>& server,
HpkeKemId kem_id = HpkeDhKemX25519Sha256,
bool expect_ech = true, bool set_client_config = true,
bool set_server_config = true, int maxConfigSize = 100);
protected:
SSLProtocolVariant variant_;
std::shared_ptr<TlsAgent> client_;
std::shared_ptr<TlsAgent> server_;
std::unique_ptr<TlsAgent> client_model_;
std::unique_ptr<TlsAgent> server_model_;
uint16_t version_;
SessionResumptionMode expected_resumption_mode_;
uint8_t expected_resumptions_;
std::vector<std::vector<uint8_t>> session_ids_;
ScopedSSLAntiReplayContext anti_replay_;
// A simple value of "a", "b". Note that the preferred value of "a" is placed
// at the end, because the NSS API follows the now defunct NPN specification,
// which places the preferred (and default) entry at the end of the list.
// NSS will move this final entry to the front when used with ALPN.
const uint8_t alpn_dummy_val_[4] = {0x01, 0x62, 0x01, 0x61};
// A list of algorithm IDs whose policies need to be preserved
// around test cases. In particular, DSA is checked in
// ssl_extension_unittest.cc.
const std::vector<SECOidTag> algorithms_ = {SEC_OID_APPLY_SSL_POLICY,
SEC_OID_ANSIX9_DSA_SIGNATURE,
SEC_OID_CURVE25519, SEC_OID_SHA1};
std::vector<std::tuple<SECOidTag, uint32_t>> saved_policies_;
const std::vector<PRInt32> options_ = {
NSS_RSA_MIN_KEY_SIZE, NSS_DH_MIN_KEY_SIZE, NSS_DSA_MIN_KEY_SIZE,
NSS_TLS_VERSION_MIN_POLICY, NSS_TLS_VERSION_MAX_POLICY};
std::vector<std::tuple<PRInt32, uint32_t>> saved_options_;
private:
void CheckResumption(SessionResumptionMode expected);
void CheckExtendedMasterSecret();
void CheckEarlyDataAccepted();
static PRTime TimeFunc(void* arg);
bool expect_extended_master_secret_;
bool expect_early_data_accepted_;
bool skip_version_checks_;
PRTime now_;
// Track groups and make sure that there are no duplicates.
class DuplicateGroupChecker {
public:
void AddAndCheckGroup(SSLNamedGroup group) {
EXPECT_EQ(groups_.end(), groups_.find(group))
<< "Group " << group << " should not be duplicated";
groups_.insert(group);
}
private:
std::set<SSLNamedGroup> groups_;
};
};
// A non-parametrized TLS test base.
class TlsConnectTest : public TlsConnectTestBase {
public:
TlsConnectTest() : TlsConnectTestBase(ssl_variant_stream, 0) {}
};
// A non-parametrized DTLS-only test base.
class DtlsConnectTest : public TlsConnectTestBase {
public:
DtlsConnectTest() : TlsConnectTestBase(ssl_variant_datagram, 0) {}
};
// A TLS-only test base.
class TlsConnectStream : public TlsConnectTestBase,
public ::testing::WithParamInterface<uint16_t> {
public:
TlsConnectStream() : TlsConnectTestBase(ssl_variant_stream, GetParam()) {}
};
// A TLS-only test base for tests before 1.3
class TlsConnectStreamPre13 : public TlsConnectStream {};
// A DTLS-only test base.
class TlsConnectDatagram : public TlsConnectTestBase,
public ::testing::WithParamInterface<uint16_t> {
public:
TlsConnectDatagram() : TlsConnectTestBase(ssl_variant_datagram, GetParam()) {}
};
// A generic test class that can be either stream or datagram and a single
// version of TLS. This is configured in ssl_loopback_unittest.cc.
class TlsConnectGeneric : public TlsConnectTestBase,
public ::testing::WithParamInterface<
std::tuple<SSLProtocolVariant, uint16_t>> {
public:
TlsConnectGeneric();
};
class TlsConnectGenericResumption
: public TlsConnectTestBase,
public ::testing::WithParamInterface<
std::tuple<SSLProtocolVariant, uint16_t, bool>> {
private:
bool external_cache_;
public:
TlsConnectGenericResumption();
virtual void EnsureTlsSetup() {
TlsConnectTestBase::EnsureTlsSetup();
// Enable external resumption token cache.
if (external_cache_) {
client_->SetResumptionTokenCallback();
}
}
bool use_external_cache() const { return external_cache_; }
};
class TlsConnectTls13ResumptionToken
: public TlsConnectTestBase,
public ::testing::WithParamInterface<SSLProtocolVariant> {
public:
TlsConnectTls13ResumptionToken();
virtual void EnsureTlsSetup() {
TlsConnectTestBase::EnsureTlsSetup();
client_->SetResumptionTokenCallback();
}
};
class TlsConnectGenericResumptionToken
: public TlsConnectTestBase,
public ::testing::WithParamInterface<
std::tuple<SSLProtocolVariant, uint16_t>> {
public:
TlsConnectGenericResumptionToken();
virtual void EnsureTlsSetup() {
TlsConnectTestBase::EnsureTlsSetup();
client_->SetResumptionTokenCallback();
}
};
// A Pre TLS 1.2 generic test.
class TlsConnectPre12 : public TlsConnectTestBase,
public ::testing::WithParamInterface<
std::tuple<SSLProtocolVariant, uint16_t>> {
public:
TlsConnectPre12();
};
// A TLS 1.2 only generic test.
class TlsConnectTls12
: public TlsConnectTestBase,
public ::testing::WithParamInterface<SSLProtocolVariant> {
public:
TlsConnectTls12();
};
// A TLS 1.2 only stream test.
class TlsConnectStreamTls12 : public TlsConnectTestBase {
public:
TlsConnectStreamTls12()
: TlsConnectTestBase(ssl_variant_stream, SSL_LIBRARY_VERSION_TLS_1_2) {}
};
// A TLS 1.2+ generic test.
class TlsConnectTls12Plus : public TlsConnectTestBase,
public ::testing::WithParamInterface<
std::tuple<SSLProtocolVariant, uint16_t>> {
public:
TlsConnectTls12Plus();
};
// A TLS 1.3 only generic test.
class TlsConnectTls13
: public TlsConnectTestBase,
public ::testing::WithParamInterface<SSLProtocolVariant> {
public:
TlsConnectTls13();
};
// A TLS 1.3 only stream test.
class TlsConnectStreamTls13 : public TlsConnectTestBase {
public:
TlsConnectStreamTls13()
: TlsConnectTestBase(ssl_variant_stream, SSL_LIBRARY_VERSION_TLS_1_3) {}
};
class TlsConnectDatagram13 : public TlsConnectTestBase {
public:
TlsConnectDatagram13()
: TlsConnectTestBase(ssl_variant_datagram, SSL_LIBRARY_VERSION_TLS_1_3) {}
};
class TlsConnectDatagramPre13 : public TlsConnectDatagram {
public:
TlsConnectDatagramPre13() {}
};
// A variant that is used only with Pre13.
class TlsConnectGenericPre13 : public TlsConnectGeneric {};
class TlsKeyExchangeTest : public TlsConnectGeneric {
protected:
std::shared_ptr<TlsExtensionCapture> groups_capture_;
std::shared_ptr<TlsExtensionCapture> shares_capture_;
std::shared_ptr<TlsExtensionCapture> shares_capture2_;
std::shared_ptr<TlsHandshakeRecorder> capture_hrr_;
void EnsureKeyShareSetup();
void ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups);
std::vector<SSLNamedGroup> GetGroupDetails(
const std::shared_ptr<TlsExtensionCapture>& capture);
std::vector<SSLNamedGroup> GetShareDetails(
const std::shared_ptr<TlsExtensionCapture>& capture);
void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups,
const std::vector<SSLNamedGroup>& expectedShares);
void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups,
const std::vector<SSLNamedGroup>& expectedShares,
SSLNamedGroup expectedShare2);
private:
void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups,
const std::vector<SSLNamedGroup>& expectedShares,
bool expect_hrr);
};
class TlsKeyExchangeTest13 : public TlsKeyExchangeTest {};
class TlsKeyExchangeTestPre13 : public TlsKeyExchangeTest {};
} // namespace nss_test
#endif