diff --git a/README.md b/README.md index 39f5549..2c64838 100644 --- a/README.md +++ b/README.md @@ -7,23 +7,37 @@ Currently mostly reflects the needs of our own [engine](https://github.com/Float ## Example ```rust -let session = slang::GlobalSession::new(); +let global_session = slang::GlobalSession::new().unwrap(); -let mut compile_request = session.create_compile_request(); +let search_path = std::ffi::CString::new("shaders/directory").unwrap(); -compile_request - .set_codegen_target(slang::CompileTarget::Dxil) - .set_target_profile(session.find_profile("sm_6_5")); +// All compiler options are available through this builder. +let session_options = slang::OptionsBuilder::new() + .optimization(slang::OptimizationLevel::High) + .matrix_layout_row(true); -let entry_point = compile_request - .add_translation_unit(slang::SourceLanguage::Slang, None) - .add_source_file(filepath) - .add_entry_point("main", slang::Stage::Compute); +let target_desc = slang::TargetDescBuilder::new() + .format(slang::CompileTarget::Dxil) + .profile(self.global_session.find_profile("sm_6_5")); -let shader_bytecode = compile_request - .compile() - .expect("Shader compilation failed.") - .get_entry_point_code(entry_point); +let session_desc = slang::SessionDescBuilder::new() + .targets(&[*target_desc]) + .search_paths(&[include_path.as_ptr()]) + .options(&session_options); + +let session = self.global_session.create_session(&session_desc).unwrap(); + +let module = session.load_module("filename.slang").unwrap(); + +let entry_point = module.find_entry_point_by_name("main").unwrap(); + +let program = session.create_composite_component_type(&[ + module.downcast(), entry_point.downcast(), +]); + +let linked_program = program.link(); + +let shader_bytecode = linked_program.get_entry_point_code(0, 0); ``` ## Installation diff --git a/slang-sys/build.rs b/slang-sys/build.rs index c1ba9cb..5cdceba 100644 --- a/slang-sys/build.rs +++ b/slang-sys/build.rs @@ -98,19 +98,23 @@ impl bindgen::callbacks::ParseCallbacks for ParseCallback { } /// Converts `snake_case` or `SNAKE_CASE` to `PascalCase`. +/// If the input is already in `PascalCase` it will be returned as is. fn pascal_case_from_snake_case(snake_case: &str) -> String { let mut result = String::new(); - let mut capitalize_next = true; - for c in snake_case.chars() { - if c == '_' { - capitalize_next = true; - } else { - if capitalize_next { + let should_lower = snake_case + .chars() + .filter(|c| c.is_alphabetic()) + .all(|c| c.is_uppercase()); + + for part in snake_case.split('_') { + for (i, c) in part.chars().enumerate() { + if i == 0 { result.push(c.to_ascii_uppercase()); - capitalize_next = false; - } else { + } else if should_lower { result.push(c.to_ascii_lowercase()); + } else { + result.push(c); } } } diff --git a/slang-sys/src/lib.rs b/slang-sys/src/lib.rs index 6d4ea5d..72bb2a4 100644 --- a/slang-sys/src/lib.rs +++ b/slang-sys/src/lib.rs @@ -3,10 +3,15 @@ include!(concat!(env!("OUT_DIR"), "/bindings.rs")); use std::ffi::{c_char, c_int, c_void}; -// The vtables below are manually implemented since bindgen does not yet support -// generating vtables for types with base classes, a critical part of COM interfaces. +// Based on Slang version 2024.1.6 -// Based on Slang version 2024.0.10 +#[repr(C)] +pub struct IBlobVtable { + pub _base: ISlangUnknown__bindgen_vtable, + + pub getBufferPointer: unsafe extern "stdcall" fn(*mut c_void) -> *const c_void, + pub getBufferSize: unsafe extern "stdcall" fn(*mut c_void) -> usize, +} #[repr(C)] pub struct IGlobalSessionVtable { @@ -24,7 +29,7 @@ pub struct IGlobalSessionVtable { pub getDefaultDownstreamCompiler: unsafe extern "stdcall" fn(*mut c_void, sourceLanguage: SlangSourceLanguage) -> SlangPassThrough, pub setLanguagePrelude: unsafe extern "stdcall" fn(*mut c_void, sourceLanguage: SlangSourceLanguage, preludeText: *const c_char), pub getLanguagePrelude: unsafe extern "stdcall" fn(*mut c_void, sourceLanguage: SlangSourceLanguage, outPrelude: *mut *mut ISlangBlob), - pub createCompileRequest: unsafe extern "stdcall" fn(*mut c_void, *mut *mut c_void) -> SlangResult, + pub createCompileRequest: unsafe extern "stdcall" fn(*mut c_void, *mut *mut slang_ICompileRequest) -> SlangResult, pub addBuiltins: unsafe extern "stdcall" fn(*mut c_void, sourcePath: *const c_char, sourceString: *const c_char), pub setSharedLibraryLoader: unsafe extern "stdcall" fn(*mut c_void, loader: *mut ISlangSharedLibraryLoader), pub getSharedLibraryLoader: unsafe extern "stdcall" fn(*mut c_void) -> *mut ISlangSharedLibraryLoader, @@ -39,85 +44,70 @@ pub struct IGlobalSessionVtable { pub getCompilerElapsedTime: unsafe extern "stdcall" fn(*mut c_void, outTotalTime: *mut f64, outDownstreamTime: *mut f64), pub setSPIRVCoreGrammar: unsafe extern "stdcall" fn(*mut c_void, jsonPath: *const c_char) -> SlangResult, pub parseCommandLineArguments: unsafe extern "stdcall" fn(*mut c_void, argc: c_int, argv: *const *const c_char, outSessionDesc: *mut slang_SessionDesc, outAuxAllocation: *mut *mut ISlangUnknown) -> SlangResult, + pub getSessionDescDigest: unsafe extern "stdcall" fn(*mut c_void, sessionDesc: *const slang_SessionDesc, outBlob: *mut *mut ISlangBlob) -> SlangResult, } #[repr(C)] -pub struct ICompileRequestVtable { +pub struct ISessionVtable { pub _base: ISlangUnknown__bindgen_vtable, - pub setFileSystem: unsafe extern "stdcall" fn(*mut c_void, fileSystem: *mut ISlangFileSystem), - pub setCompileFlags: unsafe extern "stdcall" fn(*mut c_void, flags: SlangCompileFlags), - pub getCompileFlags: unsafe extern "stdcall" fn(*mut c_void) -> SlangCompileFlags, - pub setDumpIntermediates: unsafe extern "stdcall" fn(*mut c_void, enable: c_int), - pub setDumpIntermediatePrefix: unsafe extern "stdcall" fn(*mut c_void, prefix: *const c_char), - pub setLineDirectiveMode: unsafe extern "stdcall" fn(*mut c_void, mode: SlangLineDirectiveMode), - pub setCodeGenTarget: unsafe extern "stdcall" fn(*mut c_void, target: SlangCompileTarget), - pub addCodeGenTarget: unsafe extern "stdcall" fn(*mut c_void, target: SlangCompileTarget) -> c_int, - pub setTargetProfile: unsafe extern "stdcall" fn(*mut c_void, targetIndex: c_int, profile: SlangProfileID), - pub setTargetFlags: unsafe extern "stdcall" fn(*mut c_void, targetIndex: c_int, flags: SlangTargetFlags), - pub setTargetFloatingPointMode: unsafe extern "stdcall" fn(*mut c_void, targetIndex: c_int, mode: SlangFloatingPointMode), - #[deprecated( note = "Use setMatrixLayoutMode instead")] - pub setTargetMatrixLayoutMode: unsafe extern "stdcall" fn(*mut c_void, target: c_int, mode: SlangMatrixLayoutMode), - pub setMatrixLayoutMode: unsafe extern "stdcall" fn(*mut c_void, mode: SlangMatrixLayoutMode), - pub setDebugInfoLevel: unsafe extern "stdcall" fn(*mut c_void, level: SlangDebugInfoLevel), - pub setOptimizationLevel: unsafe extern "stdcall" fn(*mut c_void, level: SlangOptimizationLevel), - pub setOutputContainerFormat: unsafe extern "stdcall" fn(*mut c_void, format: SlangContainerFormat), - pub setPassThrough: unsafe extern "stdcall" fn(*mut c_void, passThrough: SlangPassThrough), - pub setDiagnosticCallback: unsafe extern "stdcall" fn(*mut c_void, callback: SlangDiagnosticCallback, userData: *const c_void), - pub setWriter: unsafe extern "stdcall" fn(*mut c_void, channel: SlangWriterChannel, writer: *mut ISlangWriter), - pub getWriter: unsafe extern "stdcall" fn(*mut c_void, channel: SlangWriterChannel) -> *mut ISlangWriter, - pub addSearchPath: unsafe extern "stdcall" fn(*mut c_void, searchDir: *const c_char), - pub addPreprocessorDefine: unsafe extern "stdcall" fn(*mut c_void, key: *const c_char, value: *const c_char), - pub processCommandLineArguments: unsafe extern "stdcall" fn(*mut c_void, args: *const *const c_char, argCount: c_int) -> SlangResult, - pub addTranslationUnit: unsafe extern "stdcall" fn(*mut c_void, language: SlangSourceLanguage, name: *const c_char) -> c_int, - pub setDefaultModuleName: unsafe extern "stdcall" fn(*mut c_void, defaultModuleName: *const c_char), - pub addTranslationUnitPreprocessorDefine: unsafe extern "stdcall" fn(*mut c_void, translationUnitIndex: c_int, key: *const c_char, value: *const c_char), - pub addTranslationUnitSourceFile: unsafe extern "stdcall" fn(*mut c_void, translationUnitIndex: c_int, path: *const c_char), - pub addTranslationUnitSourceString: unsafe extern "stdcall" fn(*mut c_void, translationUnitIndex: c_int, path: *const c_char, source: *const c_char), - pub addLibraryReference: unsafe extern "stdcall" fn(*mut c_void, basePath: *const c_char, libData: *const c_void, libDataSize: usize) -> SlangResult, - pub addTranslationUnitSourceStringSpan: unsafe extern "stdcall" fn(*mut c_void, translationUnitIndex: c_int, path: *const c_char, sourceBegin: *const c_char, sourceEnd: *const c_char), - pub addTranslationUnitSourceBlob: unsafe extern "stdcall" fn(*mut c_void, translationUnitIndex: c_int, path: *const c_char, sourceBlob: *mut ISlangBlob), - pub addEntryPoint: unsafe extern "stdcall" fn(*mut c_void, translationUnitIndex: c_int, name: *const c_char, stage: SlangStage) -> c_int, - pub addEntryPointEx: unsafe extern "stdcall" fn(*mut c_void, translationUnitIndex: c_int, name: *const c_char, stage: SlangStage, genericArgCount: c_int, genericArgs: *const *const c_char) -> c_int, - pub setGlobalGenericArgs: unsafe extern "stdcall" fn(*mut c_void, genericArgCount: c_int, genericArgs: *const *const c_char) -> SlangResult, - pub setTypeNameForGlobalExistentialTypeParam: unsafe extern "stdcall" fn(*mut c_void, slotIndex: c_int, typeName: *const c_char) -> SlangResult, - pub setTypeNameForEntryPointExistentialTypeParam: unsafe extern "stdcall" fn(*mut c_void, entryPointIndex: c_int, slotIndex: c_int, typeName: *const c_char) -> SlangResult, - pub setAllowGLSLInput: unsafe extern "stdcall" fn(*mut c_void, value: bool), - pub compile: unsafe extern "stdcall" fn(*mut c_void) -> SlangResult, - pub getDiagnosticOutput: unsafe extern "stdcall" fn(*mut c_void) -> *const c_char, - pub getDiagnosticOutputBlob: unsafe extern "stdcall" fn(*mut c_void, outBlob: *mut *mut ISlangBlob) -> SlangResult, - pub getDependencyFileCount: unsafe extern "stdcall" fn(*mut c_void) -> c_int, - pub getDependencyFilePath: unsafe extern "stdcall" fn(*mut c_void, index: c_int) -> *const c_char, - pub getTranslationUnitCount: unsafe extern "stdcall" fn(*mut c_void) -> c_int, - pub getEntryPointSource: unsafe extern "stdcall" fn(*mut c_void, entryPointIndex: c_int) -> *const c_char, - pub getEntryPointCode: unsafe extern "stdcall" fn(*mut c_void, entryPointIndex: c_int, outSize: *mut usize) -> *const c_void, - pub getEntryPointCodeBlob: unsafe extern "stdcall" fn(*mut c_void, entryPointIndex: c_int, targetIndex: c_int, outBlob: *mut *mut ISlangBlob) -> SlangResult, - pub getEntryPointHostCallable: unsafe extern "stdcall" fn(*mut c_void, entryPointIndex: c_int, targetIndex: c_int, outSharedLibrary: *mut *mut ISlangSharedLibrary) -> SlangResult, - pub getTargetCodeBlob: unsafe extern "stdcall" fn(*mut c_void, targetIndex: c_int, outBlob: *mut *mut ISlangBlob) -> SlangResult, - pub getTargetHostCallable: unsafe extern "stdcall" fn(*mut c_void, targetIndex: c_int, outSharedLibrary: *mut *mut ISlangSharedLibrary) -> SlangResult, - pub getCompileRequestCode: unsafe extern "stdcall" fn(*mut c_void, outSize: *mut usize) -> *const c_void, - pub getCompileRequestResultAsFileSystem: unsafe extern "stdcall" fn(*mut c_void) -> *mut ISlangMutableFileSystem, - pub getContainerCode: unsafe extern "stdcall" fn(*mut c_void, outBlob: *mut *mut ISlangBlob) -> SlangResult, - pub loadRepro: unsafe extern "stdcall" fn(*mut c_void, fileSystem: *mut ISlangFileSystem, data: *const c_void, size: usize) -> SlangResult, - pub saveRepro: unsafe extern "stdcall" fn(*mut c_void, outBlob: *mut *mut ISlangBlob) -> SlangResult, - pub enableReproCapture: unsafe extern "stdcall" fn(*mut c_void) -> SlangResult, - pub getProgram: unsafe extern "stdcall" fn(*mut c_void, outProgram: *mut *mut slang_IComponentType) -> SlangResult, - pub getEntryPoint: unsafe extern "stdcall" fn(*mut c_void, entryPointIndex: c_int, outEntryPoint: *mut *mut slang_IComponentType) -> SlangResult, - pub getModule: unsafe extern "stdcall" fn(*mut c_void, translationUnitIndex: c_int, outModule: *mut *mut slang_IModule) -> SlangResult, - pub getSession: unsafe extern "stdcall" fn(*mut c_void, outSession: *mut *mut slang_ISession) -> SlangResult, - pub getReflection: unsafe extern "stdcall" fn(*mut c_void) -> *mut SlangReflection, - pub setCommandLineCompilerMode: unsafe extern "stdcall" fn(*mut c_void), - pub addTargetCapability: unsafe extern "stdcall" fn(*mut c_void, targetIndex: c_int, capability: SlangCapabilityID) -> SlangResult, - pub getProgramWithEntryPoints: unsafe extern "stdcall" fn(*mut c_void, outProgram: *mut *mut slang_IComponentType) -> SlangResult, - pub isParameterLocationUsed: unsafe extern "stdcall" fn(*mut c_void, entryPointIndex: c_int, targetIndex: c_int, category: SlangParameterCategory, spaceIndex: SlangUInt, registerIndex: SlangUInt, outUsed: *mut bool) -> SlangResult, - pub setTargetLineDirectiveMode: unsafe extern "stdcall" fn(*mut c_void, targetIndex: c_int, mode: SlangLineDirectiveMode), - pub setTargetForceGLSLScalarBufferLayout: unsafe extern "stdcall" fn(*mut c_void, targetIndex: c_int, forceScalarLayout: bool), - pub overrideDiagnosticSeverity: unsafe extern "stdcall" fn(*mut c_void, messageID: SlangInt, overrideSeverity: SlangSeverity), - pub getDiagnosticFlags: unsafe extern "stdcall" fn(*mut c_void) -> SlangDiagnosticFlags, - pub setDiagnosticFlags: unsafe extern "stdcall" fn(*mut c_void, flags: SlangDiagnosticFlags), - pub setDebugInfoFormat: unsafe extern "stdcall" fn(*mut c_void, debugFormat: SlangDebugInfoFormat), - pub setEnableEffectAnnotations: unsafe extern "stdcall" fn(*mut c_void, value: bool), - pub setReportDownstreamTime: unsafe extern "stdcall" fn(*mut c_void, value: bool), - pub setReportPerfBenchmark: unsafe extern "stdcall" fn(*mut c_void, value: bool), - pub setSkipSPIRVValidation: unsafe extern "stdcall" fn(*mut c_void, value: bool), + pub getGlobalSession: unsafe extern "stdcall" fn(*mut c_void) -> *mut slang_IGlobalSession, + pub loadModule: unsafe extern "stdcall" fn(*mut c_void, moduleName: *const c_char, outDiagnostics: *mut *mut ISlangBlob) -> *mut slang_IModule, + pub loadModuleFromSource: unsafe extern "stdcall" fn(*mut c_void, moduleName: *const c_char, path: *const c_char, source: *mut ISlangBlob, outDiagnostics: *mut *mut ISlangBlob) -> *mut slang_IModule, + pub createCompositeComponentType: unsafe extern "stdcall" fn(*mut c_void, componentTypes: *const *const slang_IComponentType, componentTypeCount: SlangInt, outCompositeComponentType: *mut *mut slang_IComponentType, outDiagnostics: *mut *mut ISlangBlob) -> SlangResult, + pub specializeType: unsafe extern "stdcall" fn(*mut c_void, type_: *mut slang_TypeReflection, specializationArgs: *const slang_SpecializationArg, specializationArgCount: SlangInt, outDiagnostics: *mut *mut ISlangBlob) -> *mut slang_TypeReflection, + pub getTypeLayout: unsafe extern "stdcall" fn(*mut c_void, type_: *mut slang_TypeReflection, targetIndex: SlangInt, rules: slang_LayoutRules, outDiagnostics: *mut *mut ISlangBlob) -> *mut slang_TypeLayoutReflection, + pub getContainerType: unsafe extern "stdcall" fn(*mut c_void, elementType: *mut slang_TypeReflection, containerType: slang_ContainerType, outDiagnostics: *mut *mut ISlangBlob) -> *mut slang_TypeReflection, + pub getDynamicType: unsafe extern "stdcall" fn(*mut c_void) -> *mut slang_TypeReflection, + pub getTypeRTTIMangledName: unsafe extern "stdcall" fn(*mut c_void, type_: *mut slang_TypeReflection, outNameBlob: *mut *mut ISlangBlob) -> SlangResult, + pub getTypeConformanceWitnessMangledName: unsafe extern "stdcall" fn(*mut c_void, type_: *mut slang_TypeReflection, interfaceType: *mut slang_TypeReflection, outNameBlob: *mut *mut ISlangBlob) -> SlangResult, + pub getTypeConformanceWitnessSequentialID: unsafe extern "stdcall" fn(*mut c_void, type_: *mut slang_TypeReflection, interfaceType: *mut slang_TypeReflection, outId: *mut u32) -> SlangResult, + pub createCompileRequest: unsafe extern "stdcall" fn(*mut c_void, outCompileRequest: *mut *mut slang_ICompileRequest) -> SlangResult, + pub createTypeConformanceComponentType: unsafe extern "stdcall" fn(*mut c_void, type_: *mut slang_TypeReflection, interfaceType: *mut slang_TypeReflection, outConformance: *mut *mut slang_ITypeConformance, conformanceIdOverride: SlangInt, outDiagnostics: *mut *mut ISlangBlob) -> SlangResult, + pub loadModuleFromIRBlob: unsafe extern "stdcall" fn(*mut c_void, moduleName: *const c_char, path: *const c_char, source: *mut ISlangBlob, outDiagnostics: *mut *mut ISlangBlob) -> *mut slang_IModule, + pub getLoadedModuleCount: unsafe extern "stdcall" fn(*mut c_void) -> SlangInt, + pub getLoadedModule: unsafe extern "stdcall" fn(*mut c_void, index: SlangInt) -> *mut slang_IModule, + pub isBinaryModuleUpToDate: unsafe extern "stdcall" fn(*mut c_void, modulePath: *const c_char, binaryModuleBlob: *mut ISlangBlob) -> bool, + pub loadModuleFromSourceString: unsafe extern "stdcall" fn(*mut c_void, moduleName: *const c_char, path: *const c_char, string: *const c_char, outDiagnostics: *mut *mut ISlangBlob) -> *mut slang_IModule, +} + +#[repr(C)] +pub struct IComponentTypeVtable { + pub _base: ISlangUnknown__bindgen_vtable, + + pub getSession: unsafe extern "stdcall" fn(*mut c_void) -> *mut slang_ISession, + pub getLayout: unsafe extern "stdcall" fn(*mut c_void, targetIndex: SlangInt, outDiagnostics: *mut *mut ISlangBlob) -> *mut slang_ProgramLayout, + pub getSpecializationParamCount: unsafe extern "stdcall" fn(*mut c_void) -> SlangInt, + pub getEntryPointCode: unsafe extern "stdcall" fn(*mut c_void, entryPointIndex: SlangInt, targetIndex: SlangInt, outCode: *mut *mut ISlangBlob, outDiagnostics: *mut *mut ISlangBlob) -> SlangResult, + pub getResultAsFileSystem: unsafe extern "stdcall" fn(*mut c_void, entryPointIndex: SlangInt, targetIndex: SlangInt, outFileSystem: *mut *mut ISlangMutableFileSystem) -> SlangResult, + pub getEntryPointHash: unsafe extern "stdcall" fn(*mut c_void, entryPointIndex: SlangInt, targetIndex: SlangInt, outHash: *mut *mut ISlangBlob), + pub specialize: unsafe extern "stdcall" fn(*mut c_void, specializationArgs: *const slang_SpecializationArg, specializationArgCount: SlangInt, outSpecializedComponentType: *mut *mut slang_IComponentType, outDiagnostics: *mut *mut ISlangBlob) -> SlangResult, + pub link: unsafe extern "stdcall" fn(*mut c_void, outLinkedComponentType: *mut *mut slang_IComponentType, outDiagnostics: *mut *mut ISlangBlob) -> SlangResult, + pub getEntryPointHostCallable: unsafe extern "stdcall" fn(*mut c_void, entryPointIndex: c_int, targetIndex: c_int, outSharedLibrary: *mut *mut ISlangSharedLibrary, outDiagnostics: *mut *mut ISlangBlob) -> SlangResult, + pub renameEntryPoint: unsafe extern "stdcall" fn(*mut c_void, newName: *const c_char, outEntryPoint: *mut *mut slang_IComponentType) -> SlangResult, + pub linkWithOptions: unsafe extern "stdcall" fn(*mut c_void, outLinkedComponentType: *mut *mut slang_IComponentType, compilerOptionEntryCount: u32, compilerOptionEntries: *mut slang_CompilerOptionEntry, outDiagnostics: *mut *mut ISlangBlob) -> SlangResult, +} + +#[repr(C)] +pub struct IEntryPointVtable { + pub _base: IComponentTypeVtable, +} + +#[repr(C)] +pub struct ITypeConformanceVtable { + pub _base: IComponentTypeVtable, +} + +#[repr(C)] +pub struct IModuleVtable { + pub _base: IComponentTypeVtable, + + pub findEntryPointByName: unsafe extern "stdcall" fn(*mut c_void, name: *const c_char, outEntryPoint: *mut *mut slang_IEntryPoint) -> SlangResult, + pub getDefinedEntryPointCount: unsafe extern "stdcall" fn(*mut c_void) -> SlangInt32, + pub getDefinedEntryPoint: unsafe extern "stdcall" fn(*mut c_void, index: SlangInt32, outEntryPoint: *mut *mut slang_IEntryPoint) -> SlangResult, + pub serialize: unsafe extern "stdcall" fn(*mut c_void, outSerializedBlob: *mut *mut ISlangBlob) -> SlangResult, + pub writeToFile: unsafe extern "stdcall" fn(*mut c_void, fileName: *const c_char) -> SlangResult, + pub getName: unsafe extern "stdcall" fn(*mut c_void) -> *const c_char, + pub getFilePath: unsafe extern "stdcall" fn(*mut c_void) -> *const c_char, + pub getUniqueIdentity: unsafe extern "stdcall" fn(*mut c_void) -> *const c_char, } diff --git a/src/lib.rs b/src/lib.rs index de8c876..e8f1124 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,15 +1,24 @@ -use slang_sys as sys; use std::ffi::{CStr, CString}; -use std::path::Path; -use std::slice; +use std::ptr::{null, null_mut}; -pub use sys::SlangUUID as UUID; -pub use sys::SlangCompileTarget as CompileTarget; -pub use sys::SlangMatrixLayoutMode as MatrixLayoutMode; -pub use sys::SlangOptimizationLevel as OptimizationLevel; -pub use sys::SlangSourceLanguage as SourceLanguage; -pub use sys::SlangStage as Stage; -pub use sys::SlangProfileID as ProfileID; +use slang_sys as sys; + +pub use sys::{ + slang_CompilerOptionName as CompilerOptionName, + slang_SessionDesc as SessionDesc, + slang_TargetDesc as TargetDesc, + SlangCapabilityID as CapabilityID, + SlangCompileTarget as CompileTarget, + SlangDebugInfoLevel as DebugInfoLevel, + SlangFloatingPointMode as FloatingPointMode, + SlangLineDirectiveMode as LineDirectiveMode, + SlangMatrixLayoutMode as MatrixLayoutMode, + SlangOptimizationLevel as OptimizationLevel, + SlangProfileID as ProfileID, + SlangSourceLanguage as SourceLanguage, + SlangStage as Stage, + SlangUUID as UUID, +}; macro_rules! vcall { ($self:expr, $method:ident($($args:expr),*)) => { @@ -34,6 +43,15 @@ unsafe trait Interface: Sized { unsafe fn as_raw(&self) -> *mut T { std::mem::transmute_copy(self) } + + fn as_unknown(&self) -> &IUnknown { + // SAFETY: It is always safe to treat an `Interface` as an `IUnknown`. + unsafe { std::mem::transmute(self) } + } +} + +pub unsafe trait Downcast { + fn downcast(&self) -> &T; } #[repr(transparent)] @@ -41,7 +59,7 @@ pub struct IUnknown(std::ptr::NonNull); unsafe impl Interface for IUnknown { type Vtable = sys::ISlangUnknown__bindgen_vtable; - const IID: UUID = uuid(0x00000000, 0x0000, 0x0000, [0xC0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x46]); + const IID: UUID = uuid(0x00000000, 0x0000, 0x0000, [0xc0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x46]); } impl Clone for IUnknown { @@ -57,170 +75,416 @@ impl Drop for IUnknown { } } +#[repr(transparent)] +#[derive(Clone)] +pub struct Blob(IUnknown); + +unsafe impl Interface for Blob { + type Vtable = sys::IBlobVtable; + const IID: UUID = uuid(0x8ba5fb08, 0x5195, 0x40e2, [0xac, 0x58, 0x0d, 0x98, 0x9c, 0x3a, 0x01, 0x02]); +} + +impl Blob { + pub fn as_slice(&self) -> &[u8] { + let ptr = vcall!(self, getBufferPointer()); + let size = vcall!(self, getBufferSize()); + unsafe { std::slice::from_raw_parts(ptr as *const u8, size) } + } +} + #[repr(transparent)] #[derive(Clone)] pub struct GlobalSession(IUnknown); unsafe impl Interface for GlobalSession { type Vtable = sys::IGlobalSessionVtable; - const IID: UUID = uuid(0xc140b5fd, 0xc78, 0x452e, [0xba, 0x7c, 0x1a, 0x1e, 0x70, 0xc7, 0xf7, 0x1c]); + const IID: UUID = uuid(0xc140b5fd, 0x0c78, 0x452e, [0xba, 0x7c, 0x1a, 0x1e, 0x70, 0xc7, 0xf7, 0x1c]); } impl GlobalSession { - pub fn new() -> GlobalSession { - unsafe { - let mut global_session = std::ptr::null_mut(); - sys::slang_createGlobalSession(sys::SLANG_API_VERSION as _, &mut global_session); - GlobalSession(IUnknown(std::ptr::NonNull::new(global_session as *mut _).unwrap())) - } + pub fn new() -> Option { + let mut global_session = null_mut(); + unsafe { sys::slang_createGlobalSession(sys::SLANG_API_VERSION as _, &mut global_session) }; + Some(GlobalSession(IUnknown(std::ptr::NonNull::new(global_session as *mut _)?))) } - pub fn new_without_std_lib() -> GlobalSession { - unsafe { - let mut global_session = std::ptr::null_mut(); - sys::slang_createGlobalSessionWithoutStdLib(sys::SLANG_API_VERSION as _, &mut global_session); - GlobalSession(IUnknown(std::ptr::NonNull::new(global_session as *mut _).unwrap())) - } + pub fn new_without_std_lib() -> Option { + let mut global_session = null_mut(); + unsafe { sys::slang_createGlobalSessionWithoutStdLib(sys::SLANG_API_VERSION as _, &mut global_session) }; + Some(GlobalSession(IUnknown(std::ptr::NonNull::new(global_session as *mut _)?))) } - pub fn create_compile_request(&self) -> CompileRequest { - let mut compile_request = std::ptr::null_mut(); - vcall!(self, createCompileRequest(&mut compile_request)); - CompileRequest(IUnknown(std::ptr::NonNull::new(compile_request).unwrap())) + pub fn create_session(&self, desc: &SessionDesc) -> Option { + let mut session = null_mut(); + let res = vcall!(self, createSession(desc, &mut session)); + let session = Session(IUnknown(std::ptr::NonNull::new(session as *mut _)?)); + + // TODO: Without adding an extra reference, the code crashes when Session is dropped. + // Investigate why this is happening, the current solution could cause a memory leak. + unsafe { (session.as_unknown().vtable().ISlangUnknown_addRef)(session.as_raw()) }; + + Some(session) } pub fn find_profile(&self, name: &str) -> ProfileID { let name = CString::new(name).unwrap(); vcall!(self, findProfile(name.as_ptr())) } + + pub fn find_capability(&self, name: &str) -> CapabilityID { + let name = CString::new(name).unwrap(); + vcall!(self, findCapability(name.as_ptr())) + } } #[repr(transparent)] #[derive(Clone)] -pub struct CompileRequest(IUnknown); +pub struct Session(IUnknown); -unsafe impl Interface for CompileRequest { - type Vtable = sys::ICompileRequestVtable; - const IID: UUID = uuid(0x96d33993, 0x317c, 0x4db5, [0xaf, 0xd8, 0x66, 0x6e, 0xe7, 0x72, 0x48, 0xe2]); +unsafe impl Interface for Session { + type Vtable = sys::ISessionVtable; + const IID: UUID = uuid(0x67618701, 0xd116, 0x468f, [0xab, 0x3b, 0x47, 0x4b, 0xed, 0xce, 0x0e, 0x3d]); } -impl CompileRequest { - pub fn set_codegen_target(&mut self, target: CompileTarget) -> &mut Self { - vcall!(self, setCodeGenTarget(target)); - self - } - - pub fn set_matrix_layout_mode(&mut self, layout: MatrixLayoutMode) -> &mut Self { - vcall!(self, setMatrixLayoutMode(layout)); - self - } - - pub fn set_optimization_level(&mut self, level: OptimizationLevel) -> &mut Self { - vcall!(self, setOptimizationLevel(level)); - self - } - - pub fn set_target_profile(&mut self, profile: ProfileID) ->&mut Self { - vcall!(self, setTargetProfile(0, profile)); - self - } - - pub fn add_preprocessor_define(&mut self, key: &str, value: &str) -> &mut Self { - let key = CString::new(key).unwrap(); - let value = CString::new(value).unwrap(); - vcall!(self, addPreprocessorDefine(key.as_ptr(), value.as_ptr())); - self - } - - pub fn add_search_path(&mut self, path: impl AsRef) -> &mut Self { - let path = CString::new(path.as_ref().to_str().unwrap()).unwrap(); - vcall!(self, addSearchPath(path.as_ptr())); - self - } - - pub fn add_translation_unit(&mut self, source_language: SourceLanguage, name: Option<&str>) -> TranslationUnit { - let name = CString::new(name.unwrap_or("")).unwrap(); - let index = vcall!(self, addTranslationUnit(source_language, name.as_ptr())); - - TranslationUnit { - request: self, - index, - } - } - - pub fn compile(self) -> Result { - let r = vcall!(self, compile()); - - if r < 0 { - let out = vcall!(self, getDiagnosticOutput()); - let errors = unsafe { CStr::from_ptr(out).to_str().unwrap().to_string() }; - - Err(CompilationErrors { errors }) - } else { - Ok(CompiledRequest { request: self }) - } - } -} - -pub struct TranslationUnit<'a> { - request: &'a mut CompileRequest, - index: i32, -} - -impl<'a> TranslationUnit<'a> { - pub fn add_preprocessor_define(&mut self, key: &str, value: &str) -> &mut Self { - let key = CString::new(key).unwrap(); - let value = CString::new(value).unwrap(); - vcall!(self.request, addTranslationUnitPreprocessorDefine(self.index, key.as_ptr(), value.as_ptr())); - self - } - - pub fn add_source_file(&mut self, path: impl AsRef) -> &mut Self { - let path = CString::new(path.as_ref().to_str().unwrap()).unwrap(); - vcall!(self.request, addTranslationUnitSourceFile(self.index, path.as_ptr())); - self - } - - pub fn add_source_string(&mut self, path: impl AsRef, source: &str) -> &mut Self { - let path = CString::new(path.as_ref().to_str().unwrap()).unwrap(); - let source = CString::new(source).unwrap(); - vcall!(self.request, addTranslationUnitSourceString(self.index, path.as_ptr(), source.as_ptr())); - self - } - - pub fn add_entry_point(&mut self, name: &str, stage: Stage) -> EntryPointIndex { +impl Session { + pub fn load_module(&self, name: &str) -> Result { let name = CString::new(name).unwrap(); - let index = vcall!(self.request, addEntryPoint(self.index, name.as_ptr(), stage)); - EntryPointIndex(index) + let mut diagnostics = null_mut(); + + let module = vcall!(self, loadModule(name.as_ptr(), &mut diagnostics)); + + if module.is_null() { + let blob = Blob(IUnknown(std::ptr::NonNull::new(diagnostics as *mut _).unwrap())); + Err(std::str::from_utf8(blob.as_slice()).unwrap().to_string()) + } else { + Ok(Module(IUnknown(std::ptr::NonNull::new(module as *mut _).unwrap()))) + } + } + + pub fn create_composite_component_type(&self, components: &[&ComponentType]) -> ComponentType { + let components: Vec<*mut std::ffi::c_void> = unsafe { components.iter().map(|c| c.as_raw()).collect() }; + + let mut composite_component_type = null_mut(); + let mut diagnostics = null_mut(); + let res = vcall!(self, createCompositeComponentType(components.as_ptr() as _, components.len() as _, &mut composite_component_type, &mut diagnostics)); + ComponentType(IUnknown(std::ptr::NonNull::new(composite_component_type as *mut _).unwrap())) } } -pub struct CompiledRequest { - request: CompileRequest, +#[repr(transparent)] +#[derive(Clone)] +pub struct ComponentType(IUnknown); + +unsafe impl Interface for ComponentType { + type Vtable = sys::IComponentTypeVtable; + const IID: UUID = uuid(0x5bc42be8, 0x5c50, 0x4929, [0x9e, 0x5e, 0xd1, 0x5e, 0x7c, 0x24, 0x01, 0x5f]); } -impl CompiledRequest { - pub fn get_entry_point_code(&self, index: EntryPointIndex) -> &[u8] { - let mut out_size = 0; - let ptr = vcall!(self.request, getEntryPointCode(index.0, &mut out_size)); - unsafe { slice::from_raw_parts(ptr as *const u8, out_size) } +impl ComponentType { + pub fn link(&self) -> ComponentType { + let mut linked_component_type = null_mut(); + let mut diagnostics = null_mut(); + let res = vcall!(self, link(&mut linked_component_type, &mut diagnostics)); + + if linked_component_type.is_null() { + let blob = Blob(IUnknown(std::ptr::NonNull::new(diagnostics as *mut _).unwrap())); + let error = std::str::from_utf8(blob.as_slice()).unwrap().to_string(); + println!("Error: {}", error); + } + + ComponentType(IUnknown(std::ptr::NonNull::new(linked_component_type as *mut _).unwrap())) + } + + pub fn get_entry_point_code(&self, index: i64, target: i64) -> Vec { + let mut code = null_mut(); + let mut diagnostics = null_mut(); + let res = vcall!(self, getEntryPointCode(index, target, &mut code, &mut diagnostics)); + + if code.is_null() { + let blob = Blob(IUnknown(std::ptr::NonNull::new(diagnostics as *mut _).unwrap())); + let error = std::str::from_utf8(blob.as_slice()).unwrap().to_string(); + println!("Error: {}", error); + } + + let blob = Blob(IUnknown(std::ptr::NonNull::new(code as *mut _).unwrap())); + Vec::from(blob.as_slice()) } } -#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)] -pub struct EntryPointIndex(pub i32); +#[repr(transparent)] +#[derive(Clone)] +pub struct EntryPoint(IUnknown); -pub struct CompilationErrors { - errors: String, +unsafe impl Interface for EntryPoint { + type Vtable = sys::IEntryPointVtable; + const IID: UUID = uuid(0x8f241361, 0xf5bd, 0x4ca0, [0xa3, 0xac, 0x02, 0xf7, 0xfa, 0x24, 0x02, 0xb8]); } -impl std::fmt::Debug for CompilationErrors { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("\n")?; - f.write_str(&self.errors) +unsafe impl Downcast for EntryPoint { + fn downcast(&self) -> &ComponentType { + unsafe { std::mem::transmute(self) } } } +#[repr(transparent)] +#[derive(Clone)] +pub struct TypeConformance(IUnknown); + +unsafe impl Interface for TypeConformance { + type Vtable = sys::ITypeConformanceVtable; + const IID: UUID = uuid(0x73eb3147, 0xe544, 0x41b5, [0xb8, 0xf0, 0xa2, 0x44, 0xdf, 0x21, 0x94, 0x0b]); +} + +unsafe impl Downcast for TypeConformance { + fn downcast(&self) -> &ComponentType { + unsafe { std::mem::transmute(self) } + } +} + +#[repr(transparent)] +#[derive(Clone)] +pub struct Module(IUnknown); + +unsafe impl Interface for Module { + type Vtable = sys::IModuleVtable; + const IID: UUID = uuid(0x0c720e64, 0x8722, 0x4d31, [0x89, 0x90, 0x63, 0x8a, 0x98, 0xb1, 0xc2, 0x79]); +} + +unsafe impl Downcast for Module { + fn downcast(&self) -> &ComponentType { + unsafe { std::mem::transmute(self) } + } +} + +impl Module { + pub fn find_entry_point_by_name(&self, name: &str) -> Option { + let name = CString::new(name).unwrap(); + let mut entry_point = null_mut(); + vcall!(self, findEntryPointByName(name.as_ptr(), &mut entry_point)); + Some(EntryPoint(IUnknown(std::ptr::NonNull::new(entry_point as *mut _)?))) + } + + pub fn name(&self) -> &str { + let name = vcall!(self, getName()); + unsafe { CStr::from_ptr(name).to_str().unwrap() } + } + + pub fn file_path(&self) -> &str { + let path = vcall!(self, getFilePath()); + unsafe { CStr::from_ptr(path).to_str().unwrap() } + } + + pub fn unique_identity(&self) -> &str { + let identity = vcall!(self, getUniqueIdentity()); + unsafe { CStr::from_ptr(identity).to_str().unwrap() } + } +} + +pub struct TargetDescBuilder { + inner: TargetDesc, +} + +impl TargetDescBuilder { + pub fn new() -> TargetDescBuilder { + Self { + inner: TargetDesc { + structureSize: std::mem::size_of::(), + ..unsafe { std::mem::zeroed() } + } + } + } + + pub fn format(mut self, format: CompileTarget) -> Self { + self.inner.format = format; + self + } + + pub fn profile(mut self, profile: ProfileID) -> Self { + self.inner.profile = profile; + self + } + + pub fn options(mut self, options: &OptionsBuilder) -> Self { + self.inner.compilerOptionEntries = options.options.as_ptr() as _; + self.inner.compilerOptionEntryCount = options.options.len() as _; + self + } +} + +impl std::ops::Deref for TargetDescBuilder { + type Target = TargetDesc; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +pub struct SessionDescBuilder { + inner: SessionDesc, +} + +impl SessionDescBuilder { + pub fn new() -> SessionDescBuilder { + Self { + inner: SessionDesc { + structureSize: std::mem::size_of::(), + ..unsafe { std::mem::zeroed() } + } + } + } + + pub fn targets(mut self, targets: &[TargetDesc]) -> Self { + self.inner.targets = targets.as_ptr(); + self.inner.targetCount = targets.len() as _; + self + } + + pub fn search_paths(mut self, paths: &[*const i8]) -> Self { + self.inner.searchPaths = paths.as_ptr(); + self.inner.searchPathCount = paths.len() as _; + self + } + + pub fn options(mut self, options: &OptionsBuilder) -> Self { + self.inner.compilerOptionEntries = options.options.as_ptr() as _; + self.inner.compilerOptionEntryCount = options.options.len() as _; + self + } +} + +impl std::ops::Deref for SessionDescBuilder { + type Target = SessionDesc; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +macro_rules! option { + ($name:ident, $func:ident($p_name:ident: $p_type:ident)) => { + #[inline(always)] + pub fn $func(self, $p_name: $p_type) -> Self { + self.push_ints(CompilerOptionName::$name, $p_name as _, 0) + } + }; + + ($name:ident, $func:ident($p_name:ident: &str)) => { + #[inline(always)] + pub fn $func(self, $p_name: &str) -> Self { + self.push_str1(CompilerOptionName::$name, $p_name) + } + }; + + ($name:ident, $func:ident($p_name1:ident: &str, $p_name2:ident: &str)) => { + #[inline(always)] + pub fn $func(self, $p_name1: &str, $p_name2: &str) -> Self { + self.push_str2(CompilerOptionName::$name, $p_name1, $p_name2) + } + }; +} + +pub struct OptionsBuilder { + strings: Vec, + options: Vec, +} + +impl OptionsBuilder { + pub fn new() -> OptionsBuilder { + OptionsBuilder { + strings: Vec::new(), + options: Vec::new(), + } + } + + fn push_ints(mut self, name: CompilerOptionName, i0: i32, i1: i32) -> Self { + self.options.push(sys::slang_CompilerOptionEntry { + name, + value: sys::slang_CompilerOptionValue { + kind: sys::slang_CompilerOptionValueKind::Int, + intValue0: i0, + intValue1: i1, + stringValue0: null(), + stringValue1: null(), + }, + }); + + self + } + + fn push_strings(mut self, name: CompilerOptionName, s0: *const i8, s1: *const i8) -> Self { + self.options.push(sys::slang_CompilerOptionEntry { + name, + value: sys::slang_CompilerOptionValue { + kind: sys::slang_CompilerOptionValueKind::String, + intValue0: 0, + intValue1: 0, + stringValue0: s0, + stringValue1: s1, + }, + }); + + self + } + + fn push_str1(mut self, name: CompilerOptionName, s0: &str) -> Self { + let s0 = CString::new(s0).unwrap(); + let s0_ptr = s0.as_ptr(); + self.strings.push(s0); + + self.push_strings(name, s0_ptr, null()) + } + + fn push_str2(mut self, name: CompilerOptionName, s0: &str, s1: &str) -> Self { + let s0 = CString::new(s0).unwrap(); + let s0_ptr = s0.as_ptr(); + self.strings.push(s0); + + let s1 = CString::new(s1).unwrap(); + let s1_ptr = s1.as_ptr(); + self.strings.push(s1); + + self.push_strings(name, s0_ptr, s1_ptr) + } +} + +impl OptionsBuilder { + option!(MacroDefine, macro_define(key: &str, value: &str)); + option!(Include, include(path: &str)); + option!(Language, language(language: SourceLanguage)); + option!(MatrixLayoutColumn, matrix_layout_column(enable: bool)); + option!(MatrixLayoutRow, matrix_layout_row(enable: bool)); + option!(Profile, profile(profile: ProfileID)); + option!(Stage, stage(stage: Stage)); + option!(Target, target(target: CompileTarget)); + option!(WarningsAsErrors, warnings_as_errors(warning_codes: &str)); + option!(DisableWarnings, disable_warnings(warning_codes: &str)); + option!(EnableWarning, enable_warning(warning_code: &str)); + option!(DisableWarning, disable_warning(warning_code: &str)); + option!(ReportDownstreamTime, report_downstream_time(enable: bool)); + option!(ReportPerfBenchmark, report_perf_benchmark(enable: bool)); + option!(SkipSPIRVValidation, skip_spirv_validation(enable: bool)); + + // Target + option!(Capability, capability(capability: CapabilityID)); + option!(DefaultImageFormatUnknown, default_image_format_unknown(enable: bool)); + option!(DisableDynamicDispatch, disable_dynamic_dispatch(enable: bool)); + option!(DisableSpecialization, disable_specialization(enable: bool)); + option!(FloatingPointMode, floating_point_mode(mode: FloatingPointMode)); + option!(DebugInformation, debug_information(level: DebugInfoLevel)); + option!(LineDirectiveMode, line_directive_mode(mode: LineDirectiveMode)); + option!(Optimization, optimization(level: OptimizationLevel)); + option!(Obfuscate, obfuscate(enable: bool)); + option!(GLSLForceScalarLayout, glsl_force_scalar_layout(enable: bool)); + option!(EmitSpirvDirectly, emit_spirv_directly(enable: bool)); + + // Debugging + option!(NoCodeGen, no_code_gen(enable: bool)); + + // Experimental + option!(NoMangle, no_mangle(enable: bool)); + option!(ValidateUniformity, validate_uniformity(enable: bool)); +} + #[cfg(test)] mod tests { #[test]