Implement ISession and remove ICompileRequest

The ICompileRequest is part of Slang's legacy API, and will be removed
in the future. The ISession and associated interfaces are to be used
instead.
This commit is contained in:
Lauro Oyen 2024-04-04 13:32:30 +02:00
parent 5abe4533ba
commit 095c719ff8
4 changed files with 507 additions and 235 deletions

View File

@ -7,23 +7,37 @@ Currently mostly reflects the needs of our own [engine](https://github.com/Float
## Example
```rust
let session = slang::GlobalSession::new();
let global_session = slang::GlobalSession::new().unwrap();
let mut compile_request = session.create_compile_request();
let search_path = std::ffi::CString::new("shaders/directory").unwrap();
compile_request
.set_codegen_target(slang::CompileTarget::Dxil)
.set_target_profile(session.find_profile("sm_6_5"));
// All compiler options are available through this builder.
let session_options = slang::OptionsBuilder::new()
.optimization(slang::OptimizationLevel::High)
.matrix_layout_row(true);
let entry_point = compile_request
.add_translation_unit(slang::SourceLanguage::Slang, None)
.add_source_file(filepath)
.add_entry_point("main", slang::Stage::Compute);
let target_desc = slang::TargetDescBuilder::new()
.format(slang::CompileTarget::Dxil)
.profile(self.global_session.find_profile("sm_6_5"));
let shader_bytecode = compile_request
.compile()
.expect("Shader compilation failed.")
.get_entry_point_code(entry_point);
let session_desc = slang::SessionDescBuilder::new()
.targets(&[*target_desc])
.search_paths(&[include_path.as_ptr()])
.options(&session_options);
let session = self.global_session.create_session(&session_desc).unwrap();
let module = session.load_module("filename.slang").unwrap();
let entry_point = module.find_entry_point_by_name("main").unwrap();
let program = session.create_composite_component_type(&[
module.downcast(), entry_point.downcast(),
]);
let linked_program = program.link();
let shader_bytecode = linked_program.get_entry_point_code(0, 0);
```
## Installation

View File

@ -98,19 +98,23 @@ impl bindgen::callbacks::ParseCallbacks for ParseCallback {
}
/// Converts `snake_case` or `SNAKE_CASE` to `PascalCase`.
/// If the input is already in `PascalCase` it will be returned as is.
fn pascal_case_from_snake_case(snake_case: &str) -> String {
let mut result = String::new();
let mut capitalize_next = true;
for c in snake_case.chars() {
if c == '_' {
capitalize_next = true;
} else {
if capitalize_next {
let should_lower = snake_case
.chars()
.filter(|c| c.is_alphabetic())
.all(|c| c.is_uppercase());
for part in snake_case.split('_') {
for (i, c) in part.chars().enumerate() {
if i == 0 {
result.push(c.to_ascii_uppercase());
capitalize_next = false;
} else {
} else if should_lower {
result.push(c.to_ascii_lowercase());
} else {
result.push(c);
}
}
}

View File

@ -3,10 +3,15 @@ include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
use std::ffi::{c_char, c_int, c_void};
// The vtables below are manually implemented since bindgen does not yet support
// generating vtables for types with base classes, a critical part of COM interfaces.
// Based on Slang version 2024.1.6
// Based on Slang version 2024.0.10
#[repr(C)]
pub struct IBlobVtable {
pub _base: ISlangUnknown__bindgen_vtable,
pub getBufferPointer: unsafe extern "stdcall" fn(*mut c_void) -> *const c_void,
pub getBufferSize: unsafe extern "stdcall" fn(*mut c_void) -> usize,
}
#[repr(C)]
pub struct IGlobalSessionVtable {
@ -24,7 +29,7 @@ pub struct IGlobalSessionVtable {
pub getDefaultDownstreamCompiler: unsafe extern "stdcall" fn(*mut c_void, sourceLanguage: SlangSourceLanguage) -> SlangPassThrough,
pub setLanguagePrelude: unsafe extern "stdcall" fn(*mut c_void, sourceLanguage: SlangSourceLanguage, preludeText: *const c_char),
pub getLanguagePrelude: unsafe extern "stdcall" fn(*mut c_void, sourceLanguage: SlangSourceLanguage, outPrelude: *mut *mut ISlangBlob),
pub createCompileRequest: unsafe extern "stdcall" fn(*mut c_void, *mut *mut c_void) -> SlangResult,
pub createCompileRequest: unsafe extern "stdcall" fn(*mut c_void, *mut *mut slang_ICompileRequest) -> SlangResult,
pub addBuiltins: unsafe extern "stdcall" fn(*mut c_void, sourcePath: *const c_char, sourceString: *const c_char),
pub setSharedLibraryLoader: unsafe extern "stdcall" fn(*mut c_void, loader: *mut ISlangSharedLibraryLoader),
pub getSharedLibraryLoader: unsafe extern "stdcall" fn(*mut c_void) -> *mut ISlangSharedLibraryLoader,
@ -39,85 +44,70 @@ pub struct IGlobalSessionVtable {
pub getCompilerElapsedTime: unsafe extern "stdcall" fn(*mut c_void, outTotalTime: *mut f64, outDownstreamTime: *mut f64),
pub setSPIRVCoreGrammar: unsafe extern "stdcall" fn(*mut c_void, jsonPath: *const c_char) -> SlangResult,
pub parseCommandLineArguments: unsafe extern "stdcall" fn(*mut c_void, argc: c_int, argv: *const *const c_char, outSessionDesc: *mut slang_SessionDesc, outAuxAllocation: *mut *mut ISlangUnknown) -> SlangResult,
pub getSessionDescDigest: unsafe extern "stdcall" fn(*mut c_void, sessionDesc: *const slang_SessionDesc, outBlob: *mut *mut ISlangBlob) -> SlangResult,
}
#[repr(C)]
pub struct ICompileRequestVtable {
pub struct ISessionVtable {
pub _base: ISlangUnknown__bindgen_vtable,
pub setFileSystem: unsafe extern "stdcall" fn(*mut c_void, fileSystem: *mut ISlangFileSystem),
pub setCompileFlags: unsafe extern "stdcall" fn(*mut c_void, flags: SlangCompileFlags),
pub getCompileFlags: unsafe extern "stdcall" fn(*mut c_void) -> SlangCompileFlags,
pub setDumpIntermediates: unsafe extern "stdcall" fn(*mut c_void, enable: c_int),
pub setDumpIntermediatePrefix: unsafe extern "stdcall" fn(*mut c_void, prefix: *const c_char),
pub setLineDirectiveMode: unsafe extern "stdcall" fn(*mut c_void, mode: SlangLineDirectiveMode),
pub setCodeGenTarget: unsafe extern "stdcall" fn(*mut c_void, target: SlangCompileTarget),
pub addCodeGenTarget: unsafe extern "stdcall" fn(*mut c_void, target: SlangCompileTarget) -> c_int,
pub setTargetProfile: unsafe extern "stdcall" fn(*mut c_void, targetIndex: c_int, profile: SlangProfileID),
pub setTargetFlags: unsafe extern "stdcall" fn(*mut c_void, targetIndex: c_int, flags: SlangTargetFlags),
pub setTargetFloatingPointMode: unsafe extern "stdcall" fn(*mut c_void, targetIndex: c_int, mode: SlangFloatingPointMode),
#[deprecated( note = "Use setMatrixLayoutMode instead")]
pub setTargetMatrixLayoutMode: unsafe extern "stdcall" fn(*mut c_void, target: c_int, mode: SlangMatrixLayoutMode),
pub setMatrixLayoutMode: unsafe extern "stdcall" fn(*mut c_void, mode: SlangMatrixLayoutMode),
pub setDebugInfoLevel: unsafe extern "stdcall" fn(*mut c_void, level: SlangDebugInfoLevel),
pub setOptimizationLevel: unsafe extern "stdcall" fn(*mut c_void, level: SlangOptimizationLevel),
pub setOutputContainerFormat: unsafe extern "stdcall" fn(*mut c_void, format: SlangContainerFormat),
pub setPassThrough: unsafe extern "stdcall" fn(*mut c_void, passThrough: SlangPassThrough),
pub setDiagnosticCallback: unsafe extern "stdcall" fn(*mut c_void, callback: SlangDiagnosticCallback, userData: *const c_void),
pub setWriter: unsafe extern "stdcall" fn(*mut c_void, channel: SlangWriterChannel, writer: *mut ISlangWriter),
pub getWriter: unsafe extern "stdcall" fn(*mut c_void, channel: SlangWriterChannel) -> *mut ISlangWriter,
pub addSearchPath: unsafe extern "stdcall" fn(*mut c_void, searchDir: *const c_char),
pub addPreprocessorDefine: unsafe extern "stdcall" fn(*mut c_void, key: *const c_char, value: *const c_char),
pub processCommandLineArguments: unsafe extern "stdcall" fn(*mut c_void, args: *const *const c_char, argCount: c_int) -> SlangResult,
pub addTranslationUnit: unsafe extern "stdcall" fn(*mut c_void, language: SlangSourceLanguage, name: *const c_char) -> c_int,
pub setDefaultModuleName: unsafe extern "stdcall" fn(*mut c_void, defaultModuleName: *const c_char),
pub addTranslationUnitPreprocessorDefine: unsafe extern "stdcall" fn(*mut c_void, translationUnitIndex: c_int, key: *const c_char, value: *const c_char),
pub addTranslationUnitSourceFile: unsafe extern "stdcall" fn(*mut c_void, translationUnitIndex: c_int, path: *const c_char),
pub addTranslationUnitSourceString: unsafe extern "stdcall" fn(*mut c_void, translationUnitIndex: c_int, path: *const c_char, source: *const c_char),
pub addLibraryReference: unsafe extern "stdcall" fn(*mut c_void, basePath: *const c_char, libData: *const c_void, libDataSize: usize) -> SlangResult,
pub addTranslationUnitSourceStringSpan: unsafe extern "stdcall" fn(*mut c_void, translationUnitIndex: c_int, path: *const c_char, sourceBegin: *const c_char, sourceEnd: *const c_char),
pub addTranslationUnitSourceBlob: unsafe extern "stdcall" fn(*mut c_void, translationUnitIndex: c_int, path: *const c_char, sourceBlob: *mut ISlangBlob),
pub addEntryPoint: unsafe extern "stdcall" fn(*mut c_void, translationUnitIndex: c_int, name: *const c_char, stage: SlangStage) -> c_int,
pub addEntryPointEx: unsafe extern "stdcall" fn(*mut c_void, translationUnitIndex: c_int, name: *const c_char, stage: SlangStage, genericArgCount: c_int, genericArgs: *const *const c_char) -> c_int,
pub setGlobalGenericArgs: unsafe extern "stdcall" fn(*mut c_void, genericArgCount: c_int, genericArgs: *const *const c_char) -> SlangResult,
pub setTypeNameForGlobalExistentialTypeParam: unsafe extern "stdcall" fn(*mut c_void, slotIndex: c_int, typeName: *const c_char) -> SlangResult,
pub setTypeNameForEntryPointExistentialTypeParam: unsafe extern "stdcall" fn(*mut c_void, entryPointIndex: c_int, slotIndex: c_int, typeName: *const c_char) -> SlangResult,
pub setAllowGLSLInput: unsafe extern "stdcall" fn(*mut c_void, value: bool),
pub compile: unsafe extern "stdcall" fn(*mut c_void) -> SlangResult,
pub getDiagnosticOutput: unsafe extern "stdcall" fn(*mut c_void) -> *const c_char,
pub getDiagnosticOutputBlob: unsafe extern "stdcall" fn(*mut c_void, outBlob: *mut *mut ISlangBlob) -> SlangResult,
pub getDependencyFileCount: unsafe extern "stdcall" fn(*mut c_void) -> c_int,
pub getDependencyFilePath: unsafe extern "stdcall" fn(*mut c_void, index: c_int) -> *const c_char,
pub getTranslationUnitCount: unsafe extern "stdcall" fn(*mut c_void) -> c_int,
pub getEntryPointSource: unsafe extern "stdcall" fn(*mut c_void, entryPointIndex: c_int) -> *const c_char,
pub getEntryPointCode: unsafe extern "stdcall" fn(*mut c_void, entryPointIndex: c_int, outSize: *mut usize) -> *const c_void,
pub getEntryPointCodeBlob: unsafe extern "stdcall" fn(*mut c_void, entryPointIndex: c_int, targetIndex: c_int, outBlob: *mut *mut ISlangBlob) -> SlangResult,
pub getEntryPointHostCallable: unsafe extern "stdcall" fn(*mut c_void, entryPointIndex: c_int, targetIndex: c_int, outSharedLibrary: *mut *mut ISlangSharedLibrary) -> SlangResult,
pub getTargetCodeBlob: unsafe extern "stdcall" fn(*mut c_void, targetIndex: c_int, outBlob: *mut *mut ISlangBlob) -> SlangResult,
pub getTargetHostCallable: unsafe extern "stdcall" fn(*mut c_void, targetIndex: c_int, outSharedLibrary: *mut *mut ISlangSharedLibrary) -> SlangResult,
pub getCompileRequestCode: unsafe extern "stdcall" fn(*mut c_void, outSize: *mut usize) -> *const c_void,
pub getCompileRequestResultAsFileSystem: unsafe extern "stdcall" fn(*mut c_void) -> *mut ISlangMutableFileSystem,
pub getContainerCode: unsafe extern "stdcall" fn(*mut c_void, outBlob: *mut *mut ISlangBlob) -> SlangResult,
pub loadRepro: unsafe extern "stdcall" fn(*mut c_void, fileSystem: *mut ISlangFileSystem, data: *const c_void, size: usize) -> SlangResult,
pub saveRepro: unsafe extern "stdcall" fn(*mut c_void, outBlob: *mut *mut ISlangBlob) -> SlangResult,
pub enableReproCapture: unsafe extern "stdcall" fn(*mut c_void) -> SlangResult,
pub getProgram: unsafe extern "stdcall" fn(*mut c_void, outProgram: *mut *mut slang_IComponentType) -> SlangResult,
pub getEntryPoint: unsafe extern "stdcall" fn(*mut c_void, entryPointIndex: c_int, outEntryPoint: *mut *mut slang_IComponentType) -> SlangResult,
pub getModule: unsafe extern "stdcall" fn(*mut c_void, translationUnitIndex: c_int, outModule: *mut *mut slang_IModule) -> SlangResult,
pub getSession: unsafe extern "stdcall" fn(*mut c_void, outSession: *mut *mut slang_ISession) -> SlangResult,
pub getReflection: unsafe extern "stdcall" fn(*mut c_void) -> *mut SlangReflection,
pub setCommandLineCompilerMode: unsafe extern "stdcall" fn(*mut c_void),
pub addTargetCapability: unsafe extern "stdcall" fn(*mut c_void, targetIndex: c_int, capability: SlangCapabilityID) -> SlangResult,
pub getProgramWithEntryPoints: unsafe extern "stdcall" fn(*mut c_void, outProgram: *mut *mut slang_IComponentType) -> SlangResult,
pub isParameterLocationUsed: unsafe extern "stdcall" fn(*mut c_void, entryPointIndex: c_int, targetIndex: c_int, category: SlangParameterCategory, spaceIndex: SlangUInt, registerIndex: SlangUInt, outUsed: *mut bool) -> SlangResult,
pub setTargetLineDirectiveMode: unsafe extern "stdcall" fn(*mut c_void, targetIndex: c_int, mode: SlangLineDirectiveMode),
pub setTargetForceGLSLScalarBufferLayout: unsafe extern "stdcall" fn(*mut c_void, targetIndex: c_int, forceScalarLayout: bool),
pub overrideDiagnosticSeverity: unsafe extern "stdcall" fn(*mut c_void, messageID: SlangInt, overrideSeverity: SlangSeverity),
pub getDiagnosticFlags: unsafe extern "stdcall" fn(*mut c_void) -> SlangDiagnosticFlags,
pub setDiagnosticFlags: unsafe extern "stdcall" fn(*mut c_void, flags: SlangDiagnosticFlags),
pub setDebugInfoFormat: unsafe extern "stdcall" fn(*mut c_void, debugFormat: SlangDebugInfoFormat),
pub setEnableEffectAnnotations: unsafe extern "stdcall" fn(*mut c_void, value: bool),
pub setReportDownstreamTime: unsafe extern "stdcall" fn(*mut c_void, value: bool),
pub setReportPerfBenchmark: unsafe extern "stdcall" fn(*mut c_void, value: bool),
pub setSkipSPIRVValidation: unsafe extern "stdcall" fn(*mut c_void, value: bool),
pub getGlobalSession: unsafe extern "stdcall" fn(*mut c_void) -> *mut slang_IGlobalSession,
pub loadModule: unsafe extern "stdcall" fn(*mut c_void, moduleName: *const c_char, outDiagnostics: *mut *mut ISlangBlob) -> *mut slang_IModule,
pub loadModuleFromSource: unsafe extern "stdcall" fn(*mut c_void, moduleName: *const c_char, path: *const c_char, source: *mut ISlangBlob, outDiagnostics: *mut *mut ISlangBlob) -> *mut slang_IModule,
pub createCompositeComponentType: unsafe extern "stdcall" fn(*mut c_void, componentTypes: *const *const slang_IComponentType, componentTypeCount: SlangInt, outCompositeComponentType: *mut *mut slang_IComponentType, outDiagnostics: *mut *mut ISlangBlob) -> SlangResult,
pub specializeType: unsafe extern "stdcall" fn(*mut c_void, type_: *mut slang_TypeReflection, specializationArgs: *const slang_SpecializationArg, specializationArgCount: SlangInt, outDiagnostics: *mut *mut ISlangBlob) -> *mut slang_TypeReflection,
pub getTypeLayout: unsafe extern "stdcall" fn(*mut c_void, type_: *mut slang_TypeReflection, targetIndex: SlangInt, rules: slang_LayoutRules, outDiagnostics: *mut *mut ISlangBlob) -> *mut slang_TypeLayoutReflection,
pub getContainerType: unsafe extern "stdcall" fn(*mut c_void, elementType: *mut slang_TypeReflection, containerType: slang_ContainerType, outDiagnostics: *mut *mut ISlangBlob) -> *mut slang_TypeReflection,
pub getDynamicType: unsafe extern "stdcall" fn(*mut c_void) -> *mut slang_TypeReflection,
pub getTypeRTTIMangledName: unsafe extern "stdcall" fn(*mut c_void, type_: *mut slang_TypeReflection, outNameBlob: *mut *mut ISlangBlob) -> SlangResult,
pub getTypeConformanceWitnessMangledName: unsafe extern "stdcall" fn(*mut c_void, type_: *mut slang_TypeReflection, interfaceType: *mut slang_TypeReflection, outNameBlob: *mut *mut ISlangBlob) -> SlangResult,
pub getTypeConformanceWitnessSequentialID: unsafe extern "stdcall" fn(*mut c_void, type_: *mut slang_TypeReflection, interfaceType: *mut slang_TypeReflection, outId: *mut u32) -> SlangResult,
pub createCompileRequest: unsafe extern "stdcall" fn(*mut c_void, outCompileRequest: *mut *mut slang_ICompileRequest) -> SlangResult,
pub createTypeConformanceComponentType: unsafe extern "stdcall" fn(*mut c_void, type_: *mut slang_TypeReflection, interfaceType: *mut slang_TypeReflection, outConformance: *mut *mut slang_ITypeConformance, conformanceIdOverride: SlangInt, outDiagnostics: *mut *mut ISlangBlob) -> SlangResult,
pub loadModuleFromIRBlob: unsafe extern "stdcall" fn(*mut c_void, moduleName: *const c_char, path: *const c_char, source: *mut ISlangBlob, outDiagnostics: *mut *mut ISlangBlob) -> *mut slang_IModule,
pub getLoadedModuleCount: unsafe extern "stdcall" fn(*mut c_void) -> SlangInt,
pub getLoadedModule: unsafe extern "stdcall" fn(*mut c_void, index: SlangInt) -> *mut slang_IModule,
pub isBinaryModuleUpToDate: unsafe extern "stdcall" fn(*mut c_void, modulePath: *const c_char, binaryModuleBlob: *mut ISlangBlob) -> bool,
pub loadModuleFromSourceString: unsafe extern "stdcall" fn(*mut c_void, moduleName: *const c_char, path: *const c_char, string: *const c_char, outDiagnostics: *mut *mut ISlangBlob) -> *mut slang_IModule,
}
#[repr(C)]
pub struct IComponentTypeVtable {
pub _base: ISlangUnknown__bindgen_vtable,
pub getSession: unsafe extern "stdcall" fn(*mut c_void) -> *mut slang_ISession,
pub getLayout: unsafe extern "stdcall" fn(*mut c_void, targetIndex: SlangInt, outDiagnostics: *mut *mut ISlangBlob) -> *mut slang_ProgramLayout,
pub getSpecializationParamCount: unsafe extern "stdcall" fn(*mut c_void) -> SlangInt,
pub getEntryPointCode: unsafe extern "stdcall" fn(*mut c_void, entryPointIndex: SlangInt, targetIndex: SlangInt, outCode: *mut *mut ISlangBlob, outDiagnostics: *mut *mut ISlangBlob) -> SlangResult,
pub getResultAsFileSystem: unsafe extern "stdcall" fn(*mut c_void, entryPointIndex: SlangInt, targetIndex: SlangInt, outFileSystem: *mut *mut ISlangMutableFileSystem) -> SlangResult,
pub getEntryPointHash: unsafe extern "stdcall" fn(*mut c_void, entryPointIndex: SlangInt, targetIndex: SlangInt, outHash: *mut *mut ISlangBlob),
pub specialize: unsafe extern "stdcall" fn(*mut c_void, specializationArgs: *const slang_SpecializationArg, specializationArgCount: SlangInt, outSpecializedComponentType: *mut *mut slang_IComponentType, outDiagnostics: *mut *mut ISlangBlob) -> SlangResult,
pub link: unsafe extern "stdcall" fn(*mut c_void, outLinkedComponentType: *mut *mut slang_IComponentType, outDiagnostics: *mut *mut ISlangBlob) -> SlangResult,
pub getEntryPointHostCallable: unsafe extern "stdcall" fn(*mut c_void, entryPointIndex: c_int, targetIndex: c_int, outSharedLibrary: *mut *mut ISlangSharedLibrary, outDiagnostics: *mut *mut ISlangBlob) -> SlangResult,
pub renameEntryPoint: unsafe extern "stdcall" fn(*mut c_void, newName: *const c_char, outEntryPoint: *mut *mut slang_IComponentType) -> SlangResult,
pub linkWithOptions: unsafe extern "stdcall" fn(*mut c_void, outLinkedComponentType: *mut *mut slang_IComponentType, compilerOptionEntryCount: u32, compilerOptionEntries: *mut slang_CompilerOptionEntry, outDiagnostics: *mut *mut ISlangBlob) -> SlangResult,
}
#[repr(C)]
pub struct IEntryPointVtable {
pub _base: IComponentTypeVtable,
}
#[repr(C)]
pub struct ITypeConformanceVtable {
pub _base: IComponentTypeVtable,
}
#[repr(C)]
pub struct IModuleVtable {
pub _base: IComponentTypeVtable,
pub findEntryPointByName: unsafe extern "stdcall" fn(*mut c_void, name: *const c_char, outEntryPoint: *mut *mut slang_IEntryPoint) -> SlangResult,
pub getDefinedEntryPointCount: unsafe extern "stdcall" fn(*mut c_void) -> SlangInt32,
pub getDefinedEntryPoint: unsafe extern "stdcall" fn(*mut c_void, index: SlangInt32, outEntryPoint: *mut *mut slang_IEntryPoint) -> SlangResult,
pub serialize: unsafe extern "stdcall" fn(*mut c_void, outSerializedBlob: *mut *mut ISlangBlob) -> SlangResult,
pub writeToFile: unsafe extern "stdcall" fn(*mut c_void, fileName: *const c_char) -> SlangResult,
pub getName: unsafe extern "stdcall" fn(*mut c_void) -> *const c_char,
pub getFilePath: unsafe extern "stdcall" fn(*mut c_void) -> *const c_char,
pub getUniqueIdentity: unsafe extern "stdcall" fn(*mut c_void) -> *const c_char,
}

View File

@ -1,15 +1,24 @@
use slang_sys as sys;
use std::ffi::{CStr, CString};
use std::path::Path;
use std::slice;
use std::ptr::{null, null_mut};
pub use sys::SlangUUID as UUID;
pub use sys::SlangCompileTarget as CompileTarget;
pub use sys::SlangMatrixLayoutMode as MatrixLayoutMode;
pub use sys::SlangOptimizationLevel as OptimizationLevel;
pub use sys::SlangSourceLanguage as SourceLanguage;
pub use sys::SlangStage as Stage;
pub use sys::SlangProfileID as ProfileID;
use slang_sys as sys;
pub use sys::{
slang_CompilerOptionName as CompilerOptionName,
slang_SessionDesc as SessionDesc,
slang_TargetDesc as TargetDesc,
SlangCapabilityID as CapabilityID,
SlangCompileTarget as CompileTarget,
SlangDebugInfoLevel as DebugInfoLevel,
SlangFloatingPointMode as FloatingPointMode,
SlangLineDirectiveMode as LineDirectiveMode,
SlangMatrixLayoutMode as MatrixLayoutMode,
SlangOptimizationLevel as OptimizationLevel,
SlangProfileID as ProfileID,
SlangSourceLanguage as SourceLanguage,
SlangStage as Stage,
SlangUUID as UUID,
};
macro_rules! vcall {
($self:expr, $method:ident($($args:expr),*)) => {
@ -34,6 +43,15 @@ unsafe trait Interface: Sized {
unsafe fn as_raw<T>(&self) -> *mut T {
std::mem::transmute_copy(self)
}
fn as_unknown(&self) -> &IUnknown {
// SAFETY: It is always safe to treat an `Interface` as an `IUnknown`.
unsafe { std::mem::transmute(self) }
}
}
pub unsafe trait Downcast<T> {
fn downcast(&self) -> &T;
}
#[repr(transparent)]
@ -41,7 +59,7 @@ pub struct IUnknown(std::ptr::NonNull<std::ffi::c_void>);
unsafe impl Interface for IUnknown {
type Vtable = sys::ISlangUnknown__bindgen_vtable;
const IID: UUID = uuid(0x00000000, 0x0000, 0x0000, [0xC0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x46]);
const IID: UUID = uuid(0x00000000, 0x0000, 0x0000, [0xc0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x46]);
}
impl Clone for IUnknown {
@ -57,168 +75,414 @@ impl Drop for IUnknown {
}
}
#[repr(transparent)]
#[derive(Clone)]
pub struct Blob(IUnknown);
unsafe impl Interface for Blob {
type Vtable = sys::IBlobVtable;
const IID: UUID = uuid(0x8ba5fb08, 0x5195, 0x40e2, [0xac, 0x58, 0x0d, 0x98, 0x9c, 0x3a, 0x01, 0x02]);
}
impl Blob {
pub fn as_slice(&self) -> &[u8] {
let ptr = vcall!(self, getBufferPointer());
let size = vcall!(self, getBufferSize());
unsafe { std::slice::from_raw_parts(ptr as *const u8, size) }
}
}
#[repr(transparent)]
#[derive(Clone)]
pub struct GlobalSession(IUnknown);
unsafe impl Interface for GlobalSession {
type Vtable = sys::IGlobalSessionVtable;
const IID: UUID = uuid(0xc140b5fd, 0xc78, 0x452e, [0xba, 0x7c, 0x1a, 0x1e, 0x70, 0xc7, 0xf7, 0x1c]);
const IID: UUID = uuid(0xc140b5fd, 0x0c78, 0x452e, [0xba, 0x7c, 0x1a, 0x1e, 0x70, 0xc7, 0xf7, 0x1c]);
}
impl GlobalSession {
pub fn new() -> GlobalSession {
unsafe {
let mut global_session = std::ptr::null_mut();
sys::slang_createGlobalSession(sys::SLANG_API_VERSION as _, &mut global_session);
GlobalSession(IUnknown(std::ptr::NonNull::new(global_session as *mut _).unwrap()))
}
pub fn new() -> Option<GlobalSession> {
let mut global_session = null_mut();
unsafe { sys::slang_createGlobalSession(sys::SLANG_API_VERSION as _, &mut global_session) };
Some(GlobalSession(IUnknown(std::ptr::NonNull::new(global_session as *mut _)?)))
}
pub fn new_without_std_lib() -> GlobalSession {
unsafe {
let mut global_session = std::ptr::null_mut();
sys::slang_createGlobalSessionWithoutStdLib(sys::SLANG_API_VERSION as _, &mut global_session);
GlobalSession(IUnknown(std::ptr::NonNull::new(global_session as *mut _).unwrap()))
}
pub fn new_without_std_lib() -> Option<GlobalSession> {
let mut global_session = null_mut();
unsafe { sys::slang_createGlobalSessionWithoutStdLib(sys::SLANG_API_VERSION as _, &mut global_session) };
Some(GlobalSession(IUnknown(std::ptr::NonNull::new(global_session as *mut _)?)))
}
pub fn create_compile_request(&self) -> CompileRequest {
let mut compile_request = std::ptr::null_mut();
vcall!(self, createCompileRequest(&mut compile_request));
CompileRequest(IUnknown(std::ptr::NonNull::new(compile_request).unwrap()))
pub fn create_session(&self, desc: &SessionDesc) -> Option<Session> {
let mut session = null_mut();
let res = vcall!(self, createSession(desc, &mut session));
let session = Session(IUnknown(std::ptr::NonNull::new(session as *mut _)?));
// TODO: Without adding an extra reference, the code crashes when Session is dropped.
// Investigate why this is happening, the current solution could cause a memory leak.
unsafe { (session.as_unknown().vtable().ISlangUnknown_addRef)(session.as_raw()) };
Some(session)
}
pub fn find_profile(&self, name: &str) -> ProfileID {
let name = CString::new(name).unwrap();
vcall!(self, findProfile(name.as_ptr()))
}
pub fn find_capability(&self, name: &str) -> CapabilityID {
let name = CString::new(name).unwrap();
vcall!(self, findCapability(name.as_ptr()))
}
}
#[repr(transparent)]
#[derive(Clone)]
pub struct CompileRequest(IUnknown);
pub struct Session(IUnknown);
unsafe impl Interface for CompileRequest {
type Vtable = sys::ICompileRequestVtable;
const IID: UUID = uuid(0x96d33993, 0x317c, 0x4db5, [0xaf, 0xd8, 0x66, 0x6e, 0xe7, 0x72, 0x48, 0xe2]);
unsafe impl Interface for Session {
type Vtable = sys::ISessionVtable;
const IID: UUID = uuid(0x67618701, 0xd116, 0x468f, [0xab, 0x3b, 0x47, 0x4b, 0xed, 0xce, 0x0e, 0x3d]);
}
impl CompileRequest {
pub fn set_codegen_target(&mut self, target: CompileTarget) -> &mut Self {
vcall!(self, setCodeGenTarget(target));
self
}
pub fn set_matrix_layout_mode(&mut self, layout: MatrixLayoutMode) -> &mut Self {
vcall!(self, setMatrixLayoutMode(layout));
self
}
pub fn set_optimization_level(&mut self, level: OptimizationLevel) -> &mut Self {
vcall!(self, setOptimizationLevel(level));
self
}
pub fn set_target_profile(&mut self, profile: ProfileID) ->&mut Self {
vcall!(self, setTargetProfile(0, profile));
self
}
pub fn add_preprocessor_define(&mut self, key: &str, value: &str) -> &mut Self {
let key = CString::new(key).unwrap();
let value = CString::new(value).unwrap();
vcall!(self, addPreprocessorDefine(key.as_ptr(), value.as_ptr()));
self
}
pub fn add_search_path(&mut self, path: impl AsRef<Path>) -> &mut Self {
let path = CString::new(path.as_ref().to_str().unwrap()).unwrap();
vcall!(self, addSearchPath(path.as_ptr()));
self
}
pub fn add_translation_unit(&mut self, source_language: SourceLanguage, name: Option<&str>) -> TranslationUnit {
let name = CString::new(name.unwrap_or("")).unwrap();
let index = vcall!(self, addTranslationUnit(source_language, name.as_ptr()));
TranslationUnit {
request: self,
index,
}
}
pub fn compile(self) -> Result<CompiledRequest, CompilationErrors> {
let r = vcall!(self, compile());
if r < 0 {
let out = vcall!(self, getDiagnosticOutput());
let errors = unsafe { CStr::from_ptr(out).to_str().unwrap().to_string() };
Err(CompilationErrors { errors })
} else {
Ok(CompiledRequest { request: self })
}
}
}
pub struct TranslationUnit<'a> {
request: &'a mut CompileRequest,
index: i32,
}
impl<'a> TranslationUnit<'a> {
pub fn add_preprocessor_define(&mut self, key: &str, value: &str) -> &mut Self {
let key = CString::new(key).unwrap();
let value = CString::new(value).unwrap();
vcall!(self.request, addTranslationUnitPreprocessorDefine(self.index, key.as_ptr(), value.as_ptr()));
self
}
pub fn add_source_file(&mut self, path: impl AsRef<Path>) -> &mut Self {
let path = CString::new(path.as_ref().to_str().unwrap()).unwrap();
vcall!(self.request, addTranslationUnitSourceFile(self.index, path.as_ptr()));
self
}
pub fn add_source_string(&mut self, path: impl AsRef<Path>, source: &str) -> &mut Self {
let path = CString::new(path.as_ref().to_str().unwrap()).unwrap();
let source = CString::new(source).unwrap();
vcall!(self.request, addTranslationUnitSourceString(self.index, path.as_ptr(), source.as_ptr()));
self
}
pub fn add_entry_point(&mut self, name: &str, stage: Stage) -> EntryPointIndex {
impl Session {
pub fn load_module(&self, name: &str) -> Result<Module, String> {
let name = CString::new(name).unwrap();
let index = vcall!(self.request, addEntryPoint(self.index, name.as_ptr(), stage));
EntryPointIndex(index)
let mut diagnostics = null_mut();
let module = vcall!(self, loadModule(name.as_ptr(), &mut diagnostics));
if module.is_null() {
let blob = Blob(IUnknown(std::ptr::NonNull::new(diagnostics as *mut _).unwrap()));
Err(std::str::from_utf8(blob.as_slice()).unwrap().to_string())
} else {
Ok(Module(IUnknown(std::ptr::NonNull::new(module as *mut _).unwrap())))
}
}
pub struct CompiledRequest {
request: CompileRequest,
}
pub fn create_composite_component_type(&self, components: &[&ComponentType]) -> ComponentType {
let components: Vec<*mut std::ffi::c_void> = unsafe { components.iter().map(|c| c.as_raw()).collect() };
impl CompiledRequest {
pub fn get_entry_point_code(&self, index: EntryPointIndex) -> &[u8] {
let mut out_size = 0;
let ptr = vcall!(self.request, getEntryPointCode(index.0, &mut out_size));
unsafe { slice::from_raw_parts(ptr as *const u8, out_size) }
let mut composite_component_type = null_mut();
let mut diagnostics = null_mut();
let res = vcall!(self, createCompositeComponentType(components.as_ptr() as _, components.len() as _, &mut composite_component_type, &mut diagnostics));
ComponentType(IUnknown(std::ptr::NonNull::new(composite_component_type as *mut _).unwrap()))
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub struct EntryPointIndex(pub i32);
#[repr(transparent)]
#[derive(Clone)]
pub struct ComponentType(IUnknown);
pub struct CompilationErrors {
errors: String,
unsafe impl Interface for ComponentType {
type Vtable = sys::IComponentTypeVtable;
const IID: UUID = uuid(0x5bc42be8, 0x5c50, 0x4929, [0x9e, 0x5e, 0xd1, 0x5e, 0x7c, 0x24, 0x01, 0x5f]);
}
impl std::fmt::Debug for CompilationErrors {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("\n")?;
f.write_str(&self.errors)
impl ComponentType {
pub fn link(&self) -> ComponentType {
let mut linked_component_type = null_mut();
let mut diagnostics = null_mut();
let res = vcall!(self, link(&mut linked_component_type, &mut diagnostics));
if linked_component_type.is_null() {
let blob = Blob(IUnknown(std::ptr::NonNull::new(diagnostics as *mut _).unwrap()));
let error = std::str::from_utf8(blob.as_slice()).unwrap().to_string();
println!("Error: {}", error);
}
ComponentType(IUnknown(std::ptr::NonNull::new(linked_component_type as *mut _).unwrap()))
}
pub fn get_entry_point_code(&self, index: i64, target: i64) -> Vec<u8> {
let mut code = null_mut();
let mut diagnostics = null_mut();
let res = vcall!(self, getEntryPointCode(index, target, &mut code, &mut diagnostics));
if code.is_null() {
let blob = Blob(IUnknown(std::ptr::NonNull::new(diagnostics as *mut _).unwrap()));
let error = std::str::from_utf8(blob.as_slice()).unwrap().to_string();
println!("Error: {}", error);
}
let blob = Blob(IUnknown(std::ptr::NonNull::new(code as *mut _).unwrap()));
Vec::from(blob.as_slice())
}
}
#[repr(transparent)]
#[derive(Clone)]
pub struct EntryPoint(IUnknown);
unsafe impl Interface for EntryPoint {
type Vtable = sys::IEntryPointVtable;
const IID: UUID = uuid(0x8f241361, 0xf5bd, 0x4ca0, [0xa3, 0xac, 0x02, 0xf7, 0xfa, 0x24, 0x02, 0xb8]);
}
unsafe impl Downcast<ComponentType> for EntryPoint {
fn downcast(&self) -> &ComponentType {
unsafe { std::mem::transmute(self) }
}
}
#[repr(transparent)]
#[derive(Clone)]
pub struct TypeConformance(IUnknown);
unsafe impl Interface for TypeConformance {
type Vtable = sys::ITypeConformanceVtable;
const IID: UUID = uuid(0x73eb3147, 0xe544, 0x41b5, [0xb8, 0xf0, 0xa2, 0x44, 0xdf, 0x21, 0x94, 0x0b]);
}
unsafe impl Downcast<ComponentType> for TypeConformance {
fn downcast(&self) -> &ComponentType {
unsafe { std::mem::transmute(self) }
}
}
#[repr(transparent)]
#[derive(Clone)]
pub struct Module(IUnknown);
unsafe impl Interface for Module {
type Vtable = sys::IModuleVtable;
const IID: UUID = uuid(0x0c720e64, 0x8722, 0x4d31, [0x89, 0x90, 0x63, 0x8a, 0x98, 0xb1, 0xc2, 0x79]);
}
unsafe impl Downcast<ComponentType> for Module {
fn downcast(&self) -> &ComponentType {
unsafe { std::mem::transmute(self) }
}
}
impl Module {
pub fn find_entry_point_by_name(&self, name: &str) -> Option<EntryPoint> {
let name = CString::new(name).unwrap();
let mut entry_point = null_mut();
vcall!(self, findEntryPointByName(name.as_ptr(), &mut entry_point));
Some(EntryPoint(IUnknown(std::ptr::NonNull::new(entry_point as *mut _)?)))
}
pub fn name(&self) -> &str {
let name = vcall!(self, getName());
unsafe { CStr::from_ptr(name).to_str().unwrap() }
}
pub fn file_path(&self) -> &str {
let path = vcall!(self, getFilePath());
unsafe { CStr::from_ptr(path).to_str().unwrap() }
}
pub fn unique_identity(&self) -> &str {
let identity = vcall!(self, getUniqueIdentity());
unsafe { CStr::from_ptr(identity).to_str().unwrap() }
}
}
pub struct TargetDescBuilder {
inner: TargetDesc,
}
impl TargetDescBuilder {
pub fn new() -> TargetDescBuilder {
Self {
inner: TargetDesc {
structureSize: std::mem::size_of::<TargetDesc>(),
..unsafe { std::mem::zeroed() }
}
}
}
pub fn format(mut self, format: CompileTarget) -> Self {
self.inner.format = format;
self
}
pub fn profile(mut self, profile: ProfileID) -> Self {
self.inner.profile = profile;
self
}
pub fn options(mut self, options: &OptionsBuilder) -> Self {
self.inner.compilerOptionEntries = options.options.as_ptr() as _;
self.inner.compilerOptionEntryCount = options.options.len() as _;
self
}
}
impl std::ops::Deref for TargetDescBuilder {
type Target = TargetDesc;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
pub struct SessionDescBuilder {
inner: SessionDesc,
}
impl SessionDescBuilder {
pub fn new() -> SessionDescBuilder {
Self {
inner: SessionDesc {
structureSize: std::mem::size_of::<SessionDesc>(),
..unsafe { std::mem::zeroed() }
}
}
}
pub fn targets(mut self, targets: &[TargetDesc]) -> Self {
self.inner.targets = targets.as_ptr();
self.inner.targetCount = targets.len() as _;
self
}
pub fn search_paths(mut self, paths: &[*const i8]) -> Self {
self.inner.searchPaths = paths.as_ptr();
self.inner.searchPathCount = paths.len() as _;
self
}
pub fn options(mut self, options: &OptionsBuilder) -> Self {
self.inner.compilerOptionEntries = options.options.as_ptr() as _;
self.inner.compilerOptionEntryCount = options.options.len() as _;
self
}
}
impl std::ops::Deref for SessionDescBuilder {
type Target = SessionDesc;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
macro_rules! option {
($name:ident, $func:ident($p_name:ident: $p_type:ident)) => {
#[inline(always)]
pub fn $func(self, $p_name: $p_type) -> Self {
self.push_ints(CompilerOptionName::$name, $p_name as _, 0)
}
};
($name:ident, $func:ident($p_name:ident: &str)) => {
#[inline(always)]
pub fn $func(self, $p_name: &str) -> Self {
self.push_str1(CompilerOptionName::$name, $p_name)
}
};
($name:ident, $func:ident($p_name1:ident: &str, $p_name2:ident: &str)) => {
#[inline(always)]
pub fn $func(self, $p_name1: &str, $p_name2: &str) -> Self {
self.push_str2(CompilerOptionName::$name, $p_name1, $p_name2)
}
};
}
pub struct OptionsBuilder {
strings: Vec<CString>,
options: Vec<sys::slang_CompilerOptionEntry>,
}
impl OptionsBuilder {
pub fn new() -> OptionsBuilder {
OptionsBuilder {
strings: Vec::new(),
options: Vec::new(),
}
}
fn push_ints(mut self, name: CompilerOptionName, i0: i32, i1: i32) -> Self {
self.options.push(sys::slang_CompilerOptionEntry {
name,
value: sys::slang_CompilerOptionValue {
kind: sys::slang_CompilerOptionValueKind::Int,
intValue0: i0,
intValue1: i1,
stringValue0: null(),
stringValue1: null(),
},
});
self
}
fn push_strings(mut self, name: CompilerOptionName, s0: *const i8, s1: *const i8) -> Self {
self.options.push(sys::slang_CompilerOptionEntry {
name,
value: sys::slang_CompilerOptionValue {
kind: sys::slang_CompilerOptionValueKind::String,
intValue0: 0,
intValue1: 0,
stringValue0: s0,
stringValue1: s1,
},
});
self
}
fn push_str1(mut self, name: CompilerOptionName, s0: &str) -> Self {
let s0 = CString::new(s0).unwrap();
let s0_ptr = s0.as_ptr();
self.strings.push(s0);
self.push_strings(name, s0_ptr, null())
}
fn push_str2(mut self, name: CompilerOptionName, s0: &str, s1: &str) -> Self {
let s0 = CString::new(s0).unwrap();
let s0_ptr = s0.as_ptr();
self.strings.push(s0);
let s1 = CString::new(s1).unwrap();
let s1_ptr = s1.as_ptr();
self.strings.push(s1);
self.push_strings(name, s0_ptr, s1_ptr)
}
}
impl OptionsBuilder {
option!(MacroDefine, macro_define(key: &str, value: &str));
option!(Include, include(path: &str));
option!(Language, language(language: SourceLanguage));
option!(MatrixLayoutColumn, matrix_layout_column(enable: bool));
option!(MatrixLayoutRow, matrix_layout_row(enable: bool));
option!(Profile, profile(profile: ProfileID));
option!(Stage, stage(stage: Stage));
option!(Target, target(target: CompileTarget));
option!(WarningsAsErrors, warnings_as_errors(warning_codes: &str));
option!(DisableWarnings, disable_warnings(warning_codes: &str));
option!(EnableWarning, enable_warning(warning_code: &str));
option!(DisableWarning, disable_warning(warning_code: &str));
option!(ReportDownstreamTime, report_downstream_time(enable: bool));
option!(ReportPerfBenchmark, report_perf_benchmark(enable: bool));
option!(SkipSPIRVValidation, skip_spirv_validation(enable: bool));
// Target
option!(Capability, capability(capability: CapabilityID));
option!(DefaultImageFormatUnknown, default_image_format_unknown(enable: bool));
option!(DisableDynamicDispatch, disable_dynamic_dispatch(enable: bool));
option!(DisableSpecialization, disable_specialization(enable: bool));
option!(FloatingPointMode, floating_point_mode(mode: FloatingPointMode));
option!(DebugInformation, debug_information(level: DebugInfoLevel));
option!(LineDirectiveMode, line_directive_mode(mode: LineDirectiveMode));
option!(Optimization, optimization(level: OptimizationLevel));
option!(Obfuscate, obfuscate(enable: bool));
option!(GLSLForceScalarLayout, glsl_force_scalar_layout(enable: bool));
option!(EmitSpirvDirectly, emit_spirv_directly(enable: bool));
// Debugging
option!(NoCodeGen, no_code_gen(enable: bool));
// Experimental
option!(NoMangle, no_mangle(enable: bool));
option!(ValidateUniformity, validate_uniformity(enable: bool));
}
#[cfg(test)]