changeset 100:f4c6ff03b12c

add: initial version of better shader-shit
author Sam <sam@basx.dev>
date Fri, 17 Mar 2023 01:11:18 +0700
parents 4deffc94484a
children f0ceb8c17d2c
files src/semicongine/vulkan/pipeline.nim src/semicongine/vulkan/shader.nim src/semicongine/vulkan/vertex.nim
diffstat 3 files changed, 138 insertions(+), 86 deletions(-) [+]
line wrap: on
line diff
--- a/src/semicongine/vulkan/pipeline.nim	Tue Mar 14 13:21:40 2023 +0700
+++ b/src/semicongine/vulkan/pipeline.nim	Fri Mar 17 01:11:18 2023 +0700
@@ -9,9 +9,11 @@
   Pipeline = object
     device: Device
     vk*: VkPipeline
+    layout: VkPipelineLayout
+    descriptorLayout: VkDescriptorSetLayout
 
 
-proc createPipeline*(renderPass: RenderPass, vertexShader: VertexShader, fragmentShader: FragmentShader): Pipeline =
+proc createPipeline*(renderPass: RenderPass, vertexShader: Shader, fragmentShader: Shader): Pipeline =
   assert renderPass.vk.valid
   assert renderPass.device.vk.valid
   result.device = renderPass.device
@@ -33,33 +35,32 @@
     bindingCount: uint32(descriptorLayoutBinding.len),
     pBindings: descriptorLayoutBinding.toCPointer
   )
-  var descriptorLayout: VkDescriptorSetLayout
   checkVkResult vkCreateDescriptorSetLayout(
     renderPass.device.vk,
     addr(layoutCreateInfo),
     nil,
-    addr(descriptorLayout),
+    addr(result.descriptorLayout),
   )
-  var pushConstant = VkPushConstantRange(
-    stageFlags: toBits shaderStage,
-    offset: 0,
-    size: 0,
-  )
-
-  var descriptorSets: seq[VkDescriptorSetLayout] = @[descriptorLayout]
-  var pushConstants: seq[VkPushConstantRange] = @[pushConstant]
+  # var pushConstant = VkPushConstantRange(
+    # stageFlags: toBits shaderStage,
+    # offset: 0,
+    # size: 0,
+  # )
+  var descriptorSets: seq[VkDescriptorSetLayout] = @[result.descriptorLayout]
+  # var pushConstants: seq[VkPushConstantRange] = @[pushConstant]
   var pipelineLayoutInfo = VkPipelineLayoutCreateInfo(
       sType: VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO,
       setLayoutCount: uint32(descriptorSets.len),
       pSetLayouts: descriptorSets.toCPointer,
-      pushConstantRangeCount: uint32(pushConstants.len),
-      pPushConstantRanges: pushConstants.toCPointer,
+      # pushConstantRangeCount: uint32(pushConstants.len),
+      # pPushConstantRanges: pushConstants.toCPointer,
     )
-  var pipelineLayout: VkPipelineLayout
-  checkVkResult vkCreatePipelineLayout(renderPass.device.vk, addr(pipelineLayoutInfo), nil, addr(pipelineLayout))
+  checkVkResult vkCreatePipelineLayout(renderPass.device.vk, addr(pipelineLayoutInfo), nil, addr(result.layout))
 
   var
-    vertexInputInfo = vertexShader.getVertexBindings()
+    bindings: seq[VkVertexInputBindingDescription]
+    attributes: seq[VkVertexInputAttributeDescription]
+    vertexInputInfo = vertexShader.getVertexInputInfo(bindings, attributes)
     inputAssembly = VkPipelineInputAssemblyStateCreateInfo(
       sType: VK_STRUCTURE_TYPE_PIPELINE_INPUT_ASSEMBLY_STATE_CREATE_INFO,
       topology: VK_PRIMITIVE_TOPOLOGY_TRIANGLE_LIST,
@@ -116,25 +117,25 @@
       dynamicStateCount: uint32(dynamicStates.len),
       pDynamicStates: dynamicStates.toCPointer,
     )
