diff src/vulkan_api/vulkan_api_generator.nim @ 84:8412f433dc46

fix: tons of errors in wrapper generator, can now compile, extension function not loaded yet it seems
author Sam <sam@basx.dev>
date Fri, 24 Feb 2023 01:32:45 +0700
parents 5e19aead2b61
children e872cf354110
line wrap: on
line diff
--- a/src/vulkan_api/vulkan_api_generator.nim	Thu Feb 23 00:34:38 2023 +0700
+++ b/src/vulkan_api/vulkan_api_generator.nim	Fri Feb 24 01:32:45 2023 +0700
@@ -83,6 +83,48 @@
 func tableSorted(table: Table[int, string]): seq[(int, string)] =
   result = toSeq(table.pairs)
   result.sort((a, b) => cmp(a[0], b[0]))
+func findType(declNode: XmlNode): string =
+  # examples:
+  # char** -> cstringArray
+  # void* -> pointer
+  # char* -> cstring
+  #
+  # int* -> ptr int
+  # void** -> ptr pointer
+  # int** -> ptr ptr int
+  var basetype = ""
+  var apointer = ""
+  var arraylen = ""
+  for child in declNode:
+    if child.kind == xnText:
+        if "[" in child.text:
+          if "[" in child.text and "]" in child.text:
+            arraylen = child.text.strip(chars={'[', ']'}).replace("][", "*")
+          else:
+            arraylen = declNode.child("enum")[0].text
+        else:
+          for i in 0 ..< child.text.count('*'):
+            apointer = apointer & "ptr "
+    elif child.tag == "type":
+      basetype = mapType(child[0].text)
+  if basetype == "void":
+    if apointer.count("ptr ") > 0:
+      basetype = "pointer"
+      apointer = apointer[0 ..< ^4]
+  elif basetype == "char":
+    if apointer.count("ptr ") == 1:
+      basetype = "cstring"
+      apointer = ""
+    elif apointer.count("ptr ") == 2:
+      basetype = "cstringArray"
+      apointer = ""
+    elif apointer.count("ptr ") > 2:
+      basetype = "cstringArray"
+      apointer = apointer[0 ..< ^8]
+
+  result = &"{apointer}{basetype}"
+  if arraylen != "":
+    result = &"array[{arraylen}, {result}]"
 
 # serializers
 func serializeEnum(node: XmlNode, api: XmlNode): seq[string] =
@@ -174,14 +216,18 @@
 
   # generate bitsets (normal enums in the C API, but bitfield-enums in Nim)
   elif node.attr("type") == "bitmask":
+    var predefined_enum_sets: seq[string]
     for value in node.findAll("enum"):
       if value.hasAttr("bitpos"):
         values[smartParseInt(value.attr("bitpos"))] = value.attr("name")
       elif node.attr("name") == "VkVideoEncodeRateControlModeFlagBitsKHR": # special exception, for some reason this has values instead of bitpos
         values[smartParseInt(value.attr("value"))] = value.attr("name")
+      elif value.hasAttr("value"): # create a const that has multiple bits set
+        predefined_enum_sets.add &"  {value.attr(\"name\")}* = {value.attr(\"value\")}"
+
     if values.len > 0:
       if node.hasAttr("bitwidth"):
-        result.add "  " & name & "* {.size: " & $(smartParseInt(node.attr("bitwidth")) div 8) & ".} = enum"
+        result.add "  " & name & "* {.size: 8.} = enum"
       else:
         result.add "  " & name & "* {.size: sizeof(cint).} = enum"
       for (bitpos, enumvalue) in tableSorted(values):
@@ -192,19 +238,30 @@
         let enumEntry = &"    {enumvalue} = 0b{value}"
         if not (enumEntry in result): # the specs define duplicate entries for backwards compat
           result.add enumEntry
