Source code

Revision control

Copy as Markdown

Other Tools

use std::{
collections::{hash_map::Entry, HashMap},
hash::Hash,
sync::{Arc, Weak},
};
use once_cell::sync::OnceCell;
use crate::lock::{rank, Mutex};
use crate::{PreHashedKey, PreHashedMap};
type SlotInner<V> = Weak<V>;
type ResourcePoolSlot<V> = Arc<OnceCell<SlotInner<V>>>;
pub struct ResourcePool<K, V> {
// We use a pre-hashed map as we never actually need to read the keys.
//
// This additionally allows us to not need to hash more than once on get_or_init.
inner: Mutex<PreHashedMap<K, ResourcePoolSlot<V>>>,
}
impl<K: Clone + Eq + Hash, V> ResourcePool<K, V> {
pub fn new() -> Self {
Self {
inner: Mutex::new(rank::RESOURCE_POOL_INNER, HashMap::default()),
}
}
/// Get a resource from the pool with the given entry map, or create a new
/// one if it doesn't exist using the given constructor.
///
/// Behaves such that only one resource will be created for each unique
/// entry map at any one time.
pub fn get_or_init<F, E>(&self, key: K, constructor: F) -> Result<Arc<V>, E>
where
F: FnOnce(K) -> Result<Arc<V>, E>,
{
// Hash the key outside of the lock.
let hashed_key = PreHashedKey::from_key(&key);
// We can't prove at compile time that these will only ever be consumed once,
// so we need to do the check at runtime.
let mut key = Some(key);
let mut constructor = Some(constructor);
'race: loop {
let mut map_guard = self.inner.lock();
let entry = match map_guard.entry(hashed_key) {
// An entry exists for this resource.
//
// We know that either:
// - The resource is still alive, and Weak::upgrade will succeed.
// - The resource is in the process of being dropped, and Weak::upgrade will fail.
//
// The entry will never be empty while the BGL is still alive.
Entry::Occupied(entry) => Arc::clone(entry.get()),
// No entry exists for this resource.
//
// We know that the resource is not alive, so we can create a new entry.
Entry::Vacant(entry) => Arc::clone(entry.insert(Arc::new(OnceCell::new()))),
};
drop(map_guard);
// Some other thread may beat us to initializing the entry, but OnceCell guarantees that only one thread
// will actually initialize the entry.
//
// We pass the strong reference outside of the closure to keep it alive while we're the only one keeping a reference to it.
let mut strong = None;
let weak = entry.get_or_try_init(|| {
let strong_inner = constructor.take().unwrap()(key.take().unwrap())?;
let weak = Arc::downgrade(&strong_inner);
strong = Some(strong_inner);
Ok(weak)
})?;
// If strong is Some, that means we just initialized the entry, so we can just return it.
if let Some(strong) = strong {
return Ok(strong);
}
// The entry was already initialized by someone else, so we need to try to upgrade it.
if let Some(strong) = weak.upgrade() {
// We succeed, the resource is still alive, just return that.
return Ok(strong);
}
// The resource is in the process of being dropped, because upgrade failed. The entry still exists in the map, but it points to nothing.
//
// We're in a race with the drop implementation of the resource, so lets just go around again. When we go around again:
// - If the entry exists, we might need to go around a few more times.
// - If the entry doesn't exist, we'll create a new one.
continue 'race;
}
}
/// Remove the given entry map from the pool.
///
/// Must *only* be called in the Drop impl of [`BindGroupLayout`].
///
/// [`BindGroupLayout`]: crate::binding_model::BindGroupLayout
pub fn remove(&self, key: &K) {
let hashed_key = PreHashedKey::from_key(key);
let mut map_guard = self.inner.lock();
// Weak::upgrade will be failing long before this code is called. All threads trying to access the resource will be spinning,
// waiting for the entry to be removed. It is safe to remove the entry from the map.
map_guard.remove(&hashed_key);
}
}
#[cfg(test)]
mod tests {
use std::sync::{
atomic::{AtomicU32, Ordering},
Barrier,
};
use super::*;
#[test]
fn deduplication() {
let pool = ResourcePool::<u32, u32>::new();
let mut counter = 0_u32;
let arc1 = pool
.get_or_init::<_, ()>(0, |key| {
counter += 1;
Ok(Arc::new(key))
})
.unwrap();
assert_eq!(*arc1, 0);
assert_eq!(counter, 1);
let arc2 = pool
.get_or_init::<_, ()>(0, |key| {
counter += 1;
Ok(Arc::new(key))
})
.unwrap();
assert!(Arc::ptr_eq(&arc1, &arc2));
assert_eq!(*arc2, 0);
assert_eq!(counter, 1);
drop(arc1);
drop(arc2);
pool.remove(&0);
let arc3 = pool
.get_or_init::<_, ()>(0, |key| {
counter += 1;
Ok(Arc::new(key))
})
.unwrap();
assert_eq!(*arc3, 0);
assert_eq!(counter, 2);
}
// Test name has "2_threads" in the name so nextest reserves two threads for it.
#[test]
fn concurrent_creation_2_threads() {
struct Resources {
pool: ResourcePool<u32, u32>,
counter: AtomicU32,
barrier: Barrier,
}
let resources = Arc::new(Resources {
pool: ResourcePool::<u32, u32>::new(),
counter: AtomicU32::new(0),
barrier: Barrier::new(2),
});
// Like all races, this is not inherently guaranteed to work, but in practice it should work fine.
//
// To validate the expected order of events, we've put print statements in the code, indicating when each thread is at a certain point.
// The output will look something like this if the test is working as expected:
//
// ```
// 0: prewait
// 1: prewait
// 1: postwait
// 0: postwait
// 1: init
// 1: postget
// 0: postget
// ```
fn thread_inner(idx: u8, resources: &Resources) -> Arc<u32> {
eprintln!("{idx}: prewait");
// Once this returns, both threads should hit get_or_init at about the same time,
// allowing us to actually test concurrent creation.
//
// Like all races, this is not inherently guaranteed to work, but in practice it should work fine.
resources.barrier.wait();
eprintln!("{idx}: postwait");
let ret = resources
.pool
.get_or_init::<_, ()>(0, |key| {
eprintln!("{idx}: init");
// Simulate long running constructor, ensuring that both threads will be in get_or_init.
std::thread::sleep(std::time::Duration::from_millis(250));
resources.counter.fetch_add(1, Ordering::SeqCst);
Ok(Arc::new(key))
})
.unwrap();
eprintln!("{idx}: postget");
ret
}
let thread1 = std::thread::spawn({
let resource_clone = Arc::clone(&resources);
move || thread_inner(1, &resource_clone)
});
let arc0 = thread_inner(0, &resources);
assert_eq!(resources.counter.load(Ordering::Acquire), 1);
let arc1 = thread1.join().unwrap();
assert!(Arc::ptr_eq(&arc0, &arc1));
}
// Test name has "2_threads" in the name so nextest reserves two threads for it.
#[test]
fn create_while_drop_2_threads() {
struct Resources {
pool: ResourcePool<u32, u32>,
barrier: Barrier,
}
let resources = Arc::new(Resources {
pool: ResourcePool::<u32, u32>::new(),
barrier: Barrier::new(2),
});
// Like all races, this is not inherently guaranteed to work, but in practice it should work fine.
//
// To validate the expected order of events, we've put print statements in the code, indicating when each thread is at a certain point.
// The output will look something like this if the test is working as expected:
//
// ```
// 0: prewait
// 1: prewait
// 1: postwait
// 0: postwait
// 1: postsleep
// 1: removal
// 0: postget
// ```
//
// The last two _may_ be flipped.
let existing_entry = resources
.pool
.get_or_init::<_, ()>(0, |key| Ok(Arc::new(key)))
.unwrap();
// Drop the entry, but do _not_ remove it from the pool.
// This simulates the situation where the resource arc has been dropped, but the Drop implementation
// has not yet run, which calls remove.
drop(existing_entry);
fn thread0_inner(resources: &Resources) {
eprintln!("0: prewait");
resources.barrier.wait();
eprintln!("0: postwait");
// We try to create a new entry, but the entry already exists.
//
// As Arc::upgrade is failing, we will just keep spinning until remove is called.
resources
.pool
.get_or_init::<_, ()>(0, |key| Ok(Arc::new(key)))
.unwrap();
eprintln!("0: postget");
}
fn thread1_inner(resources: &Resources) {
eprintln!("1: prewait");
resources.barrier.wait();
eprintln!("1: postwait");
// We wait a little bit, making sure that thread0_inner has started spinning.
std::thread::sleep(std::time::Duration::from_millis(250));
eprintln!("1: postsleep");
// We remove the entry from the pool, allowing thread0_inner to re-create.
resources.pool.remove(&0);
eprintln!("1: removal");
}
let thread1 = std::thread::spawn({
let resource_clone = Arc::clone(&resources);
move || thread1_inner(&resource_clone)
});
thread0_inner(&resources);
thread1.join().unwrap();
}
}