Implement ISession and remove ICompileRequest

The ICompileRequest is part of Slang's legacy API, and will be removed
in the future. The ISession and associated interfaces are to be used
instead.
This commit is contained in:
Lauro Oyen
2024-04-04 13:32:30 +02:00
parent 5abe4533ba
commit 095c719ff8
4 changed files with 507 additions and 235 deletions

View File

@ -1,15 +1,24 @@
use slang_sys as sys;
use std::ffi::{CStr, CString};
use std::path::Path;
use std::slice;
use std::ptr::{null, null_mut};
pub use sys::SlangUUID as UUID;
pub use sys::SlangCompileTarget as CompileTarget;
pub use sys::SlangMatrixLayoutMode as MatrixLayoutMode;
pub use sys::SlangOptimizationLevel as OptimizationLevel;
pub use sys::SlangSourceLanguage as SourceLanguage;
pub use sys::SlangStage as Stage;
pub use sys::SlangProfileID as ProfileID;
use slang_sys as sys;
pub use sys::{
slang_CompilerOptionName as CompilerOptionName,
slang_SessionDesc as SessionDesc,
slang_TargetDesc as TargetDesc,
SlangCapabilityID as CapabilityID,
SlangCompileTarget as CompileTarget,
SlangDebugInfoLevel as DebugInfoLevel,
SlangFloatingPointMode as FloatingPointMode,
SlangLineDirectiveMode as LineDirectiveMode,
SlangMatrixLayoutMode as MatrixLayoutMode,
SlangOptimizationLevel as OptimizationLevel,
SlangProfileID as ProfileID,
SlangSourceLanguage as SourceLanguage,
SlangStage as Stage,
SlangUUID as UUID,
};
macro_rules! vcall {
($self:expr, $method:ident($($args:expr),*)) => {
@ -34,6 +43,15 @@ unsafe trait Interface: Sized {
unsafe fn as_raw<T>(&self) -> *mut T {
std::mem::transmute_copy(self)
}
fn as_unknown(&self) -> &IUnknown {
// SAFETY: It is always safe to treat an `Interface` as an `IUnknown`.
unsafe { std::mem::transmute(self) }
}
}
pub unsafe trait Downcast<T> {
fn downcast(&self) -> &T;
}
#[repr(transparent)]
@ -41,7 +59,7 @@ pub struct IUnknown(std::ptr::NonNull<std::ffi::c_void>);
unsafe impl Interface for IUnknown {
type Vtable = sys::ISlangUnknown__bindgen_vtable;
const IID: UUID = uuid(0x00000000, 0x0000, 0x0000, [0xC0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x46]);
const IID: UUID = uuid(0x00000000, 0x0000, 0x0000, [0xc0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x46]);
}
impl Clone for IUnknown {
@ -57,170 +75,416 @@ impl Drop for IUnknown {
}
}
#[repr(transparent)]
#[derive(Clone)]
pub struct Blob(IUnknown);
unsafe impl Interface for Blob {
type Vtable = sys::IBlobVtable;
const IID: UUID = uuid(0x8ba5fb08, 0x5195, 0x40e2, [0xac, 0x58, 0x0d, 0x98, 0x9c, 0x3a, 0x01, 0x02]);
}
impl Blob {
pub fn as_slice(&self) -> &[u8] {
let ptr = vcall!(self, getBufferPointer());
let size = vcall!(self, getBufferSize());
unsafe { std::slice::from_raw_parts(ptr as *const u8, size) }
}
}
#[repr(transparent)]
#[derive(Clone)]
pub struct GlobalSession(IUnknown);
unsafe impl Interface for GlobalSession {
type Vtable = sys::IGlobalSessionVtable;
const IID: UUID = uuid(0xc140b5fd, 0xc78, 0x452e, [0xba, 0x7c, 0x1a, 0x1e, 0x70, 0xc7, 0xf7, 0x1c]);
const IID: UUID = uuid(0xc140b5fd, 0x0c78, 0x452e, [0xba, 0x7c, 0x1a, 0x1e, 0x70, 0xc7, 0xf7, 0x1c]);
}
impl GlobalSession {
pub fn new() -> GlobalSession {
unsafe {
let mut global_session = std::ptr::null_mut();
sys::slang_createGlobalSession(sys::SLANG_API_VERSION as _, &mut global_session);
GlobalSession(IUnknown(std::ptr::NonNull::new(global_session as *mut _).unwrap()))
}
pub fn new() -> Option<GlobalSession> {
let mut global_session = null_mut();
unsafe { sys::slang_createGlobalSession(sys::SLANG_API_VERSION as _, &mut global_session) };
Some(GlobalSession(IUnknown(std::ptr::NonNull::new(global_session as *mut _)?)))
}
pub fn new_without_std_lib() -> GlobalSession {
unsafe {
let mut global_session = std::ptr::null_mut();
sys::slang_createGlobalSessionWithoutStdLib(sys::SLANG_API_VERSION as _, &mut global_session);
GlobalSession(IUnknown(std::ptr::NonNull::new(global_session as *mut _).unwrap()))
}
pub fn new_without_std_lib() -> Option<GlobalSession> {
let mut global_session = null_mut();
unsafe { sys::slang_createGlobalSessionWithoutStdLib(sys::SLANG_API_VERSION as _, &mut global_session) };
Some(GlobalSession(IUnknown(std::ptr::NonNull::new(global_session as *mut _)?)))
}
pub fn create_compile_request(&self) -> CompileRequest {
let mut compile_request = std::ptr::null_mut();
vcall!(self, createCompileRequest(&mut compile_request));
CompileRequest(IUnknown(std::ptr::NonNull::new(compile_request).unwrap()))
pub fn create_session(&self, desc: &SessionDesc) -> Option<Session> {
let mut session = null_mut();
let res = vcall!(self, createSession(desc, &mut session));
let session = Session(IUnknown(std::ptr::NonNull::new(session as *mut _)?));
// TODO: Without adding an extra reference, the code crashes when Session is dropped.
// Investigate why this is happening, the current solution could cause a memory leak.
unsafe { (session.as_unknown().vtable().ISlangUnknown_addRef)(session.as_raw()) };
Some(session)
}
pub fn find_profile(&self, name: &str) -> ProfileID {
let name = CString::new(name).unwrap();
vcall!(self, findProfile(name.as_ptr()))
}
pub fn find_capability(&self, name: &str) -> CapabilityID {
let name = CString::new(name).unwrap();
vcall!(self, findCapability(name.as_ptr()))
}
}
#[repr(transparent)]
#[derive(Clone)]
pub struct CompileRequest(IUnknown);
pub struct Session(IUnknown);
unsafe impl Interface for CompileRequest {
type Vtable = sys::ICompileRequestVtable;
const IID: UUID = uuid(0x96d33993, 0x317c, 0x4db5, [0xaf, 0xd8, 0x66, 0x6e, 0xe7, 0x72, 0x48, 0xe2]);
unsafe impl Interface for Session {
type Vtable = sys::ISessionVtable;
const IID: UUID = uuid(0x67618701, 0xd116, 0x468f, [0xab, 0x3b, 0x47, 0x4b, 0xed, 0xce, 0x0e, 0x3d]);
}
impl CompileRequest {
pub fn set_codegen_target(&mut self, target: CompileTarget) -> &mut Self {
vcall!(self, setCodeGenTarget(target));
self
}
pub fn set_matrix_layout_mode(&mut self, layout: MatrixLayoutMode) -> &mut Self {
vcall!(self, setMatrixLayoutMode(layout));
self
}
pub fn set_optimization_level(&mut self, level: OptimizationLevel) -> &mut Self {
vcall!(self, setOptimizationLevel(level));
self
}
pub fn set_target_profile(&mut self, profile: ProfileID) ->&mut Self {
vcall!(self, setTargetProfile(0, profile));
self
}
pub fn add_preprocessor_define(&mut self, key: &str, value: &str) -> &mut Self {
let key = CString::new(key).unwrap();
let value = CString::new(value).unwrap();
vcall!(self, addPreprocessorDefine(key.as_ptr(), value.as_ptr()));
self
}
pub fn add_search_path(&mut self, path: impl AsRef<Path>) -> &mut Self {
let path = CString::new(path.as_ref().to_str().unwrap()).unwrap();
vcall!(self, addSearchPath(path.as_ptr()));
self
}
pub fn add_translation_unit(&mut self, source_language: SourceLanguage, name: Option<&str>) -> TranslationUnit {
let name = CString::new(name.unwrap_or("")).unwrap();
let index = vcall!(self, addTranslationUnit(source_language, name.as_ptr()));
TranslationUnit {
request: self,
index,
}
}
pub fn compile(self) -> Result<CompiledRequest, CompilationErrors> {
let r = vcall!(self, compile());
if r < 0 {
let out = vcall!(self, getDiagnosticOutput());
let errors = unsafe { CStr::from_ptr(out).to_str().unwrap().to_string() };
Err(CompilationErrors { errors })
} else {
Ok(CompiledRequest { request: self })
}
}
}
pub struct TranslationUnit<'a> {
request: &'a mut CompileRequest,
index: i32,
}
impl<'a> TranslationUnit<'a> {
pub fn add_preprocessor_define(&mut self, key: &str, value: &str) -> &mut Self {
let key = CString::new(key).unwrap();
let value = CString::new(value).unwrap();
vcall!(self.request, addTranslationUnitPreprocessorDefine(self.index, key.as_ptr(), value.as_ptr()));
self
}
pub fn add_source_file(&mut self, path: impl AsRef<Path>) -> &mut Self {
let path = CString::new(path.as_ref().to_str().unwrap()).unwrap();
vcall!(self.request, addTranslationUnitSourceFile(self.index, path.as_ptr()));
self
}
pub fn add_source_string(&mut self, path: impl AsRef<Path>, source: &str) -> &mut Self {
let path = CString::new(path.as_ref().to_str().unwrap()).unwrap();
let source = CString::new(source).unwrap();
vcall!(self.request, addTranslationUnitSourceString(self.index, path.as_ptr(), source.as_ptr()));
self
}
pub fn add_entry_point(&mut self, name: &str, stage: Stage) -> EntryPointIndex {
impl Session {
pub fn load_module(&self, name: &str) -> Result<Module, String> {
let name = CString::new(name).unwrap();
let index = vcall!(self.request, addEntryPoint(self.index, name.as_ptr(), stage));
EntryPointIndex(index)
let mut diagnostics = null_mut();
let module = vcall!(self, loadModule(name.as_ptr(), &mut diagnostics));
if module.is_null() {
let blob = Blob(IUnknown(std::ptr::NonNull::new(diagnostics as *mut _).unwrap()));
Err(std::str::from_utf8(blob.as_slice()).unwrap().to_string())
} else {
Ok(Module(IUnknown(std::ptr::NonNull::new(module as *mut _).unwrap())))
}
}
pub fn create_composite_component_type(&self, components: &[&ComponentType]) -> ComponentType {
let components: Vec<*mut std::ffi::c_void> = unsafe { components.iter().map(|c| c.as_raw()).collect() };
let mut composite_component_type = null_mut();
let mut diagnostics = null_mut();
let res = vcall!(self, createCompositeComponentType(components.as_ptr() as _, components.len() as _, &mut composite_component_type, &mut diagnostics));
ComponentType(IUnknown(std::ptr::NonNull::new(composite_component_type as *mut _).unwrap()))
}
}
pub struct CompiledRequest {
request: CompileRequest,
#[repr(transparent)]
#[derive(Clone)]
pub struct ComponentType(IUnknown);
unsafe impl Interface for ComponentType {
type Vtable = sys::IComponentTypeVtable;
const IID: UUID = uuid(0x5bc42be8, 0x5c50, 0x4929, [0x9e, 0x5e, 0xd1, 0x5e, 0x7c, 0x24, 0x01, 0x5f]);
}
impl CompiledRequest {
pub fn get_entry_point_code(&self, index: EntryPointIndex) -> &[u8] {
let mut out_size = 0;
let ptr = vcall!(self.request, getEntryPointCode(index.0, &mut out_size));
unsafe { slice::from_raw_parts(ptr as *const u8, out_size) }
impl ComponentType {
pub fn link(&self) -> ComponentType {
let mut linked_component_type = null_mut();
let mut diagnostics = null_mut();
let res = vcall!(self, link(&mut linked_component_type, &mut diagnostics));
if linked_component_type.is_null() {
let blob = Blob(IUnknown(std::ptr::NonNull::new(diagnostics as *mut _).unwrap()));
let error = std::str::from_utf8(blob.as_slice()).unwrap().to_string();
println!("Error: {}", error);
}
ComponentType(IUnknown(std::ptr::NonNull::new(linked_component_type as *mut _).unwrap()))
}
pub fn get_entry_point_code(&self, index: i64, target: i64) -> Vec<u8> {
let mut code = null_mut();
let mut diagnostics = null_mut();
let res = vcall!(self, getEntryPointCode(index, target, &mut code, &mut diagnostics));
if code.is_null() {
let blob = Blob(IUnknown(std::ptr::NonNull::new(diagnostics as *mut _).unwrap()));
let error = std::str::from_utf8(blob.as_slice()).unwrap().to_string();
println!("Error: {}", error);
}
let blob = Blob(IUnknown(std::ptr::NonNull::new(code as *mut _).unwrap()));
Vec::from(blob.as_slice())
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub struct EntryPointIndex(pub i32);
#[repr(transparent)]
#[derive(Clone)]
pub struct EntryPoint(IUnknown);
pub struct CompilationErrors {
errors: String,
unsafe impl Interface for EntryPoint {
type Vtable = sys::IEntryPointVtable;
const IID: UUID = uuid(0x8f241361, 0xf5bd, 0x4ca0, [0xa3, 0xac, 0x02, 0xf7, 0xfa, 0x24, 0x02, 0xb8]);
}
impl std::fmt::Debug for CompilationErrors {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("\n")?;
f.write_str(&self.errors)
unsafe impl Downcast<ComponentType> for EntryPoint {
fn downcast(&self) -> &ComponentType {
unsafe { std::mem::transmute(self) }
}
}
#[repr(transparent)]
#[derive(Clone)]
pub struct TypeConformance(IUnknown);
unsafe impl Interface for TypeConformance {
type Vtable = sys::ITypeConformanceVtable;
const IID: UUID = uuid(0x73eb3147, 0xe544, 0x41b5, [0xb8, 0xf0, 0xa2, 0x44, 0xdf, 0x21, 0x94, 0x0b]);
}
unsafe impl Downcast<ComponentType> for TypeConformance {
fn downcast(&self) -> &ComponentType {
unsafe { std::mem::transmute(self) }
}
}
#[repr(transparent)]
#[derive(Clone)]
pub struct Module(IUnknown);
unsafe impl Interface for Module {
type Vtable = sys::IModuleVtable;
const IID: UUID = uuid(0x0c720e64, 0x8722, 0x4d31, [0x89, 0x90, 0x63, 0x8a, 0x98, 0xb1, 0xc2, 0x79]);
}
unsafe impl Downcast<ComponentType> for Module {
fn downcast(&self) -> &ComponentType {
unsafe { std::mem::transmute(self) }
}
}
impl Module {
pub fn find_entry_point_by_name(&self, name: &str) -> Option<EntryPoint> {
let name = CString::new(name).unwrap();
let mut entry_point = null_mut();
vcall!(self, findEntryPointByName(name.as_ptr(), &mut entry_point));
Some(EntryPoint(IUnknown(std::ptr::NonNull::new(entry_point as *mut _)?)))
}
pub fn name(&self) -> &str {
let name = vcall!(self, getName());
unsafe { CStr::from_ptr(name).to_str().unwrap() }
}
pub fn file_path(&self) -> &str {
let path = vcall!(self, getFilePath());
unsafe { CStr::from_ptr(path).to_str().unwrap() }
}
pub fn unique_identity(&self) -> &str {
let identity = vcall!(self, getUniqueIdentity());
unsafe { CStr::from_ptr(identity).to_str().unwrap() }
}
}
pub struct TargetDescBuilder {
inner: TargetDesc,
}
impl TargetDescBuilder {
pub fn new() -> TargetDescBuilder {
Self {
inner: TargetDesc {
structureSize: std::mem::size_of::<TargetDesc>(),
..unsafe { std::mem::zeroed() }
}
}
}
pub fn format(mut self, format: CompileTarget) -> Self {
self.inner.format = format;
self
}
pub fn profile(mut self, profile: ProfileID) -> Self {
self.inner.profile = profile;
self
}
pub fn options(mut self, options: &OptionsBuilder) -> Self {
self.inner.compilerOptionEntries = options.options.as_ptr() as _;
self.inner.compilerOptionEntryCount = options.options.len() as _;
self
}
}
impl std::ops::Deref for TargetDescBuilder {
type Target = TargetDesc;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
pub struct SessionDescBuilder {
inner: SessionDesc,
}
impl SessionDescBuilder {
pub fn new() -> SessionDescBuilder {
Self {
inner: SessionDesc {
structureSize: std::mem::size_of::<SessionDesc>(),
..unsafe { std::mem::zeroed() }
}
}
}
pub fn targets(mut self, targets: &[TargetDesc]) -> Self {
self.inner.targets = targets.as_ptr();
self.inner.targetCount = targets.len() as _;
self
}
pub fn search_paths(mut self, paths: &[*const i8]) -> Self {
self.inner.searchPaths = paths.as_ptr();
self.inner.searchPathCount = paths.len() as _;
self
}
pub fn options(mut self, options: &OptionsBuilder) -> Self {
self.inner.compilerOptionEntries = options.options.as_ptr() as _;
self.inner.compilerOptionEntryCount = options.options.len() as _;
self
}
}
impl std::ops::Deref for SessionDescBuilder {
type Target = SessionDesc;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
macro_rules! option {
($name:ident, $func:ident($p_name:ident: $p_type:ident)) => {
#[inline(always)]
pub fn $func(self, $p_name: $p_type) -> Self {
self.push_ints(CompilerOptionName::$name, $p_name as _, 0)
}
};
($name:ident, $func:ident($p_name:ident: &str)) => {
#[inline(always)]
pub fn $func(self, $p_name: &str) -> Self {
self.push_str1(CompilerOptionName::$name, $p_name)
}
};
($name:ident, $func:ident($p_name1:ident: &str, $p_name2:ident: &str)) => {
#[inline(always)]
pub fn $func(self, $p_name1: &str, $p_name2: &str) -> Self {
self.push_str2(CompilerOptionName::$name, $p_name1, $p_name2)
}
};
}
pub struct OptionsBuilder {
strings: Vec<CString>,
options: Vec<sys::slang_CompilerOptionEntry>,
}
impl OptionsBuilder {
pub fn new() -> OptionsBuilder {
OptionsBuilder {
strings: Vec::new(),
options: Vec::new(),
}
}
fn push_ints(mut self, name: CompilerOptionName, i0: i32, i1: i32) -> Self {
self.options.push(sys::slang_CompilerOptionEntry {
name,
value: sys::slang_CompilerOptionValue {
kind: sys::slang_CompilerOptionValueKind::Int,
intValue0: i0,
intValue1: i1,
stringValue0: null(),
stringValue1: null(),
},
});
self
}
fn push_strings(mut self, name: CompilerOptionName, s0: *const i8, s1: *const i8) -> Self {
self.options.push(sys::slang_CompilerOptionEntry {
name,
value: sys::slang_CompilerOptionValue {
kind: sys::slang_CompilerOptionValueKind::String,
intValue0: 0,
intValue1: 0,
stringValue0: s0,
stringValue1: s1,
},
});
self
}
fn push_str1(mut self, name: CompilerOptionName, s0: &str) -> Self {
let s0 = CString::new(s0).unwrap();
let s0_ptr = s0.as_ptr();
self.strings.push(s0);
self.push_strings(name, s0_ptr, null())
}
fn push_str2(mut self, name: CompilerOptionName, s0: &str, s1: &str) -> Self {
let s0 = CString::new(s0).unwrap();
let s0_ptr = s0.as_ptr();
self.strings.push(s0);
let s1 = CString::new(s1).unwrap();
let s1_ptr = s1.as_ptr();
self.strings.push(s1);
self.push_strings(name, s0_ptr, s1_ptr)
}
}
impl OptionsBuilder {
option!(MacroDefine, macro_define(key: &str, value: &str));
option!(Include, include(path: &str));
option!(Language, language(language: SourceLanguage));
option!(MatrixLayoutColumn, matrix_layout_column(enable: bool));
option!(MatrixLayoutRow, matrix_layout_row(enable: bool));
option!(Profile, profile(profile: ProfileID));
option!(Stage, stage(stage: Stage));
option!(Target, target(target: CompileTarget));
option!(WarningsAsErrors, warnings_as_errors(warning_codes: &str));
option!(DisableWarnings, disable_warnings(warning_codes: &str));
option!(EnableWarning, enable_warning(warning_code: &str));
option!(DisableWarning, disable_warning(warning_code: &str));
option!(ReportDownstreamTime, report_downstream_time(enable: bool));
option!(ReportPerfBenchmark, report_perf_benchmark(enable: bool));
option!(SkipSPIRVValidation, skip_spirv_validation(enable: bool));
// Target
option!(Capability, capability(capability: CapabilityID));
option!(DefaultImageFormatUnknown, default_image_format_unknown(enable: bool));
option!(DisableDynamicDispatch, disable_dynamic_dispatch(enable: bool));
option!(DisableSpecialization, disable_specialization(enable: bool));
option!(FloatingPointMode, floating_point_mode(mode: FloatingPointMode));
option!(DebugInformation, debug_information(level: DebugInfoLevel));
option!(LineDirectiveMode, line_directive_mode(mode: LineDirectiveMode));
option!(Optimization, optimization(level: OptimizationLevel));
option!(Obfuscate, obfuscate(enable: bool));
option!(GLSLForceScalarLayout, glsl_force_scalar_layout(enable: bool));
option!(EmitSpirvDirectly, emit_spirv_directly(enable: bool));
// Debugging
option!(NoCodeGen, no_code_gen(enable: bool));
// Experimental
option!(NoMangle, no_mangle(enable: bool));
option!(ValidateUniformity, validate_uniformity(enable: bool));
}
#[cfg(test)]
mod tests {
#[test]