Source code
Revision control
Copy as Markdown
Other Tools
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// Copyright by contributors to this project.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
use crate::client::MlsError;
use crate::{group::PriorEpoch, key_package::KeyPackageRef};
use alloc::collections::VecDeque;
use alloc::vec::Vec;
use core::fmt::{self, Debug};
use mls_rs_codec::{MlsDecode, MlsEncode};
use mls_rs_core::group::{EpochRecord, GroupState};
use mls_rs_core::{error::IntoAnyError, group::GroupStateStorage, key_package::KeyPackageStorage};
use super::snapshot::Snapshot;
#[cfg(feature = "psk")]
use crate::group::ResumptionPsk;
#[cfg(feature = "psk")]
use mls_rs_core::psk::PreSharedKey;
/// A set of changes to apply to a GroupStateStorage implementation. These changes MUST
/// be made in a single transaction to avoid creating invalid states.
#[derive(Default, Clone, Debug)]
struct EpochStorageCommit {
pub(crate) inserts: VecDeque<PriorEpoch>,
pub(crate) updates: Vec<PriorEpoch>,
}
#[derive(Clone)]
pub(crate) struct GroupStateRepository<S, K>
where
S: GroupStateStorage,
K: KeyPackageStorage,
{
pending_commit: EpochStorageCommit,
pending_key_package_removal: Option<KeyPackageRef>,
group_id: Vec<u8>,
storage: S,
key_package_repo: K,
}
impl<S, K> Debug for GroupStateRepository<S, K>
where
S: GroupStateStorage + Debug,
K: KeyPackageStorage + Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("GroupStateRepository")
.field("pending_commit", &self.pending_commit)
.field(
"pending_key_package_removal",
&self.pending_key_package_removal,
)
.field(
"group_id",
&mls_rs_core::debug::pretty_group_id(&self.group_id),
)
.field("storage", &self.storage)
.field("key_package_repo", &self.key_package_repo)
.finish()
}
}
impl<S, K> GroupStateRepository<S, K>
where
S: GroupStateStorage,
K: KeyPackageStorage,
{
pub fn new(
group_id: Vec<u8>,
storage: S,
key_package_repo: K,
// Set to `None` if restoring from snapshot; set to `Some` when joining a group.
key_package_to_remove: Option<KeyPackageRef>,
) -> Result<GroupStateRepository<S, K>, MlsError> {
Ok(GroupStateRepository {
group_id,
storage,
pending_key_package_removal: key_package_to_remove,
pending_commit: Default::default(),
key_package_repo,
})
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
async fn find_max_id(&self) -> Result<Option<u64>, MlsError> {
if let Some(max) = self.pending_commit.inserts.back().map(|e| e.epoch_id()) {
Ok(Some(max))
} else {
self.storage
.max_epoch_id(&self.group_id)
.await
.map_err(|e| MlsError::GroupStorageError(e.into_any_error()))
}
}
#[cfg(feature = "psk")]
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn resumption_secret(
&self,
psk_id: &ResumptionPsk,
) -> Result<Option<PreSharedKey>, MlsError> {
// Search the local inserts cache
if let Some(min) = self.pending_commit.inserts.front().map(|e| e.epoch_id()) {
if psk_id.psk_epoch >= min {
return Ok(self
.pending_commit
.inserts
.get((psk_id.psk_epoch - min) as usize)
.map(|e| e.secrets.resumption_secret.clone()));
}
}
// Search the local updates cache
let maybe_pending = self.find_pending(psk_id.psk_epoch);
if let Some(pending) = maybe_pending {
return Ok(Some(
self.pending_commit.updates[pending]
.secrets
.resumption_secret
.clone(),
));
}
// Search the stored cache
self.storage
.epoch(&psk_id.psk_group_id.0, psk_id.psk_epoch)
.await
.map_err(|e| MlsError::GroupStorageError(e.into_any_error()))?
.map(|e| Ok(PriorEpoch::mls_decode(&mut &*e)?.secrets.resumption_secret))
.transpose()
}
#[cfg(feature = "private_message")]
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn get_epoch_mut(
&mut self,
epoch_id: u64,
) -> Result<Option<&mut PriorEpoch>, MlsError> {
// Search the local inserts cache
if let Some(min) = self.pending_commit.inserts.front().map(|e| e.epoch_id()) {
if epoch_id >= min {
return Ok(self
.pending_commit
.inserts
.get_mut((epoch_id - min) as usize));
}
}
// Look in the cached updates map, and if not found look in disk storage
// and insert into the updates map for future caching
match self.find_pending(epoch_id) {
Some(i) => self.pending_commit.updates.get_mut(i).map(Ok),
None => self
.storage
.epoch(&self.group_id, epoch_id)
.await
.map_err(|e| MlsError::GroupStorageError(e.into_any_error()))?
.and_then(|epoch| {
PriorEpoch::mls_decode(&mut &*epoch)
.map(|epoch| {
self.pending_commit.updates.push(epoch);
self.pending_commit.updates.last_mut()
})
.transpose()
}),
}
.transpose()
.map_err(Into::into)
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn insert(&mut self, epoch: PriorEpoch) -> Result<(), MlsError> {
if epoch.group_id() != self.group_id {
return Err(MlsError::GroupIdMismatch);
}
let epoch_id = epoch.epoch_id();
if let Some(expected_id) = self.find_max_id().await?.map(|id| id + 1) {
if epoch_id != expected_id {
return Err(MlsError::InvalidEpoch);
}
}
self.pending_commit.inserts.push_back(epoch);
Ok(())
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn write_to_storage(&mut self, group_snapshot: Snapshot) -> Result<(), MlsError> {
let inserts = self
.pending_commit
.inserts
.iter()
.map(|e| Ok(EpochRecord::new(e.epoch_id(), e.mls_encode_to_vec()?)))
.collect::<Result<_, MlsError>>()?;
let updates = self
.pending_commit
.updates
.iter()
.map(|e| Ok(EpochRecord::new(e.epoch_id(), e.mls_encode_to_vec()?)))
.collect::<Result<_, MlsError>>()?;
let group_state = GroupState {
data: group_snapshot.mls_encode_to_vec()?,
id: group_snapshot.state.context.group_id,
};
self.storage
.write(group_state, inserts, updates)
.await
.map_err(|e| MlsError::GroupStorageError(e.into_any_error()))?;
if let Some(ref key_package_ref) = self.pending_key_package_removal {
self.key_package_repo
.delete(key_package_ref)
.await
.map_err(|e| MlsError::KeyPackageRepoError(e.into_any_error()))?;
}
self.pending_commit.inserts.clear();
self.pending_commit.updates.clear();
Ok(())
}
#[cfg(any(feature = "psk", feature = "private_message"))]
fn find_pending(&self, epoch_id: u64) -> Option<usize> {
self.pending_commit
.updates
.iter()
.position(|ep| ep.context.epoch == epoch_id)
}
}
#[cfg(test)]
mod tests {
use alloc::vec;
use mls_rs_codec::MlsEncode;
use crate::{
client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
group::{
epoch::{test_utils::get_test_epoch_with_id, SenderDataSecret},
test_utils::{random_bytes, test_member, TEST_GROUP},
PskGroupId, ResumptionPSKUsage,
},
storage_provider::in_memory::{InMemoryGroupStateStorage, InMemoryKeyPackageStorage},
};
use super::*;
fn test_group_state_repo(
retention_limit: usize,
) -> GroupStateRepository<InMemoryGroupStateStorage, InMemoryKeyPackageStorage> {
GroupStateRepository::new(
TEST_GROUP.to_vec(),
InMemoryGroupStateStorage::new()
.with_max_epoch_retention(retention_limit)
.unwrap(),
InMemoryKeyPackageStorage::default(),
None,
)
.unwrap()
}
fn test_epoch(epoch_id: u64) -> PriorEpoch {
get_test_epoch_with_id(TEST_GROUP.to_vec(), TEST_CIPHER_SUITE, epoch_id)
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
async fn test_snapshot(epoch_id: u64) -> Snapshot {
crate::group::snapshot::test_utils::get_test_snapshot(TEST_CIPHER_SUITE, epoch_id).await
}
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn test_epoch_inserts() {
let mut test_repo = test_group_state_repo(1);
let test_epoch = test_epoch(0);
test_repo.insert(test_epoch.clone()).await.unwrap();
// Check the in-memory state
assert_eq!(
test_repo.pending_commit.inserts.back().unwrap(),
&test_epoch
);
assert!(test_repo.pending_commit.updates.is_empty());
#[cfg(feature = "std")]
assert!(test_repo.storage.inner.lock().unwrap().is_empty());
#[cfg(not(feature = "std"))]
assert!(test_repo.storage.inner.lock().is_empty());
let psk_id = ResumptionPsk {
psk_epoch: 0,
psk_group_id: PskGroupId(test_repo.group_id.clone()),
usage: ResumptionPSKUsage::Application,
};
// Make sure you can recall an epoch sitting as a pending insert
let resumption = test_repo.resumption_secret(&psk_id).await.unwrap();
let prior_epoch = test_repo.get_epoch_mut(0).await.unwrap().cloned();
assert_eq!(
prior_epoch.clone().unwrap().secrets.resumption_secret,
resumption.unwrap()
);
assert_eq!(prior_epoch.unwrap(), test_epoch);
// Write to the storage
let snapshot = test_snapshot(test_epoch.epoch_id()).await;
test_repo.write_to_storage(snapshot.clone()).await.unwrap();
// Make sure the memory cache cleared
assert!(test_repo.pending_commit.inserts.is_empty());
assert!(test_repo.pending_commit.updates.is_empty());
// Make sure the storage was written
#[cfg(feature = "std")]
let storage = test_repo.storage.inner.lock().unwrap();
#[cfg(not(feature = "std"))]
let storage = test_repo.storage.inner.lock();
assert_eq!(storage.len(), 1);
let stored = storage.get(TEST_GROUP).unwrap();
assert_eq!(stored.state_data, snapshot.mls_encode_to_vec().unwrap());
assert_eq!(stored.epoch_data.len(), 1);
assert_eq!(
stored.epoch_data.back().unwrap(),
&EpochRecord::new(
test_epoch.epoch_id(),
test_epoch.mls_encode_to_vec().unwrap()
)
);
}
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn test_updates() {
let mut test_repo = test_group_state_repo(2);
let test_epoch_0 = test_epoch(0);
test_repo.insert(test_epoch_0.clone()).await.unwrap();
test_repo
.write_to_storage(test_snapshot(0).await)
.await
.unwrap();
// Update the stored epoch
let to_update = test_repo.get_epoch_mut(0).await.unwrap().unwrap();
assert_eq!(to_update, &test_epoch_0);
let new_sender_secret = random_bytes(32);
to_update.secrets.sender_data_secret = SenderDataSecret::from(new_sender_secret);
let to_update = to_update.clone();
assert_eq!(test_repo.pending_commit.updates.len(), 1);
assert!(test_repo.pending_commit.inserts.is_empty());
assert_eq!(
test_repo.pending_commit.updates.first().unwrap(),
&to_update
);
// Make sure you can access an epoch pending update
let psk_id = ResumptionPsk {
psk_epoch: 0,
psk_group_id: PskGroupId(test_repo.group_id.clone()),
usage: ResumptionPSKUsage::Application,
};
let owned = test_repo.resumption_secret(&psk_id).await.unwrap();
assert_eq!(owned.as_ref(), Some(&to_update.secrets.resumption_secret));
// Write the update to storage
let snapshot = test_snapshot(1).await;
test_repo.write_to_storage(snapshot.clone()).await.unwrap();
assert!(test_repo.pending_commit.updates.is_empty());
assert!(test_repo.pending_commit.inserts.is_empty());
// Make sure the storage was written
#[cfg(feature = "std")]
let storage = test_repo.storage.inner.lock().unwrap();
#[cfg(not(feature = "std"))]
let storage = test_repo.storage.inner.lock();
assert_eq!(storage.len(), 1);
let stored = storage.get(TEST_GROUP).unwrap();
assert_eq!(stored.state_data, snapshot.mls_encode_to_vec().unwrap());
assert_eq!(stored.epoch_data.len(), 1);
assert_eq!(
stored.epoch_data.back().unwrap(),
&EpochRecord::new(to_update.epoch_id(), to_update.mls_encode_to_vec().unwrap())
);
}
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn test_insert_and_update() {
let mut test_repo = test_group_state_repo(2);
let test_epoch_0 = test_epoch(0);
test_repo.insert(test_epoch_0).await.unwrap();
test_repo
.write_to_storage(test_snapshot(0).await)
.await
.unwrap();
// Update the stored epoch
let to_update = test_repo.get_epoch_mut(0).await.unwrap().unwrap();
let new_sender_secret = random_bytes(32);
to_update.secrets.sender_data_secret = SenderDataSecret::from(new_sender_secret);
let to_update = to_update.clone();
// Insert another epoch
let test_epoch_1 = test_epoch(1);
test_repo.insert(test_epoch_1.clone()).await.unwrap();
test_repo
.write_to_storage(test_snapshot(1).await)
.await
.unwrap();
assert!(test_repo.pending_commit.inserts.is_empty());
assert!(test_repo.pending_commit.updates.is_empty());
// Make sure the storage was written
#[cfg(feature = "std")]
let storage = test_repo.storage.inner.lock().unwrap();
#[cfg(not(feature = "std"))]
let storage = test_repo.storage.inner.lock();
assert_eq!(storage.len(), 1);
let stored = storage.get(TEST_GROUP).unwrap();
assert_eq!(stored.epoch_data.len(), 2);
assert_eq!(
stored.epoch_data.front().unwrap(),
&EpochRecord::new(to_update.epoch_id(), to_update.mls_encode_to_vec().unwrap())
);
assert_eq!(
stored.epoch_data.back().unwrap(),
&EpochRecord::new(
test_epoch_1.epoch_id(),
test_epoch_1.mls_encode_to_vec().unwrap()
)
);
}
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn test_many_epochs_in_storage() {
let epochs = (0..10).map(test_epoch).collect::<Vec<_>>();
let mut test_repo = test_group_state_repo(10);
for epoch in epochs.iter().cloned() {
test_repo.insert(epoch).await.unwrap()
}
test_repo
.write_to_storage(test_snapshot(9).await)
.await
.unwrap();
for mut epoch in epochs {
let res = test_repo.get_epoch_mut(epoch.epoch_id()).await.unwrap();
assert_eq!(res, Some(&mut epoch));
}
}
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn test_stored_groups_list() {
let mut test_repo = test_group_state_repo(2);
let test_epoch_0 = test_epoch(0);
test_repo.insert(test_epoch_0.clone()).await.unwrap();
test_repo
.write_to_storage(test_snapshot(0).await)
.await
.unwrap();
assert_eq!(
test_repo.storage.stored_groups(),
vec![test_epoch_0.context.group_id]
)
}
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn reducing_retention_limit_takes_effect_on_epoch_access() {
let mut repo = test_group_state_repo(1);
repo.insert(test_epoch(0)).await.unwrap();
repo.insert(test_epoch(1)).await.unwrap();
repo.write_to_storage(test_snapshot(0).await).await.unwrap();
let mut repo = GroupStateRepository {
storage: repo.storage,
..test_group_state_repo(1)
};
let res = repo.get_epoch_mut(0).await.unwrap();
assert!(res.is_none());
}
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn in_memory_storage_obeys_retention_limit_after_saving() {
let mut repo = test_group_state_repo(1);
repo.insert(test_epoch(0)).await.unwrap();
repo.write_to_storage(test_snapshot(0).await).await.unwrap();
repo.insert(test_epoch(1)).await.unwrap();
repo.write_to_storage(test_snapshot(1).await).await.unwrap();
#[cfg(feature = "std")]
let lock = repo.storage.inner.lock().unwrap();
#[cfg(not(feature = "std"))]
let lock = repo.storage.inner.lock();
assert_eq!(lock.get(TEST_GROUP).unwrap().epoch_data.len(), 1);
}
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn used_key_package_is_deleted() {
let key_package_repo = InMemoryKeyPackageStorage::default();
let key_package = test_member(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, b"member")
.await
.0;
let (id, data) = key_package.to_storage().unwrap();
key_package_repo.insert(id, data);
let mut repo = GroupStateRepository::new(
TEST_GROUP.to_vec(),
InMemoryGroupStateStorage::new(),
key_package_repo,
Some(key_package.reference.clone()),
)
.unwrap();
repo.key_package_repo.get(&key_package.reference).unwrap();
repo.write_to_storage(test_snapshot(4).await).await.unwrap();
assert!(repo.key_package_repo.get(&key_package.reference).is_none());
}
}