From 0efe5f8115ec9d19d9e73322ee8a2c91654682d4 Mon Sep 17 00:00:00 2001 From: Lauro Oyen Date: Wed, 18 Sep 2024 18:26:08 +0200 Subject: [PATCH] Improve error handling --- README.md | 8 +-- slang-sys/src/lib.rs | 138 +++++++++++++++++++++---------------------- src/lib.rs | 114 +++++++++++++++++++++-------------- 3 files changed, 144 insertions(+), 116 deletions(-) diff --git a/README.md b/README.md index 2c64838..358f692 100644 --- a/README.md +++ b/README.md @@ -32,12 +32,12 @@ let module = session.load_module("filename.slang").unwrap(); let entry_point = module.find_entry_point_by_name("main").unwrap(); let program = session.create_composite_component_type(&[ - module.downcast(), entry_point.downcast(), -]); + module.downcast().clone(), entry_point.downcast().clone(), +]).unwrap(); -let linked_program = program.link(); +let linked_program = program.link().unwrap(); -let shader_bytecode = linked_program.get_entry_point_code(0, 0); +let shader_bytecode = linked_program.get_entry_point_code(0, 0).unwrap(); ``` ## Installation diff --git a/slang-sys/src/lib.rs b/slang-sys/src/lib.rs index cc30955..dbe0909 100644 --- a/slang-sys/src/lib.rs +++ b/slang-sys/src/lib.rs @@ -10,84 +10,84 @@ use std::ffi::{c_char, c_int, c_void}; pub struct IBlobVtable { pub _base: ISlangUnknown__bindgen_vtable, - pub getBufferPointer: unsafe extern "stdcall" fn(*mut c_void) -> *const c_void, - pub getBufferSize: unsafe extern "stdcall" fn(*mut c_void) -> usize, + pub getBufferPointer: unsafe extern "C" fn(*mut c_void) -> *const c_void, + pub getBufferSize: unsafe extern "C" fn(*mut c_void) -> usize, } #[repr(C)] pub struct IGlobalSessionVtable { pub _base: ISlangUnknown__bindgen_vtable, - pub createSession: unsafe extern "stdcall" fn(*mut c_void, desc: *const slang_SessionDesc, outSession: *mut *mut slang_ISession) -> SlangResult, - pub findProfile: unsafe extern "stdcall" fn(*mut c_void, name: *const c_char) -> SlangProfileID, - pub setDownstreamCompilerPath: unsafe extern "stdcall" fn(*mut c_void, passThrough: SlangPassThrough, path: *const c_char), + pub createSession: unsafe extern "C" fn(*mut c_void, desc: *const slang_SessionDesc, outSession: *mut *mut slang_ISession) -> SlangResult, + pub findProfile: unsafe extern "C" fn(*mut c_void, name: *const c_char) -> SlangProfileID, + pub setDownstreamCompilerPath: unsafe extern "C" fn(*mut c_void, passThrough: SlangPassThrough, path: *const c_char), #[deprecated( note = "Use setLanguagePrelude instead")] - pub setDownstreamCompilerPrelude: unsafe extern "stdcall" fn(*mut c_void, passThrough: SlangPassThrough, preludeText: *const c_char), + pub setDownstreamCompilerPrelude: unsafe extern "C" fn(*mut c_void, passThrough: SlangPassThrough, preludeText: *const c_char), #[deprecated( note = "Use getLanguagePrelude instead")] - pub getDownstreamCompilerPrelude: unsafe extern "stdcall" fn(*mut c_void, passThrough: SlangPassThrough, outPrelude: *mut *mut ISlangBlob), - pub getBuildTagString: unsafe extern "stdcall" fn(*mut c_void) -> *const c_char, - pub setDefaultDownstreamCompiler: unsafe extern "stdcall" fn(*mut c_void, sourceLanguage: SlangSourceLanguage, defaultCompiler: SlangPassThrough) -> SlangResult, - pub getDefaultDownstreamCompiler: unsafe extern "stdcall" fn(*mut c_void, sourceLanguage: SlangSourceLanguage) -> SlangPassThrough, - pub setLanguagePrelude: unsafe extern "stdcall" fn(*mut c_void, sourceLanguage: SlangSourceLanguage, preludeText: *const c_char), - pub getLanguagePrelude: unsafe extern "stdcall" fn(*mut c_void, sourceLanguage: SlangSourceLanguage, outPrelude: *mut *mut ISlangBlob), - pub createCompileRequest: unsafe extern "stdcall" fn(*mut c_void, *mut *mut slang_ICompileRequest) -> SlangResult, - pub addBuiltins: unsafe extern "stdcall" fn(*mut c_void, sourcePath: *const c_char, sourceString: *const c_char), - pub setSharedLibraryLoader: unsafe extern "stdcall" fn(*mut c_void, loader: *mut ISlangSharedLibraryLoader), - pub getSharedLibraryLoader: unsafe extern "stdcall" fn(*mut c_void) -> *mut ISlangSharedLibraryLoader, - pub checkCompileTargetSupport: unsafe extern "stdcall" fn(*mut c_void, target: SlangCompileTarget) -> SlangResult, - pub checkPassThroughSupport: unsafe extern "stdcall" fn(*mut c_void, passThrough: SlangPassThrough) -> SlangResult, - pub compileStdLib: unsafe extern "stdcall" fn(*mut c_void, flags: slang_CompileStdLibFlags) -> SlangResult, - pub loadStdLib: unsafe extern "stdcall" fn(*mut c_void, stdLib: *const c_void, stdLibSizeInBytes: usize) -> SlangResult, - pub saveStdLib: unsafe extern "stdcall" fn(*mut c_void, archiveType: SlangArchiveType, outBlob: *mut *mut ISlangBlob) -> SlangResult, - pub findCapability: unsafe extern "stdcall" fn(*mut c_void, name: *const c_char) -> SlangCapabilityID, - pub setDownstreamCompilerForTransition: unsafe extern "stdcall" fn(*mut c_void, source: SlangCompileTarget, target: SlangCompileTarget, compiler: SlangPassThrough), - pub getDownstreamCompilerForTransition: unsafe extern "stdcall" fn(*mut c_void, source: SlangCompileTarget, target: SlangCompileTarget) -> SlangPassThrough, - pub getCompilerElapsedTime: unsafe extern "stdcall" fn(*mut c_void, outTotalTime: *mut f64, outDownstreamTime: *mut f64), - pub setSPIRVCoreGrammar: unsafe extern "stdcall" fn(*mut c_void, jsonPath: *const c_char) -> SlangResult, - pub parseCommandLineArguments: unsafe extern "stdcall" fn(*mut c_void, argc: c_int, argv: *const *const c_char, outSessionDesc: *mut slang_SessionDesc, outAuxAllocation: *mut *mut ISlangUnknown) -> SlangResult, - pub getSessionDescDigest: unsafe extern "stdcall" fn(*mut c_void, sessionDesc: *const slang_SessionDesc, outBlob: *mut *mut ISlangBlob) -> SlangResult, + pub getDownstreamCompilerPrelude: unsafe extern "C" fn(*mut c_void, passThrough: SlangPassThrough, outPrelude: *mut *mut ISlangBlob), + pub getBuildTagString: unsafe extern "C" fn(*mut c_void) -> *const c_char, + pub setDefaultDownstreamCompiler: unsafe extern "C" fn(*mut c_void, sourceLanguage: SlangSourceLanguage, defaultCompiler: SlangPassThrough) -> SlangResult, + pub getDefaultDownstreamCompiler: unsafe extern "C" fn(*mut c_void, sourceLanguage: SlangSourceLanguage) -> SlangPassThrough, + pub setLanguagePrelude: unsafe extern "C" fn(*mut c_void, sourceLanguage: SlangSourceLanguage, preludeText: *const c_char), + pub getLanguagePrelude: unsafe extern "C" fn(*mut c_void, sourceLanguage: SlangSourceLanguage, outPrelude: *mut *mut ISlangBlob), + pub createCompileRequest: unsafe extern "C" fn(*mut c_void, *mut *mut slang_ICompileRequest) -> SlangResult, + pub addBuiltins: unsafe extern "C" fn(*mut c_void, sourcePath: *const c_char, sourceString: *const c_char), + pub setSharedLibraryLoader: unsafe extern "C" fn(*mut c_void, loader: *mut ISlangSharedLibraryLoader), + pub getSharedLibraryLoader: unsafe extern "C" fn(*mut c_void) -> *mut ISlangSharedLibraryLoader, + pub checkCompileTargetSupport: unsafe extern "C" fn(*mut c_void, target: SlangCompileTarget) -> SlangResult, + pub checkPassThroughSupport: unsafe extern "C" fn(*mut c_void, passThrough: SlangPassThrough) -> SlangResult, + pub compileStdLib: unsafe extern "C" fn(*mut c_void, flags: slang_CompileStdLibFlags) -> SlangResult, + pub loadStdLib: unsafe extern "C" fn(*mut c_void, stdLib: *const c_void, stdLibSizeInBytes: usize) -> SlangResult, + pub saveStdLib: unsafe extern "C" fn(*mut c_void, archiveType: SlangArchiveType, outBlob: *mut *mut ISlangBlob) -> SlangResult, + pub findCapability: unsafe extern "C" fn(*mut c_void, name: *const c_char) -> SlangCapabilityID, + pub setDownstreamCompilerForTransition: unsafe extern "C" fn(*mut c_void, source: SlangCompileTarget, target: SlangCompileTarget, compiler: SlangPassThrough), + pub getDownstreamCompilerForTransition: unsafe extern "C" fn(*mut c_void, source: SlangCompileTarget, target: SlangCompileTarget) -> SlangPassThrough, + pub getCompilerElapsedTime: unsafe extern "C" fn(*mut c_void, outTotalTime: *mut f64, outDownstreamTime: *mut f64), + pub setSPIRVCoreGrammar: unsafe extern "C" fn(*mut c_void, jsonPath: *const c_char) -> SlangResult, + pub parseCommandLineArguments: unsafe extern "C" fn(*mut c_void, argc: c_int, argv: *const *const c_char, outSessionDesc: *mut slang_SessionDesc, outAuxAllocation: *mut *mut ISlangUnknown) -> SlangResult, + pub getSessionDescDigest: unsafe extern "C" fn(*mut c_void, sessionDesc: *const slang_SessionDesc, outBlob: *mut *mut ISlangBlob) -> SlangResult, } #[repr(C)] pub struct ISessionVtable { pub _base: ISlangUnknown__bindgen_vtable, - pub getGlobalSession: unsafe extern "stdcall" fn(*mut c_void) -> *mut slang_IGlobalSession, - pub loadModule: unsafe extern "stdcall" fn(*mut c_void, moduleName: *const c_char, outDiagnostics: *mut *mut ISlangBlob) -> *mut slang_IModule, - pub loadModuleFromSource: unsafe extern "stdcall" fn(*mut c_void, moduleName: *const c_char, path: *const c_char, source: *mut ISlangBlob, outDiagnostics: *mut *mut ISlangBlob) -> *mut slang_IModule, - pub createCompositeComponentType: unsafe extern "stdcall" fn(*mut c_void, componentTypes: *const *const slang_IComponentType, componentTypeCount: SlangInt, outCompositeComponentType: *mut *mut slang_IComponentType, outDiagnostics: *mut *mut ISlangBlob) -> SlangResult, - pub specializeType: unsafe extern "stdcall" fn(*mut c_void, type_: *mut slang_TypeReflection, specializationArgs: *const slang_SpecializationArg, specializationArgCount: SlangInt, outDiagnostics: *mut *mut ISlangBlob) -> *mut slang_TypeReflection, - pub getTypeLayout: unsafe extern "stdcall" fn(*mut c_void, type_: *mut slang_TypeReflection, targetIndex: SlangInt, rules: slang_LayoutRules, outDiagnostics: *mut *mut ISlangBlob) -> *mut slang_TypeLayoutReflection, - pub getContainerType: unsafe extern "stdcall" fn(*mut c_void, elementType: *mut slang_TypeReflection, containerType: slang_ContainerType, outDiagnostics: *mut *mut ISlangBlob) -> *mut slang_TypeReflection, - pub getDynamicType: unsafe extern "stdcall" fn(*mut c_void) -> *mut slang_TypeReflection, - pub getTypeRTTIMangledName: unsafe extern "stdcall" fn(*mut c_void, type_: *mut slang_TypeReflection, outNameBlob: *mut *mut ISlangBlob) -> SlangResult, - pub getTypeConformanceWitnessMangledName: unsafe extern "stdcall" fn(*mut c_void, type_: *mut slang_TypeReflection, interfaceType: *mut slang_TypeReflection, outNameBlob: *mut *mut ISlangBlob) -> SlangResult, - pub getTypeConformanceWitnessSequentialID: unsafe extern "stdcall" fn(*mut c_void, type_: *mut slang_TypeReflection, interfaceType: *mut slang_TypeReflection, outId: *mut u32) -> SlangResult, - pub createCompileRequest: unsafe extern "stdcall" fn(*mut c_void, outCompileRequest: *mut *mut slang_ICompileRequest) -> SlangResult, - pub createTypeConformanceComponentType: unsafe extern "stdcall" fn(*mut c_void, type_: *mut slang_TypeReflection, interfaceType: *mut slang_TypeReflection, outConformance: *mut *mut slang_ITypeConformance, conformanceIdOverride: SlangInt, outDiagnostics: *mut *mut ISlangBlob) -> SlangResult, - pub loadModuleFromIRBlob: unsafe extern "stdcall" fn(*mut c_void, moduleName: *const c_char, path: *const c_char, source: *mut ISlangBlob, outDiagnostics: *mut *mut ISlangBlob) -> *mut slang_IModule, - pub getLoadedModuleCount: unsafe extern "stdcall" fn(*mut c_void) -> SlangInt, - pub getLoadedModule: unsafe extern "stdcall" fn(*mut c_void, index: SlangInt) -> *mut slang_IModule, - pub isBinaryModuleUpToDate: unsafe extern "stdcall" fn(*mut c_void, modulePath: *const c_char, binaryModuleBlob: *mut ISlangBlob) -> bool, - pub loadModuleFromSourceString: unsafe extern "stdcall" fn(*mut c_void, moduleName: *const c_char, path: *const c_char, string: *const c_char, outDiagnostics: *mut *mut ISlangBlob) -> *mut slang_IModule, + pub getGlobalSession: unsafe extern "C" fn(*mut c_void) -> *mut slang_IGlobalSession, + pub loadModule: unsafe extern "C" fn(*mut c_void, moduleName: *const c_char, outDiagnostics: *mut *mut ISlangBlob) -> *mut slang_IModule, + pub loadModuleFromSource: unsafe extern "C" fn(*mut c_void, moduleName: *const c_char, path: *const c_char, source: *mut ISlangBlob, outDiagnostics: *mut *mut ISlangBlob) -> *mut slang_IModule, + pub createCompositeComponentType: unsafe extern "C" fn(*mut c_void, componentTypes: *const *const slang_IComponentType, componentTypeCount: SlangInt, outCompositeComponentType: *mut *mut slang_IComponentType, outDiagnostics: *mut *mut ISlangBlob) -> SlangResult, + pub specializeType: unsafe extern "C" fn(*mut c_void, type_: *mut slang_TypeReflection, specializationArgs: *const slang_SpecializationArg, specializationArgCount: SlangInt, outDiagnostics: *mut *mut ISlangBlob) -> *mut slang_TypeReflection, + pub getTypeLayout: unsafe extern "C" fn(*mut c_void, type_: *mut slang_TypeReflection, targetIndex: SlangInt, rules: slang_LayoutRules, outDiagnostics: *mut *mut ISlangBlob) -> *mut slang_TypeLayoutReflection, + pub getContainerType: unsafe extern "C" fn(*mut c_void, elementType: *mut slang_TypeReflection, containerType: slang_ContainerType, outDiagnostics: *mut *mut ISlangBlob) -> *mut slang_TypeReflection, + pub getDynamicType: unsafe extern "C" fn(*mut c_void) -> *mut slang_TypeReflection, + pub getTypeRTTIMangledName: unsafe extern "C" fn(*mut c_void, type_: *mut slang_TypeReflection, outNameBlob: *mut *mut ISlangBlob) -> SlangResult, + pub getTypeConformanceWitnessMangledName: unsafe extern "C" fn(*mut c_void, type_: *mut slang_TypeReflection, interfaceType: *mut slang_TypeReflection, outNameBlob: *mut *mut ISlangBlob) -> SlangResult, + pub getTypeConformanceWitnessSequentialID: unsafe extern "C" fn(*mut c_void, type_: *mut slang_TypeReflection, interfaceType: *mut slang_TypeReflection, outId: *mut u32) -> SlangResult, + pub createCompileRequest: unsafe extern "C" fn(*mut c_void, outCompileRequest: *mut *mut slang_ICompileRequest) -> SlangResult, + pub createTypeConformanceComponentType: unsafe extern "C" fn(*mut c_void, type_: *mut slang_TypeReflection, interfaceType: *mut slang_TypeReflection, outConformance: *mut *mut slang_ITypeConformance, conformanceIdOverride: SlangInt, outDiagnostics: *mut *mut ISlangBlob) -> SlangResult, + pub loadModuleFromIRBlob: unsafe extern "C" fn(*mut c_void, moduleName: *const c_char, path: *const c_char, source: *mut ISlangBlob, outDiagnostics: *mut *mut ISlangBlob) -> *mut slang_IModule, + pub getLoadedModuleCount: unsafe extern "C" fn(*mut c_void) -> SlangInt, + pub getLoadedModule: unsafe extern "C" fn(*mut c_void, index: SlangInt) -> *mut slang_IModule, + pub isBinaryModuleUpToDate: unsafe extern "C" fn(*mut c_void, modulePath: *const c_char, binaryModuleBlob: *mut ISlangBlob) -> bool, + pub loadModuleFromSourceString: unsafe extern "C" fn(*mut c_void, moduleName: *const c_char, path: *const c_char, string: *const c_char, outDiagnostics: *mut *mut ISlangBlob) -> *mut slang_IModule, } #[repr(C)] pub struct IComponentTypeVtable { pub _base: ISlangUnknown__bindgen_vtable, - pub getSession: unsafe extern "stdcall" fn(*mut c_void) -> *mut slang_ISession, - pub getLayout: unsafe extern "stdcall" fn(*mut c_void, targetIndex: SlangInt, outDiagnostics: *mut *mut ISlangBlob) -> *mut slang_ProgramLayout, - pub getSpecializationParamCount: unsafe extern "stdcall" fn(*mut c_void) -> SlangInt, - pub getEntryPointCode: unsafe extern "stdcall" fn(*mut c_void, entryPointIndex: SlangInt, targetIndex: SlangInt, outCode: *mut *mut ISlangBlob, outDiagnostics: *mut *mut ISlangBlob) -> SlangResult, - pub getResultAsFileSystem: unsafe extern "stdcall" fn(*mut c_void, entryPointIndex: SlangInt, targetIndex: SlangInt, outFileSystem: *mut *mut ISlangMutableFileSystem) -> SlangResult, - pub getEntryPointHash: unsafe extern "stdcall" fn(*mut c_void, entryPointIndex: SlangInt, targetIndex: SlangInt, outHash: *mut *mut ISlangBlob), - pub specialize: unsafe extern "stdcall" fn(*mut c_void, specializationArgs: *const slang_SpecializationArg, specializationArgCount: SlangInt, outSpecializedComponentType: *mut *mut slang_IComponentType, outDiagnostics: *mut *mut ISlangBlob) -> SlangResult, - pub link: unsafe extern "stdcall" fn(*mut c_void, outLinkedComponentType: *mut *mut slang_IComponentType, outDiagnostics: *mut *mut ISlangBlob) -> SlangResult, - pub getEntryPointHostCallable: unsafe extern "stdcall" fn(*mut c_void, entryPointIndex: c_int, targetIndex: c_int, outSharedLibrary: *mut *mut ISlangSharedLibrary, outDiagnostics: *mut *mut ISlangBlob) -> SlangResult, - pub renameEntryPoint: unsafe extern "stdcall" fn(*mut c_void, newName: *const c_char, outEntryPoint: *mut *mut slang_IComponentType) -> SlangResult, - pub linkWithOptions: unsafe extern "stdcall" fn(*mut c_void, outLinkedComponentType: *mut *mut slang_IComponentType, compilerOptionEntryCount: u32, compilerOptionEntries: *mut slang_CompilerOptionEntry, outDiagnostics: *mut *mut ISlangBlob) -> SlangResult, - pub getTargetCode: unsafe extern "stdcall" fn(*mut c_void, targetIndex: SlangInt, outCode: *mut *mut ISlangBlob, outDiagnostics: *mut *mut ISlangBlob) -> SlangResult, + pub getSession: unsafe extern "C" fn(*mut c_void) -> *mut slang_ISession, + pub getLayout: unsafe extern "C" fn(*mut c_void, targetIndex: SlangInt, outDiagnostics: *mut *mut ISlangBlob) -> *mut slang_ProgramLayout, + pub getSpecializationParamCount: unsafe extern "C" fn(*mut c_void) -> SlangInt, + pub getEntryPointCode: unsafe extern "C" fn(*mut c_void, entryPointIndex: SlangInt, targetIndex: SlangInt, outCode: *mut *mut ISlangBlob, outDiagnostics: *mut *mut ISlangBlob) -> SlangResult, + pub getResultAsFileSystem: unsafe extern "C" fn(*mut c_void, entryPointIndex: SlangInt, targetIndex: SlangInt, outFileSystem: *mut *mut ISlangMutableFileSystem) -> SlangResult, + pub getEntryPointHash: unsafe extern "C" fn(*mut c_void, entryPointIndex: SlangInt, targetIndex: SlangInt, outHash: *mut *mut ISlangBlob), + pub specialize: unsafe extern "C" fn(*mut c_void, specializationArgs: *const slang_SpecializationArg, specializationArgCount: SlangInt, outSpecializedComponentType: *mut *mut slang_IComponentType, outDiagnostics: *mut *mut ISlangBlob) -> SlangResult, + pub link: unsafe extern "C" fn(*mut c_void, outLinkedComponentType: *mut *mut slang_IComponentType, outDiagnostics: *mut *mut ISlangBlob) -> SlangResult, + pub getEntryPointHostCallable: unsafe extern "C" fn(*mut c_void, entryPointIndex: c_int, targetIndex: c_int, outSharedLibrary: *mut *mut ISlangSharedLibrary, outDiagnostics: *mut *mut ISlangBlob) -> SlangResult, + pub renameEntryPoint: unsafe extern "C" fn(*mut c_void, newName: *const c_char, outEntryPoint: *mut *mut slang_IComponentType) -> SlangResult, + pub linkWithOptions: unsafe extern "C" fn(*mut c_void, outLinkedComponentType: *mut *mut slang_IComponentType, compilerOptionEntryCount: u32, compilerOptionEntries: *mut slang_CompilerOptionEntry, outDiagnostics: *mut *mut ISlangBlob) -> SlangResult, + pub getTargetCode: unsafe extern "C" fn(*mut c_void, targetIndex: SlangInt, outCode: *mut *mut ISlangBlob, outDiagnostics: *mut *mut ISlangBlob) -> SlangResult, } #[repr(C)] @@ -104,15 +104,15 @@ pub struct ITypeConformanceVtable { pub struct IModuleVtable { pub _base: IComponentTypeVtable, - pub findEntryPointByName: unsafe extern "stdcall" fn(*mut c_void, name: *const c_char, outEntryPoint: *mut *mut slang_IEntryPoint) -> SlangResult, - pub getDefinedEntryPointCount: unsafe extern "stdcall" fn(*mut c_void) -> SlangInt32, - pub getDefinedEntryPoint: unsafe extern "stdcall" fn(*mut c_void, index: SlangInt32, outEntryPoint: *mut *mut slang_IEntryPoint) -> SlangResult, - pub serialize: unsafe extern "stdcall" fn(*mut c_void, outSerializedBlob: *mut *mut ISlangBlob) -> SlangResult, - pub writeToFile: unsafe extern "stdcall" fn(*mut c_void, fileName: *const c_char) -> SlangResult, - pub getName: unsafe extern "stdcall" fn(*mut c_void) -> *const c_char, - pub getFilePath: unsafe extern "stdcall" fn(*mut c_void) -> *const c_char, - pub getUniqueIdentity: unsafe extern "stdcall" fn(*mut c_void) -> *const c_char, - pub findAndCheckEntryPoint: unsafe extern "stdcall" fn(*mut c_void, name: *const c_char, stage: SlangStage, outEntryPoint: *mut *mut slang_IEntryPoint, outDiagnostics: *mut *mut ISlangBlob) -> SlangResult, - pub getDependencyFileCount: unsafe extern "stdcall" fn(*mut c_void) -> SlangInt32, - pub getDependencyFilePath: unsafe extern "stdcall" fn(*mut c_void, index: SlangInt32) -> *const c_char, + pub findEntryPointByName: unsafe extern "C" fn(*mut c_void, name: *const c_char, outEntryPoint: *mut *mut slang_IEntryPoint) -> SlangResult, + pub getDefinedEntryPointCount: unsafe extern "C" fn(*mut c_void) -> SlangInt32, + pub getDefinedEntryPoint: unsafe extern "C" fn(*mut c_void, index: SlangInt32, outEntryPoint: *mut *mut slang_IEntryPoint) -> SlangResult, + pub serialize: unsafe extern "C" fn(*mut c_void, outSerializedBlob: *mut *mut ISlangBlob) -> SlangResult, + pub writeToFile: unsafe extern "C" fn(*mut c_void, fileName: *const c_char) -> SlangResult, + pub getName: unsafe extern "C" fn(*mut c_void) -> *const c_char, + pub getFilePath: unsafe extern "C" fn(*mut c_void) -> *const c_char, + pub getUniqueIdentity: unsafe extern "C" fn(*mut c_void) -> *const c_char, + pub findAndCheckEntryPoint: unsafe extern "C" fn(*mut c_void, name: *const c_char, stage: SlangStage, outEntryPoint: *mut *mut slang_IEntryPoint, outDiagnostics: *mut *mut ISlangBlob) -> SlangResult, + pub getDependencyFileCount: unsafe extern "C" fn(*mut c_void) -> SlangInt32, + pub getDependencyFilePath: unsafe extern "C" fn(*mut c_void, index: SlangInt32) -> *const c_char, } diff --git a/src/lib.rs b/src/lib.rs index 6a09689..52ce85c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,6 +27,32 @@ const fn uuid(data1: u32, data2: u16, data3: u16, data4: [u8; 8]) -> UUID { } } +pub enum Error { + Code(sys::SlangResult), + Blob(Blob), +} + +impl std::fmt::Debug for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Error::Code(code) => write!(f, "{}", code), + Error::Blob(blob) => write!(f, "{}", blob.as_str().unwrap()), + } + } +} + +pub type Result = std::result::Result; + +fn result_from_blob(code: sys::SlangResult, blob: *mut sys::slang_IBlob) -> Result<()> { + if code < 0 { + Err(Error::Blob(Blob(IUnknown( + std::ptr::NonNull::new(blob as *mut _).unwrap(), + )))) + } else { + Ok(()) + } +} + pub struct ProfileID(sys::SlangProfileID); impl ProfileID { @@ -109,6 +135,10 @@ impl Blob { let size = vcall!(self, getBufferSize()); unsafe { std::slice::from_raw_parts(ptr as *const u8, size) } } + + pub fn as_str(&self) -> std::result::Result<&str, std::str::Utf8Error> { + std::str::from_utf8(self.as_slice()) + } } #[repr(transparent)] @@ -149,7 +179,7 @@ impl GlobalSession { pub fn create_session(&self, desc: &SessionDesc) -> Option { let mut session = null_mut(); - let res = vcall!(self, createSession(desc, &mut session)); + vcall!(self, createSession(desc, &mut session)); Some(Session(IUnknown(std::ptr::NonNull::new( session as *mut _, )?))) @@ -181,7 +211,7 @@ unsafe impl Interface for Session { } impl Session { - pub fn load_module(&self, name: &str) -> Result { + pub fn load_module(&self, name: &str) -> Result { let name = CString::new(name).unwrap(); let mut diagnostics = null_mut(); @@ -191,7 +221,7 @@ impl Session { let blob = Blob(IUnknown( std::ptr::NonNull::new(diagnostics as *mut _).unwrap(), )); - Err(std::str::from_utf8(blob.as_slice()).unwrap().to_string()) + Err(Error::Blob(blob)) } else { let module = Module(IUnknown(std::ptr::NonNull::new(module as *mut _).unwrap())); unsafe { (module.as_unknown().vtable().ISlangUnknown_addRef)(module.as_raw()) }; @@ -199,24 +229,29 @@ impl Session { } } - pub fn create_composite_component_type(&self, components: &[&ComponentType]) -> ComponentType { - let components: Vec<*mut std::ffi::c_void> = - unsafe { components.iter().map(|c| c.as_raw()).collect() }; - + pub fn create_composite_component_type( + &self, + components: &[ComponentType], + ) -> Result { let mut composite_component_type = null_mut(); let mut diagnostics = null_mut(); - let res = vcall!( - self, - createCompositeComponentType( - components.as_ptr() as _, - components.len() as _, - &mut composite_component_type, - &mut diagnostics - ) - ); - ComponentType(IUnknown( + + result_from_blob( + vcall!( + self, + createCompositeComponentType( + components.as_ptr() as _, + components.len() as _, + &mut composite_component_type, + &mut diagnostics + ) + ), + diagnostics, + )?; + + Ok(ComponentType(IUnknown( std::ptr::NonNull::new(composite_component_type as *mut _).unwrap(), - )) + ))) } } @@ -235,42 +270,35 @@ unsafe impl Interface for ComponentType { } impl ComponentType { - pub fn link(&self) -> ComponentType { + pub fn link(&self) -> Result { let mut linked_component_type = null_mut(); let mut diagnostics = null_mut(); - let res = vcall!(self, link(&mut linked_component_type, &mut diagnostics)); - if linked_component_type.is_null() { - let blob = Blob(IUnknown( - std::ptr::NonNull::new(diagnostics as *mut _).unwrap(), - )); - let error = std::str::from_utf8(blob.as_slice()).unwrap().to_string(); - println!("Error: {}", error); - } + result_from_blob( + vcall!(self, link(&mut linked_component_type, &mut diagnostics)), + diagnostics, + )?; - ComponentType(IUnknown( + Ok(ComponentType(IUnknown( std::ptr::NonNull::new(linked_component_type as *mut _).unwrap(), - )) + ))) } - pub fn get_entry_point_code(&self, index: i64, target: i64) -> Vec { + pub fn get_entry_point_code(&self, index: i64, target: i64) -> Result { let mut code = null_mut(); let mut diagnostics = null_mut(); - let res = vcall!( - self, - getEntryPointCode(index, target, &mut code, &mut diagnostics) - ); - if code.is_null() { - let blob = Blob(IUnknown( - std::ptr::NonNull::new(diagnostics as *mut _).unwrap(), - )); - let error = std::str::from_utf8(blob.as_slice()).unwrap().to_string(); - println!("Error: {}", error); - } + result_from_blob( + vcall!( + self, + getEntryPointCode(index, target, &mut code, &mut diagnostics) + ), + diagnostics, + )?; - let blob = Blob(IUnknown(std::ptr::NonNull::new(code as *mut _).unwrap())); - Vec::from(blob.as_slice()) + Ok(Blob(IUnknown( + std::ptr::NonNull::new(code as *mut _).unwrap(), + ))) } }