Source code

Revision control

Copy as Markdown

Other Tools

use alloc::{borrow::Cow, string::String};
/// Describes how shader bound checks should be performed.
#[derive(Copy, Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ShaderRuntimeChecks {
/// Enforce bounds checks in shaders, even if the underlying driver doesn't
/// support doing so natively.
///
/// When this is `true`, `wgpu` promises that shaders can only read or
/// write the accessible region of a bindgroup's buffer bindings. If
/// the underlying graphics platform cannot implement these bounds checks
/// itself, `wgpu` will inject bounds checks before presenting the
/// shader to the platform.
///
/// When this is `false`, `wgpu` only enforces such bounds checks if the
/// underlying platform provides a way to do so itself. `wgpu` does not
/// itself add any bounds checks to generated shader code.
///
/// Note that `wgpu` users may try to initialize only those portions of
/// buffers that they anticipate might be read from. Passing `false` here
/// may allow shaders to see wider regions of the buffers than expected,
/// making such deferred initialization visible to the application.
pub bounds_checks: bool,
///
/// If false, the caller MUST ensure that all passed shaders do not contain any infinite loops.
///
/// If it does, backend compilers MAY treat such a loop as unreachable code and draw
/// conclusions about other safety-critical code paths. This option SHOULD NOT be disabled
/// when running untrusted code.
pub force_loop_bounding: bool,
/// If false, the caller **MUST** ensure that in all passed shaders every function operating
/// on a ray query must obey these rules (functions using wgsl naming)
/// - `rayQueryInitialize` must have called before `rayQueryProceed`
/// - `rayQueryProceed` must have been called, returned true and have hit an AABB before
/// `rayQueryGenerateIntersection` is called
/// - `rayQueryProceed` must have been called, returned true and have hit a triangle before
/// `rayQueryConfirmIntersection` is called
/// - `rayQueryProceed` must have been called and have returned true before `rayQueryTerminate`,
/// `getCandidateHitVertexPositions` or `rayQueryGetCandidateIntersection` is called
/// - `rayQueryProceed` must have been called and have returned false before `rayQueryGetCommittedIntersection`
/// or `getCommittedHitVertexPositions` are called
///
/// It is the aim that these cases will not cause UB if this is set to true, but currently this will still happen on DX12 and Metal.
pub ray_query_initialization_tracking: bool,
}
impl ShaderRuntimeChecks {
/// Creates a new configuration where the shader is fully checked.
#[must_use]
pub const fn checked() -> Self {
unsafe { Self::all(true) }
}
/// Creates a new configuration where none of the checks are performed.
///
/// # Safety
///
/// See the documentation for the `set_*` methods for the safety requirements
/// of each sub-configuration.
#[must_use]
pub const fn unchecked() -> Self {
unsafe { Self::all(false) }
}
/// Creates a new configuration where all checks are enabled or disabled. To safely
/// create a configuration with all checks enabled, use [`ShaderRuntimeChecks::checked`].
///
/// # Safety
///
/// See the documentation for the `set_*` methods for the safety requirements
/// of each sub-configuration.
#[must_use]
pub const unsafe fn all(all_checks: bool) -> Self {
Self {
bounds_checks: all_checks,
force_loop_bounding: all_checks,
ray_query_initialization_tracking: all_checks,
}
}
}
impl Default for ShaderRuntimeChecks {
fn default() -> Self {
Self::checked()
}
}
/// Descriptor for a shader module given by any of several sources.
/// These shaders are passed through directly to the underlying api.
/// At least one shader type that may be used by the backend must be `Some` or a panic is raised.
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct CreateShaderModuleDescriptorPassthrough<'a, L> {
/// Entrypoint. Unused for Spir-V.
pub entry_point: String,
/// Debug label of the shader module. This will show up in graphics debuggers for easy identification.
pub label: L,
/// Number of workgroups in each dimension x, y and z. Unused for Spir-V.
pub num_workgroups: (u32, u32, u32),
/// Runtime checks that should be enabled.
pub runtime_checks: ShaderRuntimeChecks,
/// Binary SPIR-V data, in 4-byte words.
pub spirv: Option<Cow<'a, [u32]>>,
/// Shader DXIL source.
pub dxil: Option<Cow<'a, [u8]>>,
/// Shader MSL source.
pub msl: Option<Cow<'a, str>>,
/// Shader HLSL source.
pub hlsl: Option<Cow<'a, str>>,
/// Shader GLSL source (currently unused).
pub glsl: Option<Cow<'a, str>>,
/// Shader WGSL source.
pub wgsl: Option<Cow<'a, str>>,
}
// This is so people don't have to fill in fields they don't use, like num_workgroups,
// entry_point, or other shader languages they didn't compile for
impl<'a, L: Default> Default for CreateShaderModuleDescriptorPassthrough<'a, L> {
fn default() -> Self {
Self {
entry_point: "".into(),
label: Default::default(),
num_workgroups: (0, 0, 0),
runtime_checks: ShaderRuntimeChecks::unchecked(),
spirv: None,
dxil: None,
msl: None,
hlsl: None,
glsl: None,
wgsl: None,
}
}
}
impl<'a, L> CreateShaderModuleDescriptorPassthrough<'a, L> {
/// Takes a closure and maps the label of the shader module descriptor into another.
pub fn map_label<K>(
&self,
fun: impl FnOnce(&L) -> K,
) -> CreateShaderModuleDescriptorPassthrough<'a, K> {
CreateShaderModuleDescriptorPassthrough {
entry_point: self.entry_point.clone(),
label: fun(&self.label),
num_workgroups: self.num_workgroups,
runtime_checks: self.runtime_checks,
spirv: self.spirv.clone(),
dxil: self.dxil.clone(),
msl: self.msl.clone(),
hlsl: self.hlsl.clone(),
glsl: self.glsl.clone(),
wgsl: self.wgsl.clone(),
}
}
#[cfg(feature = "trace")]
/// Returns the source data for tracing purpose.
pub fn trace_data(&self) -> &[u8] {
if let Some(spirv) = &self.spirv {
bytemuck::cast_slice(spirv)
} else if let Some(msl) = &self.msl {
msl.as_bytes()
} else if let Some(dxil) = &self.dxil {
dxil
} else {
panic!("No binary data provided to `ShaderModuleDescriptorGeneric`")
}
}
#[cfg(feature = "trace")]
/// Returns the binary file extension for tracing purpose.
pub fn trace_binary_ext(&self) -> &'static str {
if self.spirv.is_some() {
"spv"
} else if self.msl.is_some() {
"msl"
} else if self.dxil.is_some() {
"dxil"
} else {
panic!("No binary data provided to `ShaderModuleDescriptorGeneric`")
}
}
}