// // Created by bspeice on 12/28/24. // #include "lib.h" #include #include #include "slang-tools/src/lib.rs.h" namespace slang_compiler { class slang_exception final : public std::exception { public: slang_exception(SlangResult result) : result_{result} {} const char *what() const noexcept override { return what_.c_str(); } private: SlangResult result_; std::string what_; }; } // namespace slang_compiler #define SLANG_TOOLS_CHECK(expr) \ [&]() { \ const SlangResult result = expr; \ if (SLANG_FAILED(result)) { \ throw slang_compiler::slang_exception(result); \ } \ }(); namespace slang_compiler { GlobalSession::GlobalSession() { slang::createGlobalSession(global_session_.writeRef()); } slang::IGlobalSession &GlobalSession::getSlangSession() { return *global_session_; } SlangResult GlobalSession::createSession(const slang::SessionDesc &desc, slang::ISession **outSession) const { return global_session_->createSession(desc, outSession); } std::shared_ptr create_global_session() { return std::make_shared(); } Session::Session(std::shared_ptr global_session, Slang::ComPtr &&session) : global_session_{global_session}, session_{session} {} int64_t Session::get_loaded_module_count() const noexcept { return session_->getLoadedModuleCount(); } ::SlangProfileID get_slang_profile_id(slang::IGlobalSession &global_session, const SlangProfileID_rs slang_profile_id) { switch (slang_profile_id) { case SlangProfileID_rs::spirv_1_0: return global_session.findProfile("spirv_1_0"); } return SLANG_PROFILE_UNKNOWN; } ::SlangTargetFlags get_slang_target_flags(const rust::Vec &flags) noexcept { uint32_t value = 0; std::for_each(flags.cbegin(), flags.cend(), [&value](const auto &flag) { value |= static_cast>(flag); }); return value; } std::unique_ptr create_session(SessionDesc session_desc, std::shared_ptr global_session) { // The Slang session descriptor wants unowned pointers, // so we copy the values from Rust into storage that // survives at least through the call to `createSession`. // Slang will maintain its own copy once as part of the session. // I'm sure there's a better way to do this, but this works for now. // First, convert the compiler options for each target descriptor std::vector targets_options_strings{}; std::vector> targets_options{}; for (const auto &target : session_desc.targets) { std::vector target_options; for (const CompilerOptionEntry &target_option : target) { auto& string_value_0 = targets_options_strings.emplace_back(target_option.value.string_value_0); auto& string_value_1 = targets_options_strings.emplace_back(target_option.value.string_value_1); auto slang_option = slang::CompilerOptionEntry{ .name = target_option.name, .value = slang::CompilerOptionValue{ .kind = target_option.value.kind, .intValue0 = target_option.value.int_value_0, .intValue1 = target_option.value.int_value_1, .stringValue0 = string_value_0.c_str(), .stringValue1 = string_value_1.c_str() } }; target_options.push_back(slang_option); } targets_options.push_back(std::move(target_options)); } // Second, convert the target descriptors std::vector target_descs; for (size_t i = 0; i < session_desc.targets.size(); ++i) { const auto& target = session_desc.targets[i]; auto& target_options = targets_options[i]; const auto target_desc = slang::TargetDesc{ .format = target.format, .profile = get_slang_profile_id(global_session->getSlangSession(), target.profile), .floatingPointMode = target.floating_point_mode, .lineDirectiveMode = target.line_directive_mode, .forceGLSLScalarBufferLayout = target.force_glsl_scalar_buffer_layout, .compilerOptionEntries = target_options.data(), .compilerOptionEntryCount = static_cast(target_options.size()), }; target_descs.push_back(target_desc); } // Finally, build the session descriptor for slang. const auto slang_session_desc = slang::SessionDesc{ .targets = target_descs.data(), .targetCount = static_cast(target_descs.size()), }; Slang::ComPtr session_ptr; SLANG_TOOLS_CHECK(global_session->createSession(slang_session_desc, session_ptr.writeRef())); return std::make_unique(std::move(global_session), std::move(session_ptr)); } } // namespace slang_compiler