comparison semiconginev2/old/vulkan/shader.nim @ 1218:56781cc0fc7c compiletime-tests

did: renamge main package
author sam <sam@basx.dev>
date Wed, 17 Jul 2024 21:01:37 +0700
parents semicongine/old/vulkan/shader.nim@a3eb305bcac2
children
comparison
equal deleted inserted replaced
1217:f819a874058f 1218:56781cc0fc7c
1 import std/typetraits
2 import std/os
3 import std/enumerate
4 import std/logging
5 import std/hashes
6 import std/strformat
7 import std/strutils
8
9 import ../core
10 import ./device
11
12 const DEFAULT_SHADER_VERSION = 450
13 const DEFAULT_SHADER_ENTRYPOINT = "main"
14
15 let logger = newConsoleLogger()
16 addHandler(logger)
17
18 type
19 ShaderModule* = object
20 device: Device
21 vk*: VkShaderModule
22 stage*: VkShaderStageFlagBits
23 configuration*: ShaderConfiguration
24 ShaderConfiguration* = object
25 name*: string
26 vertexBinary: seq[uint32]
27 fragmentBinary: seq[uint32]
28 entrypoint: string
29 inputs*: seq[ShaderAttribute]
30 intermediates*: seq[ShaderAttribute]
31 outputs*: seq[ShaderAttribute]
32 uniforms*: seq[ShaderAttribute]
33 samplers*: seq[ShaderAttribute]
34
35 proc `$`*(shader: ShaderConfiguration): string =
36 shader.name
37 # &"Inputs: {shader.inputs}, Uniforms: {shader.uniforms}, Samplers: {shader.samplers}"
38
39 proc compileGlslToSPIRV(stage: VkShaderStageFlagBits, shaderSource: string, entrypoint: string): seq[uint32] {.compileTime.} =
40 func stage2string(stage: VkShaderStageFlagBits): string {.compileTime.} =
41 case stage
42 of VK_SHADER_STAGE_VERTEX_BIT: "vert"
43 of VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT: "tesc"
44 of VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT: "tese"
45 of VK_SHADER_STAGE_GEOMETRY_BIT: "geom"
46 of VK_SHADER_STAGE_FRAGMENT_BIT: "frag"
47 of VK_SHADER_STAGE_COMPUTE_BIT: "comp"
48 else: ""
49
50 when defined(nimcheck): # will not run if nimcheck is running
51 return result
52
53 let
54 stagename = stage2string(stage)
55 shaderHash = hash(shaderSource)
56 shaderfile = getTempDir() / &"shader_{shaderHash}.{stagename}"
57
58
59 if not shaderfile.fileExists:
60 echo "shader of type ", stage, ", entrypoint ", entrypoint
61 for i, line in enumerate(shaderSource.splitlines()):
62 echo " ", i + 1, " ", line
63 var glslExe = currentSourcePath.parentDir.parentDir.parentDir / "tools" / "glslangValidator"
64 when defined(windows):
65 glslExe = glslExe & "." & ExeExt
66 let command = &"{glslExe} --entry-point {entrypoint} -V --stdin -S {stagename} -o {shaderfile}"
67 echo "run: ", command
68 discard StaticExecChecked(
69 command = command,
70 input = shaderSource
71 )
72 else:
73 echo &"shaderfile {shaderfile} is up-to-date"
74
75 when defined(mingw) and defined(linux): # required for crosscompilation, path separators get messed up
76 let shaderbinary = staticRead shaderfile.replace("\\", "/")
77 else:
78 let shaderbinary = staticRead shaderfile
79
80 var i = 0
81 while i < shaderbinary.len:
82 result.add(
83 (uint32(shaderbinary[i + 0]) shl 0) or
84 (uint32(shaderbinary[i + 1]) shl 8) or
85 (uint32(shaderbinary[i + 2]) shl 16) or
86 (uint32(shaderbinary[i + 3]) shl 24)
87 )
88 i += 4
89
90 proc compileGlslCode(
91 stage: VkShaderStageFlagBits,
92 inputs: openArray[ShaderAttribute] = [],
93 uniforms: openArray[ShaderAttribute] = [],
94 samplers: openArray[ShaderAttribute] = [],
95 outputs: openArray[ShaderAttribute] = [],
96 version = DEFAULT_SHADER_VERSION,
97 entrypoint = DEFAULT_SHADER_ENTRYPOINT,
98 main: string
99 ): seq[uint32] {.compileTime.} =
100
101 let code = @[&"#version {version}", "#extension GL_EXT_scalar_block_layout : require", ""] &
102 (if inputs.len > 0: inputs.GlslInput() & @[""] else: @[]) &
103 (if uniforms.len > 0: uniforms.GlslUniforms(binding = 0) & @[""] else: @[]) &
104 (if samplers.len > 0: samplers.GlslSamplers(basebinding = if uniforms.len > 0: 1 else: 0) & @[""] else: @[]) &
105 (if outputs.len > 0: outputs.GlslOutput() & @[""] else: @[]) &
106 @[&"void {entrypoint}(){{"] &
107 main &
108 @[&"}}"]
109 compileGlslToSPIRV(stage, code.join("\n"), entrypoint)
110
111 proc CreateShaderConfiguration*(
112 name: string,
113 inputs: openArray[ShaderAttribute] = [],
114 intermediates: openArray[ShaderAttribute] = [],
115 outputs: openArray[ShaderAttribute] = [],
116 uniforms: openArray[ShaderAttribute] = [],
117 samplers: openArray[ShaderAttribute] = [],
118 version = DEFAULT_SHADER_VERSION,
119 entrypoint = DEFAULT_SHADER_ENTRYPOINT,
120 vertexCode: string,
121 fragmentCode: string,
122 ): ShaderConfiguration {.compileTime.} =
123 ShaderConfiguration(
124 name: name,
125 vertexBinary: compileGlslCode(
126 stage = VK_SHADER_STAGE_VERTEX_BIT,
127 inputs = inputs,
128 outputs = intermediates,
129 uniforms = uniforms,
130 samplers = samplers,
131 main = vertexCode,
132 ),
133 fragmentBinary: compileGlslCode(
134 stage = VK_SHADER_STAGE_FRAGMENT_BIT,
135 inputs = intermediates,
136 outputs = outputs,
137 uniforms = uniforms,
138 samplers = samplers,
139 main = fragmentCode,
140 ),
141 entrypoint: entrypoint,
142 inputs: @inputs,
143 intermediates: @intermediates,
144 outputs: @outputs,
145 uniforms: @uniforms,
146 samplers: @samplers,
147 )
148
149
150 proc CreateShaderModules*(
151 device: Device,
152 shaderConfiguration: ShaderConfiguration,
153 ): (ShaderModule, ShaderModule) =
154 assert device.vk.Valid
155 assert len(shaderConfiguration.vertexBinary) > 0
156 assert len(shaderConfiguration.fragmentBinary) > 0
157
158 result[0].device = device
159 result[1].device = device
160 result[0].configuration = shaderConfiguration
161 result[1].configuration = shaderConfiguration
162 result[0].stage = VK_SHADER_STAGE_VERTEX_BIT
163 result[1].stage = VK_SHADER_STAGE_FRAGMENT_BIT
164
165 var createInfoVertex = VkShaderModuleCreateInfo(
166 sType: VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO,
167 codeSize: uint(shaderConfiguration.vertexBinary.len * sizeof(uint32)),
168 pCode: addr(shaderConfiguration.vertexBinary[0]),
169 )
170 checkVkResult vkCreateShaderModule(device.vk, addr(createInfoVertex), nil, addr(result[0].vk))
171 var createInfoFragment = VkShaderModuleCreateInfo(
172 sType: VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO,
173 codeSize: uint(shaderConfiguration.fragmentBinary.len * sizeof(uint32)),
174 pCode: addr(shaderConfiguration.fragmentBinary[0]),
175 )
176 checkVkResult vkCreateShaderModule(device.vk, addr(createInfoFragment), nil, addr(result[1].vk))
177
178 proc GetVertexInputInfo*(
179 shaderConfiguration: ShaderConfiguration,
180 bindings: var seq[VkVertexInputBindingDescription],
181 attributes: var seq[VkVertexInputAttributeDescription],
182 baseBinding = 0'u32
183 ): VkPipelineVertexInputStateCreateInfo =
184 var location = 0'u32
185 var binding = baseBinding
186
187 for attribute in shaderConfiguration.inputs:
188 bindings.add VkVertexInputBindingDescription(
189 binding: binding,
190 stride: uint32(attribute.Size),
191 inputRate: if attribute.perInstance: VK_VERTEX_INPUT_RATE_INSTANCE else: VK_VERTEX_INPUT_RATE_VERTEX,
192 )
193 # allows to submit larger data structures like Mat44, for most other types will be 1
194 for i in 0 ..< attribute.thetype.NumberOfVertexInputAttributeDescriptors:
195 attributes.add VkVertexInputAttributeDescription(
196 binding: binding,
197 location: location,
198 format: attribute.thetype.GetVkFormat,
199 offset: uint32(i * attribute.Size(perDescriptor = true)),
200 )
201 location += uint32(attribute.thetype.NLocationSlots)
202 inc binding
203
204 return VkPipelineVertexInputStateCreateInfo(
205 sType: VK_STRUCTURE_TYPE_PIPELINE_VERTEX_INPUT_STATE_CREATE_INFO,
206 vertexBindingDescriptionCount: uint32(bindings.len),
207 pVertexBindingDescriptions: bindings.ToCPointer,
208 vertexAttributeDescriptionCount: uint32(attributes.len),
209 pVertexAttributeDescriptions: attributes.ToCPointer,
210 )
211
212
213 proc GetPipelineInfo*(shader: ShaderModule): VkPipelineShaderStageCreateInfo =
214 VkPipelineShaderStageCreateInfo(
215 sType: VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
216 stage: shader.stage,
217 module: shader.vk,
218 pName: cstring(shader.configuration.entrypoint),
219 )
220
221 proc Destroy*(shader: var ShaderModule) =
222 assert shader.device.vk.Valid
223 assert shader.vk.Valid
224 shader.device.vk.vkDestroyShaderModule(shader.vk, nil)
225 shader.vk.Reset