zengine/engine/modules/engine/render/src/tool/glsl_to_spirv.cpp
2024-11-02 17:55:55 +08:00

116 lines
4.4 KiB
C++

#include "render/tool/glsl_to_spirv.h"
#include "render/type.h"
#include "zlog.h"
#include "meta/enum.h"
#include <shaderc/shaderc.hpp>
#include <spirv_cross/spirv_reflect.hpp>
namespace api
{
static pmr::table<Guid, ShaderDescriptorSet> ShaderLayoutTable;
optional<pmr::vector<uint32_t>> GlslToSpirv::ToSpirv(const pmr::string& glsl, shaderc_shader_kind kind, string_view code_id)
{
optional<pmr::vector<uint32_t>> spirv_out{FramePool()};
{
shaderc::Compiler compiler;
shaderc::CompileOptions options;
options.SetOptimizationLevel(shaderc_optimization_level_performance);
options.SetTargetEnvironment(shaderc_target_env::shaderc_target_env_vulkan, shaderc_env_version_vulkan_1_3);
auto result = compiler.CompileGlslToSpv(glsl.data(), kind, code_id.data(), options);
if (result.GetCompilationStatus() != shaderc_compilation_status::shaderc_compilation_status_success) {
auto err_m = result.GetErrorMessage();
auto err_msg = err_m.c_str();
zlog::error("load spirv failed!!! {}", err_msg);
return spirv_out;
}
spirv_out = pmr::vector<uint32_t>{ result.cbegin(),result.cend() };
}
return spirv_out;
}
shaderc_shader_kind GlslToSpirv::GetShaderKind(Name name)
{
switch (name.Hash())
{
case string_hash(".vert"): return shaderc_shader_kind::shaderc_vertex_shader;
case string_hash(".geom"): return shaderc_shader_kind::shaderc_geometry_shader;
case string_hash(".tese"): return shaderc_shader_kind::shaderc_tess_evaluation_shader;
case string_hash(".tesc"): return shaderc_shader_kind::shaderc_tess_control_shader;
case string_hash(".frag"): return shaderc_shader_kind::shaderc_fragment_shader;
default: return shaderc_shader_kind::shaderc_miss_shader;
}
}
ShaderStage GlslToSpirv::GetShaderStage(shaderc_shader_kind kind)
{
switch (kind)
{
case shaderc_shader_kind::shaderc_vertex_shader:
return ShaderStage::VERTEX;
case shaderc_shader_kind::shaderc_fragment_shader:
return ShaderStage::FRAGMENT;
default:
return ShaderStage::NONE;
}
}
void GlslToSpirv::LoadShaderLayout(ShaderProgram* program, const pmr::vector<uint32_t>& spirv) {
ShaderStage stage = program->GetStage();
spirv_cross::Compiler compiler(spirv.data(), spirv.size());
auto resources = compiler.get_shader_resources();
ShaderDescriptorSet descriptorSet;
// 遍历 uniform buffers (UBO)
for (auto& res : resources.uniform_buffers)
{
ShaderDescriptorLayout layout{};
layout.set = compiler.get_decoration(res.id, spv::Decoration::DecorationDescriptorSet);
layout.binding = compiler.get_decoration(res.id, spv::Decoration::DecorationBinding);
auto type = compiler.get_type(res.type_id);
uint32_t size = type.width;
uint32_t i = 0;
for (auto& member_type : type.member_types)
{
auto tmp = compiler.get_type(member_type);
size += (uint32_t)(compiler.get_declared_struct_member_size(type, i++));
}
layout.size = size;
layout.stageFlags = stage;
layout.type = ShaderDescriptorType::UNIFORM_BUFFER;
descriptorSet.push_back(layout);
}
// 遍历 sampled images (采样器/纹理)
for (const auto& res : resources.sampled_images) {
ShaderDescriptorLayout layout{};
layout.set = compiler.get_decoration(res.id, spv::Decoration::DecorationDescriptorSet);
layout.binding = compiler.get_decoration(res.id, spv::Decoration::DecorationBinding);
layout.size = 0;
layout.stageFlags = stage;
layout.type = ShaderDescriptorType::SAMPLER;
descriptorSet.push_back(layout);
}
ShaderLayoutTable.emplace(program->GetGuid(), descriptorSet);
}
void GlslToSpirv::GetShaderLayout(ShaderDescriptorSet& descriptorSet, pmr::vector<ShaderProgram*>& programList)
{
table<uint32_t, uint32_t> idTable{ FramePool() };
for (auto program : programList) {
auto& itSet = ShaderLayoutTable[program->GetGuid()];
for (ShaderDescriptorLayout& layout : itSet) {
uint32_t id = (layout.set << 8) + layout.binding;
auto itId = idTable.find(id);
if (itId == idTable.end()) {
descriptorSet.emplace_back(layout);
idTable[id] = descriptorSet.size() - 1;
}
else {
using ::operator|=;
ShaderDescriptorLayout& desc = descriptorSet[itId->second];
desc.stageFlags |= layout.stageFlags;
}
}
}
std::sort(descriptorSet.begin(), descriptorSet.end(), [](auto& k1, auto& k2) {
if (k1.set != k2.set) {
return k1.set < k2.set;
}
return k1.binding < k2.binding;
});
}
}