-  var stages = @[vertexShader.getPipelineInfo(), fragmentShader.getPipelineInfo()]
-  var createInfo = VkGraphicsPipelineCreateInfo(
-    sType: VK_STRUCTURE_TYPE_GRAPHICS_PIPELINE_CREATE_INFO,
-    stageCount: uint32(stages.len),
-    pStages: stages.toCPointer,
-    pVertexInputState: addr(vertexInputInfo),
-    pInputAssemblyState: addr(inputAssembly),
-    pViewportState: addr(viewportState),
-    pRasterizationState: addr(rasterizer),
-    pMultisampleState: addr(multisampling),
-    pDepthStencilState: nil,
-    pColorBlendState: addr(colorBlending),
-    pDynamicState: addr(dynamicState),
-    layout: pipelineLayout,
-    renderPass: renderPass.vk,
-    subpass: 0,
-    basePipelineHandle: VkPipeline(0),
-    basePipelineIndex: -1,
-  )
+    stages = @[vertexShader.getPipelineInfo(), fragmentShader.getPipelineInfo()]
+    createInfo = VkGraphicsPipelineCreateInfo(
+      sType: VK_STRUCTURE_TYPE_GRAPHICS_PIPELINE_CREATE_INFO,
+      stageCount: uint32(stages.len),
+      pStages: stages.toCPointer,
+      pVertexInputState: addr(vertexInputInfo),
+      pInputAssemblyState: addr(inputAssembly),
+      pViewportState: addr(viewportState),
+      pRasterizationState: addr(rasterizer),
+      pMultisampleState: addr(multisampling),
+      pDepthStencilState: nil,
+      pColorBlendState: addr(colorBlending),
+      pDynamicState: addr(dynamicState),
+      layout: result.layout,
+      renderPass: renderPass.vk,
+      subpass: 0,
+      basePipelineHandle: VkPipeline(0),
+      basePipelineIndex: -1,
+    )
   checkVkResult vkCreateGraphicsPipelines(
     renderpass.device.vk,
     VkPipelineCache(0),
@@ -147,6 +148,12 @@
 proc destroy*(pipeline: var Pipeline) =
   assert pipeline.device.vk.valid
   assert pipeline.vk.valid
+  assert pipeline.layout.valid
+  assert pipeline.descriptorLayout.valid
 
+  pipeline.device.vk.vkDestroyDescriptorSetLayout(pipeline.descriptorLayout, nil)
+  pipeline.device.vk.vkDestroyPipelineLayout(pipeline.layout, nil)
   pipeline.device.vk.vkDestroyPipeline(pipeline.vk, nil)
+  pipeline.descriptorLayout.reset()
+  pipeline.layout.reset()
   pipeline.vk.reset()
--- a/src/semicongine/vulkan/shader.nim	Tue Mar 14 13:21:40 2023 +0700
+++ b/src/semicongine/vulkan/shader.nim	Fri Mar 17 01:11:18 2023 +0700
@@ -1,19 +1,87 @@
+import std/macros
 import std/os
+import std/enumerate
+import std/logging
 import std/hashes
 import std/strformat
+import std/strutils
 import std/compilesettings
 
 import ./api
 import ./device
 
+let logger = newConsoleLogger()
+addHandler(logger)
+
 type
-  VertexShader*[VertexType] = object
+  Shader*[InputAttributes, Uniforms] = object
     device: Device
-    vertexType*: VertexType
-    module*: VkShaderModule
-  FragmentShader* = object
-    device: Device
-    module*: VkShaderModule
+    vk*: VkShaderModule
+    entrypoint*: string
+    inputs*: InputAttributes
+    uniforms*: Uniforms
+    binary*: seq[uint32]
+
+# produce ast for: static shader string, inputs, uniforms, entrypoint
+
+dumpAstGen:
+  block:
+    const test = 1
+
+macro shader*(inputattributes: typed, uniforms: typed, device: Device, body: untyped): untyped =
+  var shadertype: NimNode
+  var entrypoint: NimNode
+  var version: NimNode
+  var code: NimNode
+  for node in body:
+    if node.kind == nnkCall and node[0].kind == nnkIdent and node[0].strVal == "shadertype":
+      expectKind(node[1], nnkStmtList)
+      expectKind(node[1][0], nnkIdent)
+      shadertype = node[1][0]
+    if node.kind == nnkCall and node[0].kind == nnkIdent and node[0].strVal == "entrypoint":
+      expectKind(node[1], nnkStmtList)
+      expectKind(node[1][0], nnkStrLit)
+      entrypoint = node[1][0]
+    if node.kind == nnkCall and node[0].kind == nnkIdent and node[0].strVal == "version":
+      expectKind(node[1], nnkStmtList)
+      expectKind(node[1][0], nnkIntLit)
+      version = node[1][0]
+    if node.kind == nnkCall and node[0].kind == nnkIdent and node[0].strVal == "code":
+      expectKind(node[1], nnkStmtList)
+      expectKind(node[1][0], {nnkStrLit, nnkTripleStrLit})
+      code = node[1][0]
+  var shadercode: seq[string]
+  shadercode.add &"#version {version.intVal}"
+  shadercode.add &"""void {entrypoint.strVal}(){{
+{code}
+}}"""
+  
+  return nnkBlockStmt.newTree(
+    newEmptyNode(),
+    nnkStmtList.newTree(
+        nnkConstSection.newTree(
+          nnkConstDef.newTree(
+            newIdentNode("shaderbinary"),
+            newEmptyNode(),
+            newCall("compileGLSLToSPIRV", shadertype, newStrLitNode(shadercode.join("\n")), entrypoint)
+          )
+        ),
+        nnkObjConstr.newTree(
+          nnkBracketExpr.newTree(
+            newIdentNode("Shader"),
+            inputattributes,
+            uniforms,
+          ),
+          nnkExprColonExpr.newTree(newIdentNode("device"), device),
+          nnkExprColonExpr.newTree(newIdentNode("entrypoint"), entrypoint),
+          nnkExprColonExpr.newTree(newIdentNode("binary"), newIdentNode("shaderbinary")),
+
+          # vk*: VkShaderModule
+          # inputs*: InputAttributes
+          # uniforms*: Uniforms
+        )
+      )
+    )
 
 proc staticExecChecked(command: string, input = ""): string {.compileTime.} =
   let (output, exitcode) = gorgeEx(
@@ -33,7 +101,7 @@
   of VK_SHADER_STAGE_COMPUTE_BIT: "comp"
   else: ""
 
-proc compileGLSLToSPIRV(stage: static VkShaderStageFlagBits, shaderSource: static string, entrypoint: string): seq[uint32] {.compileTime.} =
+proc compileGLSLToSPIRV*(stage: static VkShaderStageFlagBits, shaderSource: static string, entrypoint: static string): seq[uint32] {.compileTime.} =
   when defined(nimcheck): # will not run if nimcheck is running
     return result
   const
@@ -43,6 +111,10 @@
     shaderfile = getTempDir() / &"shader_{shaderHash}.{stagename}"
     projectPath = querySetting(projectPath)
 
+  echo "shader of type ", stage
+  for i, line in enumerate(shaderSource.splitlines()):
+    echo "  ", i + 1, " ", line
+
   discard staticExecChecked(
       command = &"{projectPath}/glslangValidator --entry-point {entrypoint} -V --stdin -S {stagename} -o {shaderfile}",
       input = shaderSource
@@ -52,12 +124,6 @@
     let shaderbinary = staticRead shaderfile.replace("\\", "/")
   else:
     let shaderbinary = staticRead shaderfile
-  when defined(linux):
-    discard staticExecChecked(command = fmt"rm {shaderfile}")
-  elif defined(windows):
-    discard staticExecChecked(command = fmt"cmd.exe /c del {shaderfile}")
-  else:
-    raise newException(Exception, "Unsupported operating system")
 
   var i = 0
   while i < shaderbinary.len:
@@ -69,46 +135,24 @@
     )
     i += 4
 
-proc createVertexShader*[VertexType](device: Device, shader: static string, vertexType: VertexType, entryPoint: static string = "main"): VertexShader[VertexType] =
-  assert device.vk.valid
-
-  const constcode = compileGLSLToSPIRV(VK_SHADER_STAGE_VERTEX_BIT, shader, entryPoint)
-  var code = constcode
+proc loadShaderCode*(device: Device, binary: var seq[uint32]): VkShaderModule =
   var createInfo = VkShaderModuleCreateInfo(
     sType: VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO,
-    codeSize: uint(code.len * sizeof(uint32)),
-    pCode: if code.len > 0: addr(code[0]) else: nil,
+    codeSize: uint(binary.len * sizeof(uint32)),
+    pCode: if binary.len > 0: addr(binary[0]) else: nil,
   )
-  checkVkResult vkCreateShaderModule(device.vk, addr(createInfo), nil, addr(result.module))
-
-proc createFragmentShader*(device: Device, shader: static string, entryPoint: static string = "main"): FragmentShader =
-  assert device.vk.valid
+  checkVkResult vkCreateShaderModule(device.vk, addr(createInfo), nil, addr(result))
 
-  const constcode = compileGLSLToSPIRV(VK_SHADER_STAGE_FRAGMENT_BIT, shader, entryPoint)
-  var code = constcode
-  var createInfo = VkShaderModuleCreateInfo(
-    sType: VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO,
-    codeSize: uint(code.len * sizeof(uint32)),
-    pCode: if code.len > 0: addr(code[0]) else: nil,
-  )
-  checkVkResult vkCreateShaderModule(device.vk, addr(createInfo), nil, addr(result.module))
-
-proc getPipelineInfo*(shader: VertexShader|FragmentShader, entryPoint = "main"): VkPipelineShaderStageCreateInfo =
+proc getPipelineInfo*(shader: Shader): VkPipelineShaderStageCreateInfo =
   VkPipelineShaderStageCreateInfo(
     sType: VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
     stage: VK_SHADER_STAGE_VERTEX_BIT,
-    module: shader.module,
-    pName: cstring(entryPoint),
+    module: shader.vk,
+    pName: cstring(shader.entrypoint),
   )
 
-proc destroy*(shader: var VertexShader) =
+proc destroy*(shader: var Shader) =
   assert shader.device.vk.valid
-  assert shader.module.valid
-  shader.device.vk.vkDestroyShaderModule(shader.module, nil)
-  shader.module.reset
-
-proc destroy*(shader: var FragmentShader) =
-  assert shader.device.vk.valid
-  assert shader.module.valid
-  shader.device.vk.vkDestroyShaderModule(shader.module, nil)
-  shader.module.reset
+  assert shader.vk.valid
+  shader.device.vk.vkDestroyShaderModule(shader.vk, nil)
+  shader.vk.reset
--- a/src/semicongine/vulkan/vertex.nim	Tue Mar 14 13:21:40 2023 +0700
+++ b/src/semicongine/vulkan/vertex.nim	Fri Mar 17 01:11:18 2023 +0700
@@ -80,14 +80,15 @@
   elif T is TVec4[float64]: VK_FORMAT_R64G64B64A64_SFLOAT
   else: {.error: "Unsupported vertex attribute type".}
 
-proc getVertexBindings*(shader: VertexShader): VkPipelineVertexInputStateCreateInfo =
+proc getVertexInputInfo*(
+  shader: Shader,
+  bindings: var seq[VkVertexInputBindingDescription],
+  attributes: var seq[VkVertexInputAttributeDescription],
+): VkPipelineVertexInputStateCreateInfo =
   var location = 0'u32
   var binding = 0'u32
-  var offset = 0'u32
-  var bindings: seq[VkVertexInputBindingDescription]
-  var attributes: seq[VkVertexInputAttributeDescription]
 
-  for name, value in shader.vertexType.fieldPairs:
+  for name, value in shader.inputs.fieldPairs:
     bindings.add VkVertexInputBindingDescription(
       binding: binding,
       stride: uint32(sizeof(value)),