-    let cApiName = name.replace("FlagBits", "Flags")
-    if node.hasAttr("bitwidth"): # assumes this is always 64
-      if values.len > 0:
-        result.add &"""converter BitsetToNumber*(flags: openArray[{name}]): {cApiName} =
-  for flag in flags:
-    result = {cApiName}(uint64(result) or uint(flag))"""
-        result.add "type"
-    else:
-      if values.len > 0:
-        result.add &"""converter BitsetToNumber*(flags: openArray[{name}]): {cApiName} =
-  for flag in flags:
-    result = {cApiName}(uint(result) or uint(flag))"""
-        result.add "type"
+      let cApiName = name.replace("FlagBits", "Flags")
+      if node.hasAttr("bitwidth"): # assuming this attribute is always 64
+        if values.len > 0:
+          result.add &"""converter BitsetToNumber*(flags: openArray[{name}]): {cApiName} =
+    for flag in flags:
+      result = {cApiName}(int64(result) or int64(flag))"""
+          result.add &"""converter NumberToBitset*(number: {cApiName}): seq[{name}] =
+        for value in {name}.items:
+          if (value.ord and int64(number)) > 0:
+            result.add value"""
+      else:
+        if values.len > 0:
+          result.add &"""func toBits*(flags: openArray[{name}]): {cApiName} =
+    for flag in flags:
+      result = {cApiName}(uint(result) or uint(flag))"""
+          result.add &"""func toEnums*(number: {cApiName}): seq[{name}] =
+    for value in {name}.items:
+      if (value.ord and cint(number)) > 0:
+        result.add value"""
+      if predefined_enum_sets.len > 0:
+        result.add "const"
+        result.add predefined_enum_sets
+      result.add "type"
+
 
 func serializeStruct(node: XmlNode): seq[string] =
   let name = node.attr("name")
@@ -215,24 +272,7 @@
   for member in node.findAll("member"):
     if not member.hasAttr("api") or member.attr("api") == "vulkan":
       let fieldname = member.child("name")[0].text.strip(chars={'_'})
-      var fieldtype = member.child("type")[0].text.strip(chars={'_'})
-      # detect pointers
-      for child in member:
-        if child.kind == xnText and child.text.strip() == "*":
-          fieldtype = &"ptr {mapType(fieldtype)}"
-        elif child.kind == xnText and child.text.strip() == "* const*":
-          fieldtype = "cstringArray"
-      fieldtype = mapType(fieldtype)
-      # detect arrays
-      for child in member:
-        if child.kind == xnText and child.text.endsWith("]"):
-          var thelen = ""
-          if "[" in child.text:
-            thelen = child.text.strip(chars={'[', ']'}).replace("][", "*")
-          else:
-            thelen = member.child("enum")[0].text
-          fieldtype = &"array[{thelen}, {fieldtype}]"
-      result.add &"    {mapName(fieldname)}*: {fieldtype}"
+      result.add &"    {mapName(fieldname)}*: {findType(member)}"
 
 func serializeFunctiontypes(api: XmlNode): seq[string] =
   for node in api.findAll("type"):
@@ -330,11 +370,7 @@
   for param in node:
     if param.tag == "param" and param.attr("api") in ["", "vulkan"]:
       let fieldname = param.child("name")[0].text.strip(chars={'_'})
-      var fieldtype = param.child("type")[0].text.strip(chars={'_'})
-      if param[param.len - 2].kind == xnText and param[param.len - 2].text.strip() == "*":
-        fieldtype = &"ptr {mapType(fieldtype)}"
-      fieldtype = mapType(fieldtype)
-      params.add &"{mapName(fieldname)}: {fieldtype}"
+      params.add &"{mapName(fieldname)}: {findType(param)}"
   let allparams = params.join(", ")
   return (name, &"proc({allparams}): {resulttype} {{.stdcall.}}")
 
@@ -378,6 +414,11 @@
     "basetypes": @[
       "import std/dynlib",
       "import std/tables",
+      "import std/strutils",
+      "import std/logging",
+      "import std/macros",
+      "import std/private/digitsutils",
+      "from typetraits import HoleyEnum",
       "type",
       "  VkHandle* = distinct uint",
       "  VkNonDispatchableHandle* = distinct uint",
@@ -404,6 +445,13 @@
       error "Vulkan error: ", astToStr(call), " returned ", $value
       raise newException(Exception, "Vulkan error: " & astToStr(call) &
           " returned " & $value)""",
+    """
+# custom enum iteration (for enum values > 2^16)
+macro enumFullRange(a: typed): untyped =
+  newNimNode(nnkBracket).add(a.getType[1][1..^1])
+
+iterator items[T: HoleyEnum](E: typedesc[T]): T =
+  for a in enumFullRange(E): yield a""",
     ],
     "structs": @["type"],
     "enums": @["type"],
@@ -583,11 +631,13 @@
   mainout.add "  let instance = VkInstance(0)"
   for l in GLOBAL_COMMANDS:
     mainout.add procLoads[l]
-  writeFile outdir / &"api.nim", mainout.join("\n")
-
   mainout.add ""
   mainout.add "converter VkBool2NimBool*(a: VkBool32): bool = a > 0"
   mainout.add "converter NimBool2VkBool*(a: bool): VkBool32 = VkBool32(a)"
+  mainout.add "proc `$`*(x: uint32): string {.raises: [].} = addInt(result, x)"
+
+  writeFile outdir / &"api.nim", mainout.join("\n")
+
 
   for filename, filecontent in outputFiles.pairs:
     if filename.startsWith("platform/"):