Revision control

Copy as Markdown

Other Tools

// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// Copyright by contributors to this project.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
use crate::client::MlsError;
use crate::key_package::KeyPackageRef;
use alloc::vec::Vec;
use mls_rs_codec::MlsEncode;
use mls_rs_core::{
error::IntoAnyError,
group::{GroupState, GroupStateStorage},
key_package::KeyPackageStorage,
};
use super::snapshot::Snapshot;
#[derive(Debug, Clone)]
pub(crate) struct GroupStateRepository<S, K>
where
S: GroupStateStorage,
K: KeyPackageStorage,
{
pending_key_package_removal: Option<KeyPackageRef>,
storage: S,
key_package_repo: K,
}
impl<S, K> GroupStateRepository<S, K>
where
S: GroupStateStorage,
K: KeyPackageStorage,
{
pub fn new(
storage: S,
key_package_repo: K,
// Set to `None` if restoring from snapshot; set to `Some` when joining a group.
key_package_to_remove: Option<KeyPackageRef>,
) -> Result<GroupStateRepository<S, K>, MlsError> {
Ok(GroupStateRepository {
storage,
pending_key_package_removal: key_package_to_remove,
key_package_repo,
})
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn write_to_storage(&mut self, group_snapshot: Snapshot) -> Result<(), MlsError> {
let group_state = GroupState {
data: group_snapshot.mls_encode_to_vec()?,
id: group_snapshot.state.context.group_id,
};
self.storage
.write(group_state, Vec::new(), Vec::new())
.await
.map_err(|e| MlsError::GroupStorageError(e.into_any_error()))?;
if let Some(ref key_package_ref) = self.pending_key_package_removal {
self.key_package_repo
.delete(key_package_ref)
.await
.map_err(|e| MlsError::KeyPackageRepoError(e.into_any_error()))?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use crate::{
client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
group::{
snapshot::{test_utils::get_test_snapshot, Snapshot},
test_utils::{test_member, TEST_GROUP},
},
storage_provider::in_memory::{InMemoryGroupStateStorage, InMemoryKeyPackageStorage},
};
use alloc::vec;
use super::GroupStateRepository;
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
async fn test_snapshot(epoch_id: u64) -> Snapshot {
get_test_snapshot(TEST_CIPHER_SUITE, epoch_id).await
}
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn test_stored_groups_list() {
let mut test_repo = GroupStateRepository::new(
InMemoryGroupStateStorage::default(),
InMemoryKeyPackageStorage::default(),
None,
)
.unwrap();
test_repo
.write_to_storage(test_snapshot(0).await)
.await
.unwrap();
assert_eq!(test_repo.storage.stored_groups(), vec![TEST_GROUP])
}
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn used_key_package_is_deleted() {
let key_package_repo = InMemoryKeyPackageStorage::default();
let key_package = test_member(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, b"member")
.await
.0;
let (id, data) = key_package.to_storage().unwrap();
key_package_repo.insert(id, data);
let mut repo = GroupStateRepository::new(
InMemoryGroupStateStorage::default(),
key_package_repo,
Some(key_package.reference.clone()),
)
.unwrap();
repo.key_package_repo.get(&key_package.reference).unwrap();
repo.write_to_storage(test_snapshot(4).await).await.unwrap();
assert!(repo.key_package_repo.get(&key_package.reference).is_none());
}
}