From 2aa7f5f9b3c84329a3457eca8bd8f66cf85d4331 Mon Sep 17 00:00:00 2001 From: Liam Diprose Date: Fri, 12 Sep 2025 12:18:47 +1200 Subject: [PATCH] Rust GPU --- Cargo.lock | 103 +++++++ Cargo.toml | 8 + codegen/src/emit.rs | 7 +- codegen/src/main.rs | 1 + codegen/src/rust_gpu.rs | 582 ++++++++++++++++++++++++++++++++++++++++ flake.lock | 101 +++++++ flake.nix | 61 +++++ src/lib.rs | 193 +------------ 8 files changed, 868 insertions(+), 188 deletions(-) create mode 100644 codegen/src/rust_gpu.rs create mode 100644 flake.lock create mode 100644 flake.nix diff --git a/Cargo.lock b/Cargo.lock index 0ddb86d..00d3382 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,109 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "autocfg" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + [[package]] name = "geometric_algebra" version = "0.3.0" +dependencies = [ + "spirv-std", +] + +[[package]] +name = "glam" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5418c17512bdf42730f9032c74e1ae39afc408745ebb2acf72fbc4691c17945" +dependencies = [ + "libm", +] + +[[package]] +name = "libm" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7012b1bbb0719e1097c47611d3898568c546d597c2e74d66f6087edd5233ff4" + +[[package]] +name = "num-traits" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f30b0abd723be7e2ffca1272140fac1a2f084c77ec3e123c192b66af1ee9e6c2" +dependencies = [ + "autocfg", + "libm", +] + +[[package]] +name = "proc-macro2" +version = "1.0.67" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d433d9f1a3e8c1263d9456598b16fec66f4acc9a74dacffd35c7bb09b3a1328" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "spirv-std" +version = "0.9.0" +source = "git+https://github.com/EmbarkStudios/rust-gpu?rev=d0e374968a37d8a37c4f3509cd10719d384470f6#d0e374968a37d8a37c4f3509cd10719d384470f6" +dependencies = [ + "bitflags", + "glam", + "num-traits", + "spirv-std-macros", + "spirv-std-types", +] + +[[package]] +name = "spirv-std-macros" +version = "0.9.0" +source = "git+https://github.com/EmbarkStudios/rust-gpu?rev=d0e374968a37d8a37c4f3509cd10719d384470f6#d0e374968a37d8a37c4f3509cd10719d384470f6" +dependencies = [ + "proc-macro2", + "quote", + "spirv-std-types", + "syn", +] + +[[package]] +name = "spirv-std-types" +version = "0.9.0" +source = "git+https://github.com/EmbarkStudios/rust-gpu?rev=d0e374968a37d8a37c4f3509cd10719d384470f6#d0e374968a37d8a37c4f3509cd10719d384470f6" + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" diff --git a/Cargo.toml b/Cargo.toml index 55a1f89..5c0459b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,3 +7,11 @@ repository = "https://github.com/Lichtso/geometric_algebra/" keywords = ["math", "simd", "vector", "geometric-algebra", "geometry"] license = "MIT" edition = "2018" + +[dependencies.spirv-std] +git = "https://github.com/EmbarkStudios/rust-gpu" +rev = "d0e374968a37d8a37c4f3509cd10719d384470f6" + +# [dependencies] +# spirv-std = { path = "/home/liam/projects/abel/crates/spirv-std", version = "=0.9.0" } +# spirv-std = "0.9.0" diff --git a/codegen/src/emit.rs b/codegen/src/emit.rs index 84c79cd..3121c0e 100644 --- a/codegen/src/emit.rs +++ b/codegen/src/emit.rs @@ -1,4 +1,7 @@ -use crate::{algebra::BasisElement, ast::AstNode, glsl, rust}; +use crate::{algebra::BasisElement, ast::AstNode, glsl, + rust_gpu, + // rust +}; pub fn camel_to_snake_case(collector: &mut W, name: &str) -> std::io::Result<()> { let mut underscores = name.chars().enumerate().filter(|(_i, c)| c.is_uppercase()).map(|(i, _c)| i).peekable(); @@ -55,7 +58,7 @@ impl Emitter { impl Emitter { pub fn emit(&mut self, ast_node: &AstNode) -> std::io::Result<()> { - rust::emit_code(&mut self.rust_collector, ast_node, 0)?; + rust_gpu::emit_code(&mut self.rust_collector, ast_node, 0)?; glsl::emit_code(&mut self.glsl_collector, ast_node, 0)?; Ok(()) } diff --git a/codegen/src/main.rs b/codegen/src/main.rs index 9ec7c1e..bdd0cce 100644 --- a/codegen/src/main.rs +++ b/codegen/src/main.rs @@ -4,6 +4,7 @@ mod compile; mod emit; mod glsl; mod rust; +mod rust_gpu; use crate::{ algebra::{BasisElement, GeometricAlgebra, Involution, MultiVectorClass, MultiVectorClassRegistry, Product}, diff --git a/codegen/src/rust_gpu.rs b/codegen/src/rust_gpu.rs new file mode 100644 index 0000000..d0c4f75 --- /dev/null +++ b/codegen/src/rust_gpu.rs @@ -0,0 +1,582 @@ +use crate::{ + ast::{AstNode, DataType, Expression, ExpressionContent, Parameter}, + emit::{camel_to_snake_case, emit_element_name, emit_indentation}, +}; + +fn emit_data_type(collector: &mut W, data_type: &DataType) -> std::io::Result<()> { + match data_type { + DataType::Integer => collector.write_all(b"isize"), + DataType::SimdVector(size) if *size == 1 => collector.write_all(b"f32"), + DataType::SimdVector(size) => collector.write_fmt(format_args!("Vec{}", *size)), + DataType::MultiVector(class) if class.is_scalar() => collector.write_all(b"f32"), + DataType::MultiVector(class) => collector.write_fmt(format_args!("{}", class.class_name)), + } +} + +fn emit_expression(collector: &mut W, expression: &Expression) -> std::io::Result<()> { + match &expression.content { + ExpressionContent::None => unreachable!(), + ExpressionContent::Variable(_data_type, name) => { + collector.write_all(name.bytes().collect::>().as_slice())?; + } + ExpressionContent::InvokeInstanceMethod(_result_class, inner_expression, method_name, _, arguments) => { + emit_expression(collector, inner_expression)?; + collector.write_all(b".")?; + camel_to_snake_case(collector, method_name)?; + collector.write_all(b"(")?; + for (i, (_argument_class, argument)) in arguments.iter().enumerate() { + if i > 0 { + collector.write_all(b", ")?; + } + emit_expression(collector, argument)?; + } + collector.write_all(b")")?; + } + ExpressionContent::InvokeClassMethod(class, "Constructor", arguments) if class.is_scalar() => { + emit_expression(collector, &arguments[0].1)?; + } + ExpressionContent::InvokeClassMethod(class, "Constructor", arguments) => { + collector.write_fmt(format_args!("{} {{ groups: {}Groups {{ ", class.class_name, class.class_name))?; + for (i, (_argument_class, argument)) in arguments.iter().enumerate() { + if i > 0 { + collector.write_all(b", ")?; + } + collector.write_fmt(format_args!("g{}: ", i))?; + emit_expression(collector, argument)?; + } + collector.write_all(b" } }")?; + } + ExpressionContent::InvokeClassMethod(class, method_name, arguments) => { + emit_data_type(collector, &DataType::MultiVector(class))?; + collector.write_all(b"::")?; + camel_to_snake_case(collector, method_name)?; + collector.write_all(b"(")?; + for (i, (_argument_class, argument)) in arguments.iter().enumerate() { + if i > 0 { + collector.write_all(b", ")?; + } + emit_expression(collector, argument)?; + } + collector.write_all(b")")?; + } + ExpressionContent::Conversion(_source_class, _destination_class, inner_expression) => { + emit_expression(collector, inner_expression)?; + collector.write_all(b".into()")?; + } + ExpressionContent::Select(condition_expression, then_expression, else_expression) => { + collector.write_all(b"if ")?; + emit_expression(collector, condition_expression)?; + collector.write_all(b" { ")?; + emit_expression(collector, then_expression)?; + collector.write_all(b" } else { ")?; + emit_expression(collector, else_expression)?; + collector.write_all(b" }")?; + } + ExpressionContent::Access(inner_expression, array_index) => { + emit_expression(collector, inner_expression)?; + if !inner_expression.is_scalar() { + collector.write_fmt(format_args!(".group{}()", array_index))?; + } + } + ExpressionContent::Swizzle(inner_expression, indices) => { + if expression.size == 1 { + emit_expression(collector, inner_expression)?; + if inner_expression.size > 1 { + collector.write_fmt(format_args!("[{}]", indices[0]))?; + } + } else { + // Swizzle + emit_expression(collector, inner_expression)?; + collector.write_all(b".")?; + collector.write_all( + indices.iter() + .map(|component_index| { + match component_index { + 0 => 'x', + 1 => 'y', + 2 => 'z', + 3 => 'w', + _ => unimplemented!() + } + }) + .collect::() + .as_bytes() + )?; + + collector.write_all(b"()")?; + } + } + ExpressionContent::Gather(inner_expression, indices) => { + if expression.size == 1 && inner_expression.is_scalar() { + emit_expression(collector, inner_expression)?; + } else { + if expression.size > 1 { + emit_data_type(collector, &DataType::SimdVector(expression.size))?; + + if indices.len() > 1 { + collector.write_all(b"::from(")?; + } else { + collector.write_all(b"::splat(")?; + } + } + if indices.len() > 1 { + collector.write_all(b"[")?; + } + for (i, (array_index, component_index)) in indices.iter().enumerate() { + if i > 0 { + collector.write_all(b", ")?; + } + emit_expression(collector, inner_expression)?; + if !inner_expression.is_scalar() { + collector.write_fmt(format_args!(".group{}()", array_index))?; + if inner_expression.size > 1 { + collector.write_fmt(format_args!("[{}]", *component_index))?; + } + } + } + if indices.len() > 1 { + collector.write_all(b"]")?; + } + if expression.size > 1 { + collector.write_all(b")")?; + } + } + } + ExpressionContent::Constant(data_type, values) => match data_type { + DataType::Integer => collector.write_fmt(format_args!("{}", values[0] as f32))?, + DataType::SimdVector(_size) => { + if expression.size == 1 { + collector.write_fmt(format_args!("{:.1}", values[0] as f32))?; + } else { + emit_data_type(collector, &DataType::SimdVector(expression.size))?; + if values.len() > 1 { + collector.write_all(b"::from(")?; + collector.write_all(b"[")?; + } else { + collector.write_all(b"::splat(")?; + } + for (i, value) in values.iter().enumerate() { + if i > 0 { + collector.write_all(b", ")?; + } + collector.write_fmt(format_args!("{:.1}", *value as f32))?; + } + if values.len() > 1 { + collector.write_all(b"]")?; + } + collector.write_all(b")")?; + } + } + _ => unreachable!(), + }, + ExpressionContent::SquareRoot(inner_expression) => { + emit_expression(collector, inner_expression)?; + collector.write_all(b".sqrt()")?; + } + ExpressionContent::Add(lhs, rhs) + | ExpressionContent::Subtract(lhs, rhs) + | ExpressionContent::Multiply(lhs, rhs) + | ExpressionContent::Divide(lhs, rhs) + | ExpressionContent::LessThan(lhs, rhs) + | ExpressionContent::Equal(lhs, rhs) + | ExpressionContent::LogicAnd(lhs, rhs) + | ExpressionContent::BitShiftRight(lhs, rhs) => { + emit_expression(collector, lhs)?; + collector.write_all(match expression.content { + ExpressionContent::Add(_, _) => b" + ", + ExpressionContent::Subtract(_, _) => b" - ", + ExpressionContent::Multiply(_, _) => b" * ", + ExpressionContent::Divide(_, _) => b" / ", + ExpressionContent::LessThan(_, _) => b" < ", + ExpressionContent::Equal(_, _) => b" == ", + ExpressionContent::LogicAnd(_, _) => b" & ", + ExpressionContent::BitShiftRight(_, _) => b" >> ", + _ => unreachable!(), + })?; + emit_expression(collector, rhs)?; + } + } + Ok(()) +} + +fn emit_assign_trait(collector: &mut W, result: &Parameter, parameters: &[Parameter]) -> std::io::Result<()> { + if result.multi_vector_class() != parameters[0].multi_vector_class() { + return Ok(()); + } + collector.write_fmt(format_args!("impl {}Assign<", result.name))?; + emit_data_type(collector, ¶meters[1].data_type)?; + collector.write_all(b"> for ")?; + emit_data_type(collector, ¶meters[0].data_type)?; + collector.write_all(b" {\n fn ")?; + camel_to_snake_case(collector, result.name)?; + collector.write_all(b"_assign(&mut self, other: ")?; + emit_data_type(collector, ¶meters[1].data_type)?; + collector.write_all(b") {\n *self = (*self).")?; + camel_to_snake_case(collector, result.name)?; + collector.write_all(b"(other);\n }\n}\n\n") +} + +pub fn emit_code(collector: &mut W, ast_node: &AstNode, indentation: usize) -> std::io::Result<()> { + match &ast_node { + AstNode::None => {} + AstNode::Preamble => { + collector.write_all(b"#![allow(clippy::assign_op_pattern)]\n")?; + collector.write_all(b"use spirv_std::glam::{Vec2, Vec3, Vec4, Vec2Swizzles, Vec3Swizzles, Vec4Swizzles};\n")?; + collector.write_all(b"use spirv_std::num_traits::Float;\n")?; + collector.write_all(b"use crate::*;\nuse core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};\n\n")?; + } + AstNode::ClassDefinition { class } => { + if class.is_scalar() { + return Ok(()); + } + let element_count = class.grouped_basis.iter().fold(0, |a, b| a + b.len()); + let mut simd_widths = Vec::new(); + emit_indentation(collector, indentation)?; + collector.write_fmt(format_args!("#[derive(Clone, Copy)]\nstruct {}Groups {{\n", class.class_name))?; + for (j, group) in class.grouped_basis.iter().enumerate() { + emit_indentation(collector, indentation + 1)?; + collector.write_all(b"/// ")?; + for (i, element) in group.iter().enumerate() { + if i > 0 { + collector.write_all(b", ")?; + } + collector.write_fmt(format_args!("{}", element))?; + } + collector.write_all(b"\n")?; + emit_indentation(collector, indentation + 1)?; + collector.write_fmt(format_args!("g{}: ", j))?; + emit_data_type(collector, &DataType::SimdVector(group.len()))?; + collector.write_all(b",\n")?; + simd_widths.push(if group.len() == 1 { 1 } else { 4 }); + } + collector.write_all(b"}\n\n")?; + emit_indentation(collector, indentation)?; + collector.write_fmt(format_args!("#[derive(Clone, Copy)]\npub union {} {{\n", class.class_name))?; + emit_indentation(collector, indentation + 1)?; + collector.write_fmt(format_args!("groups: {}Groups,\n", class.class_name))?; + emit_indentation(collector, indentation + 1)?; + collector.write_all(b"/// ")?; + for (j, group) in class.grouped_basis.iter().enumerate() { + for (i, element) in group.iter().enumerate() { + if j > 0 || i > 0 { + collector.write_all(b", ")?; + } + collector.write_fmt(format_args!("{}", element))?; + } + for _ in group.len()..simd_widths[j] { + collector.write_all(b", 0")?; + } + } + collector.write_all(b"\n")?; + emit_indentation(collector, indentation + 1)?; + collector.write_fmt(format_args!("elements: [f32; {}],\n", simd_widths.iter().fold(0, |a, b| a + b)))?; + emit_indentation(collector, indentation)?; + collector.write_all(b"}\n\n")?; + emit_indentation(collector, indentation)?; + collector.write_fmt(format_args!("impl {} {{\n", class.class_name))?; + emit_indentation(collector, indentation + 1)?; + collector.write_all(b"#[allow(clippy::too_many_arguments)]\n")?; + emit_indentation(collector, indentation + 1)?; + collector.write_all(b"pub const fn new(")?; + let mut element_index = 0; + for group in class.grouped_basis.iter() { + for element in group.iter() { + if element_index > 0 { + collector.write_all(b", ")?; + } + emit_element_name(collector, element)?; + collector.write_all(b": f32")?; + element_index += 1; + } + } + collector.write_all(b") -> Self {\n")?; + emit_indentation(collector, indentation + 2)?; + collector.write_all(b"Self { elements: [")?; + element_index = 0; + for (j, group) in class.grouped_basis.iter().enumerate() { + for element in group.iter() { + if element_index > 0 { + collector.write_all(b", ")?; + } + emit_element_name(collector, element)?; + element_index += 1; + } + for _ in group.len()..simd_widths[j] { + collector.write_all(b", 0.0")?; + } + } + collector.write_all(b"] }\n")?; + emit_indentation(collector, indentation + 1)?; + collector.write_all(b"}\n")?; + emit_indentation(collector, indentation + 1)?; + collector.write_all(b"pub const fn from_groups(")?; + for (j, group) in class.grouped_basis.iter().enumerate() { + if j > 0 { + collector.write_all(b", ")?; + } + collector.write_fmt(format_args!("g{}: ", j))?; + emit_data_type(collector, &DataType::SimdVector(group.len()))?; + } + collector.write_all(b") -> Self {\n")?; + emit_indentation(collector, indentation + 2)?; + collector.write_fmt(format_args!("Self {{ groups: {}Groups {{ ", class.class_name))?; + for j in 0..class.grouped_basis.len() { + if j > 0 { + collector.write_all(b", ")?; + } + collector.write_fmt(format_args!("g{}", j))?; + } + collector.write_all(b" } }\n")?; + emit_indentation(collector, indentation + 1)?; + collector.write_all(b"}\n")?; + for (j, group) in class.grouped_basis.iter().enumerate() { + emit_indentation(collector, indentation + 1)?; + collector.write_all(b"#[inline(always)]\n")?; + emit_indentation(collector, indentation + 1)?; + collector.write_fmt(format_args!("pub fn group{}(&self) -> ", j))?; + emit_data_type(collector, &DataType::SimdVector(group.len()))?; + collector.write_all(b" {\n")?; + emit_indentation(collector, indentation + 2)?; + collector.write_fmt(format_args!("unsafe {{ self.groups.g{} }}\n", j))?; + emit_indentation(collector, indentation + 1)?; + collector.write_all(b"}\n")?; + emit_indentation(collector, indentation + 1)?; + collector.write_all(b"#[inline(always)]\n")?; + emit_indentation(collector, indentation + 1)?; + collector.write_fmt(format_args!("pub fn group{}_mut(&mut self) -> &mut ", j))?; + emit_data_type(collector, &DataType::SimdVector(group.len()))?; + collector.write_all(b" {\n")?; + emit_indentation(collector, indentation + 2)?; + collector.write_fmt(format_args!("unsafe {{ &mut self.groups.g{} }}\n", j))?; + emit_indentation(collector, indentation + 1)?; + collector.write_all(b"}\n")?; + } + emit_indentation(collector, indentation)?; + collector.write_all(b"}\n\n")?; + emit_indentation(collector, indentation)?; + collector.write_fmt(format_args!( + "const {}_INDEX_REMAP: [usize; {}] = [", + class.class_name.to_uppercase(), + element_count + ))?; + let mut element_index = 0; + let mut index_remap = Vec::new(); + for (j, group) in class.grouped_basis.iter().enumerate() { + for _ in 0..group.len() { + if element_index > 0 { + collector.write_all(b", ")?; + } + collector.write_fmt(format_args!("{}", element_index))?; + index_remap.push(element_index); + element_index += 1; + } + element_index += simd_widths[j].saturating_sub(group.len()); + } + collector.write_all(b"];\n\n")?; + emit_indentation(collector, indentation)?; + collector.write_fmt(format_args!("impl core::ops::Index for {} {{\n", class.class_name))?; + emit_indentation(collector, indentation + 1)?; + collector.write_all(b"type Output = f32;\n\n")?; + emit_indentation(collector, indentation + 1)?; + collector.write_all(b"fn index(&self, index: usize) -> &Self::Output {\n")?; + emit_indentation(collector, indentation + 2)?; + collector.write_fmt(format_args!( + "unsafe {{ &self.elements[{}_INDEX_REMAP[index]] }}\n", + class.class_name.to_uppercase() + ))?; + emit_indentation(collector, indentation + 1)?; + collector.write_all(b"}\n")?; + emit_indentation(collector, indentation)?; + collector.write_all(b"}\n\n")?; + emit_indentation(collector, indentation)?; + collector.write_fmt(format_args!("impl core::ops::IndexMut for {} {{\n", class.class_name))?; + emit_indentation(collector, indentation + 1)?; + collector.write_all(b"fn index_mut(&mut self, index: usize) -> &mut Self::Output {\n")?; + emit_indentation(collector, indentation + 2)?; + collector.write_fmt(format_args!( + "unsafe {{ &mut self.elements[{}_INDEX_REMAP[index]] }}\n", + class.class_name.to_uppercase() + ))?; + emit_indentation(collector, indentation + 1)?; + collector.write_all(b"}\n")?; + emit_indentation(collector, indentation)?; + collector.write_all(b"}\n\n")?; + emit_indentation(collector, indentation)?; + collector.write_fmt(format_args!( + "impl core::convert::From<{}> for [f32; {}] {{\n", + class.class_name, element_count + ))?; + emit_indentation(collector, indentation + 1)?; + collector.write_fmt(format_args!("fn from(vector: {}) -> Self {{\n", class.class_name))?; + emit_indentation(collector, indentation + 2)?; + collector.write_all(b"unsafe { [")?; + for (i, remapped) in index_remap.iter().enumerate() { + if i > 0 { + collector.write_all(b", ")?; + } + collector.write_fmt(format_args!("vector.elements[{}]", remapped))?; + } + collector.write_all(b"] }\n")?; + emit_indentation(collector, indentation + 1)?; + collector.write_all(b"}\n")?; + emit_indentation(collector, indentation)?; + collector.write_all(b"}\n\n")?; + emit_indentation(collector, indentation)?; + collector.write_fmt(format_args!( + "impl core::convert::From<[f32; {}]> for {} {{\n", + element_count, class.class_name + ))?; + emit_indentation(collector, indentation + 1)?; + collector.write_fmt(format_args!("fn from(array: [f32; {}]) -> Self {{\n", element_count))?; + emit_indentation(collector, indentation + 2)?; + collector.write_all(b"Self { elements: [")?; + let mut element_index = 0; + for (j, group) in class.grouped_basis.iter().enumerate() { + for _ in 0..group.len() { + if element_index > 0 { + collector.write_all(b", ")?; + } + collector.write_fmt(format_args!("array[{}]", element_index))?; + element_index += 1; + } + for _ in group.len()..simd_widths[j] { + collector.write_all(b", 0.0")?; + } + } + collector.write_all(b"] }\n")?; + emit_indentation(collector, indentation + 1)?; + collector.write_all(b"}\n")?; + emit_indentation(collector, indentation)?; + collector.write_all(b"}\n\n")?; + emit_indentation(collector, indentation)?; + collector.write_fmt(format_args!("impl core::fmt::Debug for {} {{\n", class.class_name))?; + emit_indentation(collector, indentation + 1)?; + collector.write_all(b"fn fmt(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {\n")?; + emit_indentation(collector, indentation + 2)?; + collector.write_all(b"formatter\n")?; + emit_indentation(collector, indentation + 3)?; + collector.write_fmt(format_args!(".debug_struct(\"{}\")\n", class.class_name))?; + let mut element_index = 0; + for group in class.grouped_basis.iter() { + for element in group.iter() { + emit_indentation(collector, indentation + 3)?; + collector.write_fmt(format_args!(".field(\"{}\", &self[{}])\n", element, element_index))?; + element_index += 1; + } + } + emit_indentation(collector, indentation + 3)?; + collector.write_all(b".finish()\n")?; + emit_indentation(collector, indentation + 1)?; + collector.write_all(b"}\n")?; + emit_indentation(collector, indentation)?; + collector.write_all(b"}\n\n")?; + } + AstNode::ReturnStatement { expression } => { + collector.write_all(b"return ")?; + emit_expression(collector, expression)?; + collector.write_all(b";\n")?; + } + AstNode::VariableAssignment { name, data_type, expression } => { + if let Some(data_type) = data_type { + collector.write_fmt(format_args!("let mut {}", name))?; + collector.write_all(b": ")?; + emit_data_type(collector, data_type)?; + } else { + collector.write_fmt(format_args!("{}", name))?; + } + collector.write_all(b" = ")?; + emit_expression(collector, expression)?; + collector.write_all(b";\n")?; + } + AstNode::IfThenBlock { condition, body } | AstNode::WhileLoopBlock { condition, body } => { + collector.write_all(match &ast_node { + AstNode::IfThenBlock { .. } => b"if ", + AstNode::WhileLoopBlock { .. } => b"while ", + _ => unreachable!(), + })?; + emit_expression(collector, condition)?; + collector.write_all(b" {\n")?; + for statement in body.iter() { + emit_indentation(collector, indentation + 1)?; + emit_code(collector, statement, indentation + 1)?; + } + emit_indentation(collector, indentation)?; + collector.write_all(b"}\n")?; + } + AstNode::TraitImplementation { result, parameters, body } => { + if result.data_type.is_scalar() + && !parameters + .iter() + .any(|parameter| matches!(parameter.data_type, DataType::MultiVector(class) if !class.is_scalar())) + { + return Ok(()); + } + collector.write_fmt(format_args!("impl {}", result.name))?; + let impl_for = match parameters.len() { + 0 => &result.data_type, + 1 if result.name == "Into" => { + collector.write_all(b"<")?; + emit_data_type(collector, &result.data_type)?; + collector.write_all(b">")?; + ¶meters[0].data_type + } + 1 => ¶meters[0].data_type, + 2 if !matches!(parameters[1].data_type, DataType::MultiVector(_)) => ¶meters[0].data_type, + 2 => { + collector.write_all(b"<")?; + emit_data_type(collector, ¶meters[1].data_type)?; + collector.write_all(b">")?; + ¶meters[0].data_type + } + _ => unreachable!(), + }; + collector.write_all(b" for ")?; + emit_data_type(collector, impl_for)?; + collector.write_all(b" {\n")?; + if !parameters.is_empty() && result.name != "Into" { + emit_indentation(collector, indentation + 1)?; + collector.write_all(b"type Output = ")?; + emit_data_type(collector, &result.data_type)?; + collector.write_all(b";\n\n")?; + } + emit_indentation(collector, indentation + 1)?; + collector.write_all(b"fn ")?; + camel_to_snake_case(collector, result.name)?; + match parameters.len() { + 0 => collector.write_all(b"() -> Self")?, + 1 => { + collector.write_fmt(format_args!("({}) -> ", parameters[0].name))?; + emit_data_type(collector, &result.data_type)?; + } + 2 => { + collector.write_fmt(format_args!("({}, {}: ", parameters[0].name, parameters[1].name))?; + emit_data_type(collector, ¶meters[1].data_type)?; + collector.write_all(b") -> ")?; + emit_data_type(collector, &result.data_type)?; + } + _ => unreachable!(), + } + collector.write_all(b" {\n")?; + for (i, statement) in body.iter().enumerate() { + emit_indentation(collector, indentation + 2)?; + if i + 1 == body.len() { + if let AstNode::ReturnStatement { expression } = statement { + emit_expression(collector, expression)?; + collector.write_all(b"\n")?; + break; + } + } + emit_code(collector, statement, indentation + 2)?; + } + emit_indentation(collector, indentation + 1)?; + collector.write_all(b"}\n}\n\n")?; + match result.name { + "Add" | "Sub" | "Mul" | "Div" => { + emit_assign_trait(collector, result, parameters)?; + } + _ => {} + } + } + } + Ok(()) +} diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..39ee006 --- /dev/null +++ b/flake.lock @@ -0,0 +1,101 @@ +{ + "nodes": { + "fenix": { + "inputs": { + "nixpkgs": [ + "nixpkgs" + ], + "rust-analyzer-src": "rust-analyzer-src" + }, + "locked": { + "lastModified": 1757572889, + "narHash": "sha256-tPmVhFoet1ijKduGpmh1uHw5eayjcaHvu3GodIVF9ac=", + "owner": "nix-community", + "repo": "fenix", + "rev": "4e37fff6b22c58717577a2006c3907c9eef452dc", + "type": "github" + }, + "original": { + "owner": "nix-community", + "repo": "fenix", + "rev": "4e37fff6b22c58717577a2006c3907c9eef452dc", + "type": "github" + } + }, + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1757632301, + "narHash": "sha256-5q2Cw5/3EB86h9vm/GKhHQWNPddfpZsS7CXS3zTUusU=", + "owner": "NixOs", + "repo": "nixpkgs", + "rev": "639f5e5962c6935a7255d0ca9583fc44cfe95307", + "type": "github" + }, + "original": { + "owner": "NixOs", + "ref": "release-25.05", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "fenix": "fenix", + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs" + } + }, + "rust-analyzer-src": { + "flake": false, + "locked": { + "lastModified": 1688765894, + "narHash": "sha256-+ZjwtTxYn3eTy77XG3R1/cihNdrFZc1JxfojORqhMfU=", + "owner": "rust-lang", + "repo": "rust-analyzer", + "rev": "db0add1ce92af58a92b2a80990044ae21713ae29", + "type": "github" + }, + "original": { + "owner": "rust-lang", + "ref": "nightly", + "repo": "rust-analyzer", + "type": "github" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..233ccba --- /dev/null +++ b/flake.nix @@ -0,0 +1,61 @@ +{ + description = "Rust GPU"; + + inputs = { + fenix = { + url = "github:nix-community/fenix/4e37fff6b22c58717577a2006c3907c9eef452dc"; + inputs.nixpkgs.follows = "nixpkgs"; + }; + nixpkgs.url = "github:NixOs/nixpkgs/release-25.05"; + flake-utils.url = "github:numtide/flake-utils"; + }; + + outputs = { self, nixpkgs, flake-utils, fenix }: + flake-utils.lib.eachDefaultSystem (system: + let + overlays = [ fenix.overlays.default ]; + pkgs = import nixpkgs { inherit overlays system; }; + + rustPkg = pkgs.fenix.latest.withComponents [ + "cargo" + "clippy" + "rust-src" + "rustc" + "rustc-dev" + "llvm-tools-preview" + "rustfmt" + ]; + in + { + devShells.default = with pkgs; mkShell { + hardeningDisable = [ "fortify" ]; + + # WGPU_ADAPTER_NAME = "vulkan"; + + shellHook = '' + export LD_LIBRARY_PATH=${vulkan-loader}/lib + ''; + + nativeBuildInputs = [ + pkg-config + gdb + rustPkg + rust-analyzer-nightly + gcc + spirv-tools + ]; + + buildInputs = [ + xorg.libX11 + xorg.libXcursor + xwayland + xorg.libXrandr + xorg.libXi + vulkan-loader + vulkan-tools + vulkan-headers + vulkan-validation-layers + ]; + }; + }); +} diff --git a/src/lib.rs b/src/lib.rs index 034ad2a..7466a5b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,14 +1,10 @@ -pub mod epga1d; -pub mod ppga1d; -pub mod hpga1d; + +#![no_std] + +#[cfg_attr(not(target_arch = "spirv"), allow(unused_imports))] +use spirv_std::num_traits::Float; + pub mod epga2d; -pub mod ppga2d; -pub mod hpga2d; -pub mod epga3d; -pub mod ppga3d; -pub mod hpga3d; -pub mod simd; -pub mod polynomial; impl Zero for f32 { fn zero() -> Self { @@ -114,7 +110,7 @@ impl Signum for f32 { type Output = f32; fn signum(self) -> f32 { - f32::signum(self) + ::signum(self) } } @@ -143,181 +139,6 @@ impl Transformation for f32 { } } -impl epga1d::ComplexNumber { - pub fn real(self) -> f32 { - self[0] - } - - pub fn imaginary(self) -> f32 { - self[1] - } - - pub fn from_polar(magnitude: f32, argument: f32) -> Self { - Self::new(magnitude * argument.cos(), magnitude * argument.sin()) - } - - pub fn arg(self) -> f32 { - self.imaginary().atan2(self.real()) - } -} - -impl Exp for epga1d::ComplexNumber { - type Output = Self; - - fn exp(self) -> Self { - Self::from_polar(self[0].exp(), self[1]) - } -} - -impl Ln for epga1d::ComplexNumber { - type Output = Self; - - fn ln(self) -> Self { - Self::new(self.magnitude().ln(), self.arg()) - } -} - -impl Powf for epga1d::ComplexNumber { - type Output = Self; - - fn powf(self, exponent: f32) -> Self { - Self::from_polar(self.magnitude().powf(exponent), self.arg() * exponent) - } -} - -impl Exp for ppga2d::IdealPoint { - type Output = ppga2d::Translator; - - fn exp(self) -> ppga2d::Translator { - ppga2d::Translator::new(1.0, self[0], self[1]) - } -} - -impl Ln for ppga2d::Translator { - type Output = ppga2d::IdealPoint; - - fn ln(self) -> ppga2d::IdealPoint { - let result: ppga2d::IdealPoint = self.into(); - result * (1.0 / self[0]) - } -} - -impl Powf for ppga2d::Translator { - type Output = Self; - - fn powf(self, exponent: f32) -> Self { - (self.ln() * exponent).exp() - } -} - -impl Exp for ppga2d::Point { - type Output = ppga2d::Motor; - - fn exp(self) -> ppga2d::Motor { - let det = self[0] * self[0]; - if det <= 0.0 { - return ppga2d::Motor::new(1.0, 0.0, self[1], self[2]); - } - let a = det.sqrt(); - let c = a.cos(); - let s = a.sin() / a; - let g0 = simd::Simd32x3::from(s) * self.group0(); - ppga2d::Motor::new(c, g0[0], g0[1], g0[2]) - } -} - -impl Ln for ppga2d::Motor { - type Output = ppga2d::Point; - - fn ln(self) -> ppga2d::Point { - let det = 1.0 - self[0] * self[0]; - if det <= 0.0 { - return ppga2d::Point::new(0.0, self[2], self[3]); - } - let a = 1.0 / det; - let b = self[0].acos() * a.sqrt(); - let g0 = simd::Simd32x4::from(b) * self.group0(); - ppga2d::Point::new(g0[1], g0[2], g0[3]) - } -} - -impl Powf for ppga2d::Motor { - type Output = Self; - - fn powf(self, exponent: f32) -> Self { - (self.ln() * exponent).exp() - } -} - -impl Exp for ppga3d::IdealPoint { - type Output = ppga3d::Translator; - - fn exp(self) -> ppga3d::Translator { - ppga3d::Translator::new(1.0, self[0], self[1], self[2]) - } -} - -impl Ln for ppga3d::Translator { - type Output = ppga3d::IdealPoint; - - fn ln(self) -> ppga3d::IdealPoint { - let result: ppga3d::IdealPoint = self.into(); - result * (1.0 / self[0]) - } -} - -impl Powf for ppga3d::Translator { - type Output = Self; - - fn powf(self, exponent: f32) -> Self { - (self.ln() * exponent).exp() - } -} - -impl Exp for ppga3d::Line { - type Output = ppga3d::Motor; - - fn exp(self) -> ppga3d::Motor { - let det = self[3] * self[3] + self[4] * self[4] + self[5] * self[5]; - if det <= 0.0 { - return ppga3d::Motor::new(1.0, 0.0, 0.0, 0.0, 0.0, self[0], self[1], self[2]); - } - let a = det.sqrt(); - let c = a.cos(); - let s = a.sin() / a; - let m = self[0] * self[3] + self[1] * self[4] + self[2] * self[5]; - let t = m / det * (c - s); - let g0 = simd::Simd32x3::from(s) * self.group1(); - let g1 = simd::Simd32x3::from(s) * self.group0() + simd::Simd32x3::from(t) * self.group1(); - ppga3d::Motor::new(c, g0[0], g0[1], g0[2], s * m, g1[0], g1[1], g1[2]) - } -} - -impl Ln for ppga3d::Motor { - type Output = ppga3d::Line; - - fn ln(self) -> ppga3d::Line { - let det = 1.0 - self[0] * self[0]; - if det <= 0.0 { - return ppga3d::Line::new(self[5], self[6], self[7], 0.0, 0.0, 0.0); - } - let a = 1.0 / det; - let b = self[0].acos() * a.sqrt(); - let c = a * self[4] * (1.0 - self[0] * b); - let g0 = simd::Simd32x4::from(b) * self.group1() + simd::Simd32x4::from(c) * self.group0(); - let g1 = simd::Simd32x4::from(b) * self.group0(); - ppga3d::Line::new(g0[1], g0[2], g0[3], g1[1], g1[2], g1[3]) - } -} - -impl Powf for ppga3d::Motor { - type Output = Self; - - fn powf(self, exponent: f32) -> Self { - (self.ln() * exponent).exp() - } -} - /// All elements set to `0.0` pub trait Zero { fn zero() -> Self;