From d73b6a253c928bd450fa24139947e0d9924ad717 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexander=20Mei=C3=9Fner?= Date: Wed, 20 Sep 2023 15:12:01 +0200 Subject: [PATCH] Blank scalars without the wrapper class. --- codegen/src/ast.rs | 23 ++++++ codegen/src/compile.rs | 4 + codegen/src/glsl.rs | 50 ++++++++---- codegen/src/rust.rs | 180 +++++++++++++++++++++++------------------ src/lib.rs | 149 +++++++++++++++++++++++++++++++--- src/polynomial.rs | 18 ++--- 6 files changed, 309 insertions(+), 115 deletions(-) diff --git a/codegen/src/ast.rs b/codegen/src/ast.rs index bb28279..99b786d 100644 --- a/codegen/src/ast.rs +++ b/codegen/src/ast.rs @@ -7,6 +7,16 @@ pub enum DataType<'a> { MultiVector(&'a MultiVectorClass), } +impl DataType<'_> { + pub fn is_scalar(&self) -> bool { + match self { + Self::SimdVector(1) => true, + Self::MultiVector(multi_vector_class) => multi_vector_class.is_scalar(), + _ => false, + } + } +} + #[derive(PartialEq, Eq, Clone, Debug)] pub enum ExpressionContent<'a> { None, @@ -42,6 +52,19 @@ pub struct Expression<'a> { pub content: ExpressionContent<'a>, } +impl Expression<'_> { + pub fn is_scalar(&self) -> bool { + if self.size > 1 { + return false; + } + match &self.content { + ExpressionContent::Variable(data_type, _) => data_type.is_scalar(), + ExpressionContent::InvokeInstanceMethod(_, _, _, result_data_type, _) => result_data_type.is_scalar(), + _ => false, + } + } +} + #[derive(PartialEq, Eq, Clone, Debug)] pub struct Parameter<'a> { pub name: &'static str, diff --git a/codegen/src/compile.rs b/codegen/src/compile.rs index 6c46fba..f531be4 100644 --- a/codegen/src/compile.rs +++ b/codegen/src/compile.rs @@ -146,6 +146,10 @@ impl MultiVectorClass { self.grouped_basis.iter().flatten().cloned().collect() } + pub fn is_scalar(&self) -> bool { + self.flat_basis() == vec![BasisElement { scalar: 1, index: 0 }] + } + pub fn signature(&self) -> Vec { let mut signature: Vec = self.grouped_basis.iter().flatten().map(|element| element.index).collect(); signature.sort_unstable(); diff --git a/codegen/src/glsl.rs b/codegen/src/glsl.rs index bab3ddb..344fe3d 100644 --- a/codegen/src/glsl.rs +++ b/codegen/src/glsl.rs @@ -10,7 +10,8 @@ fn emit_data_type(collector: &mut W, data_type: &DataType) -> DataType::Integer => collector.write_all(b"int"), DataType::SimdVector(size) if *size == 1 => collector.write_all(b"float"), DataType::SimdVector(size) => collector.write_fmt(format_args!("vec{}", *size)), - DataType::MultiVector(class) => collector.write_fmt(format_args!("{}", class.class_name)), + DataType::MultiVector(class) if class.is_scalar() => collector.write_all(b"float"), + DataType::MultiVector(class) => collector.write_all(class.class_name.as_bytes()), } } @@ -20,6 +21,9 @@ fn emit_expression(collector: &mut W, expression: &Expression ExpressionContent::Variable(_data_type, name) => { collector.write_all(name.bytes().collect::>().as_slice())?; } + ExpressionContent::InvokeClassMethod(class, "Constructor", arguments) if class.is_scalar() => { + emit_expression(collector, &arguments[0].1)?; + } ExpressionContent::InvokeClassMethod(_, _, arguments) | ExpressionContent::InvokeInstanceMethod(_, _, _, _, arguments) => { match &expression.content { ExpressionContent::InvokeInstanceMethod(result_class, inner_expression, method_name, _, _) => { @@ -78,7 +82,9 @@ fn emit_expression(collector: &mut W, expression: &Expression } ExpressionContent::Access(inner_expression, array_index) => { emit_expression(collector, inner_expression)?; - collector.write_fmt(format_args!(".g{}", array_index))?; + if !inner_expression.is_scalar() { + collector.write_fmt(format_args!(".g{}", array_index))?; + } } ExpressionContent::Swizzle(inner_expression, indices) => { emit_expression(collector, inner_expression)?; @@ -88,22 +94,28 @@ fn emit_expression(collector: &mut W, expression: &Expression } } ExpressionContent::Gather(inner_expression, indices) => { - if expression.size > 1 { - emit_data_type(collector, &DataType::SimdVector(expression.size))?; - collector.write_all(b"(")?; - } - for (i, (array_index, component_index)) in indices.iter().enumerate() { - if i > 0 { - collector.write_all(b", ")?; - } + if expression.size == 1 && inner_expression.is_scalar() { emit_expression(collector, inner_expression)?; - collector.write_fmt(format_args!(".g{}", array_index))?; - if inner_expression.size > 1 { - collector.write_fmt(format_args!(".{}", COMPONENT[*component_index]))?; + } else { + if expression.size > 1 { + emit_data_type(collector, &DataType::SimdVector(expression.size))?; + collector.write_all(b"(")?; + } + for (i, (array_index, component_index)) in indices.iter().enumerate() { + if i > 0 { + collector.write_all(b", ")?; + } + emit_expression(collector, inner_expression)?; + if !inner_expression.is_scalar() { + collector.write_fmt(format_args!(".g{}", array_index))?; + if inner_expression.size > 1 { + collector.write_fmt(format_args!(".{}", COMPONENT[*component_index]))?; + } + } + } + if expression.size > 1 { + collector.write_all(b")")?; } - } - if expression.size > 1 { - collector.write_all(b")")?; } } ExpressionContent::Constant(data_type, values) => match data_type { @@ -163,6 +175,9 @@ pub fn emit_code(collector: &mut W, ast_node: &AstNode, inden AstNode::None => {} AstNode::Preamble => {} AstNode::ClassDefinition { class } => { + if class.is_scalar() { + return Ok(()); + } collector.write_fmt(format_args!("struct {} {{\n", class.class_name))?; for (i, group) in class.grouped_basis.iter().enumerate() { emit_indentation(collector, indentation + 1)?; @@ -212,7 +227,8 @@ pub fn emit_code(collector: &mut W, ast_node: &AstNode, inden collector.write_all(b"}\n")?; } AstNode::TraitImplementation { result, parameters, body } => { - collector.write_fmt(format_args!("{} ", result.multi_vector_class().class_name))?; + emit_data_type(collector, &result.data_type)?; + collector.write_all(b" ")?; match parameters.len() { 0 => camel_to_snake_case(collector, &result.multi_vector_class().class_name)?, 1 if result.name == "Into" => { diff --git a/codegen/src/rust.rs b/codegen/src/rust.rs index a141e03..6e6dec6 100644 --- a/codegen/src/rust.rs +++ b/codegen/src/rust.rs @@ -8,6 +8,7 @@ fn emit_data_type(collector: &mut W, data_type: &DataType) -> DataType::Integer => collector.write_all(b"isize"), DataType::SimdVector(size) if *size == 1 => collector.write_all(b"f32"), DataType::SimdVector(size) => collector.write_fmt(format_args!("Simd32x{}", *size)), + DataType::MultiVector(class) if class.is_scalar() => collector.write_all(b"f32"), DataType::MultiVector(class) => collector.write_fmt(format_args!("{}", class.class_name)), } } @@ -18,39 +19,45 @@ fn emit_expression(collector: &mut W, expression: &Expression ExpressionContent::Variable(_data_type, name) => { collector.write_all(name.bytes().collect::>().as_slice())?; } - ExpressionContent::InvokeClassMethod(_, method_name, arguments) | ExpressionContent::InvokeInstanceMethod(_, _, method_name, _, arguments) => { - match &expression.content { - ExpressionContent::InvokeInstanceMethod(_result_class, inner_expression, _, _, _) => { - emit_expression(collector, inner_expression)?; - collector.write_all(b".")?; - } - ExpressionContent::InvokeClassMethod(class, _, _) => { - if *method_name == "Constructor" { - collector.write_fmt(format_args!("{} {{ groups: {}Groups {{ ", class.class_name, class.class_name))?; - } else { - collector.write_fmt(format_args!("{}::", class.class_name))?; - } - } - _ => unreachable!(), - } - if *method_name != "Constructor" { - camel_to_snake_case(collector, method_name)?; - collector.write_all(b"(")?; - } + ExpressionContent::InvokeInstanceMethod(_result_class, inner_expression, method_name, _, arguments) => { + emit_expression(collector, inner_expression)?; + collector.write_all(b".")?; + camel_to_snake_case(collector, method_name)?; + collector.write_all(b"(")?; for (i, (_argument_class, argument)) in arguments.iter().enumerate() { if i > 0 { collector.write_all(b", ")?; } - if *method_name == "Constructor" { - collector.write_fmt(format_args!("g{}: ", i))?; + emit_expression(collector, argument)?; + } + collector.write_all(b")")?; + } + ExpressionContent::InvokeClassMethod(class, "Constructor", arguments) if class.is_scalar() => { + emit_expression(collector, &arguments[0].1)?; + } + ExpressionContent::InvokeClassMethod(class, "Constructor", arguments) => { + collector.write_fmt(format_args!("{} {{ groups: {}Groups {{ ", class.class_name, class.class_name))?; + for (i, (_argument_class, argument)) in arguments.iter().enumerate() { + if i > 0 { + collector.write_all(b", ")?; + } + collector.write_fmt(format_args!("g{}: ", i))?; + emit_expression(collector, argument)?; + } + collector.write_all(b" } }")?; + } + ExpressionContent::InvokeClassMethod(class, method_name, arguments) => { + emit_data_type(collector, &DataType::MultiVector(class))?; + collector.write_all(b"::")?; + camel_to_snake_case(collector, method_name)?; + collector.write_all(b"(")?; + for (i, (_argument_class, argument)) in arguments.iter().enumerate() { + if i > 0 { + collector.write_all(b", ")?; } emit_expression(collector, argument)?; } - if *method_name == "Constructor" { - collector.write_all(b" } }")?; - } else { - collector.write_all(b")")?; - } + collector.write_all(b")")?; } ExpressionContent::Conversion(_source_class, _destination_class, inner_expression) => { emit_expression(collector, inner_expression)?; @@ -67,7 +74,9 @@ fn emit_expression(collector: &mut W, expression: &Expression } ExpressionContent::Access(inner_expression, array_index) => { emit_expression(collector, inner_expression)?; - collector.write_fmt(format_args!(".group{}()", array_index))?; + if !inner_expression.is_scalar() { + collector.write_fmt(format_args!(".group{}()", array_index))?; + } } ExpressionContent::Swizzle(inner_expression, indices) => { if expression.size == 1 { @@ -89,28 +98,34 @@ fn emit_expression(collector: &mut W, expression: &Expression } } ExpressionContent::Gather(inner_expression, indices) => { - if expression.size > 1 { - emit_data_type(collector, &DataType::SimdVector(expression.size))?; - collector.write_all(b"::from(")?; - } - if indices.len() > 1 { - collector.write_all(b"[")?; - } - for (i, (array_index, component_index)) in indices.iter().enumerate() { - if i > 0 { - collector.write_all(b", ")?; - } + if expression.size == 1 && inner_expression.is_scalar() { emit_expression(collector, inner_expression)?; - collector.write_fmt(format_args!(".group{}()", array_index))?; - if inner_expression.size > 1 { - collector.write_fmt(format_args!("[{}]", *component_index))?; + } else { + if expression.size > 1 { + emit_data_type(collector, &DataType::SimdVector(expression.size))?; + collector.write_all(b"::from(")?; + } + if indices.len() > 1 { + collector.write_all(b"[")?; + } + for (i, (array_index, component_index)) in indices.iter().enumerate() { + if i > 0 { + collector.write_all(b", ")?; + } + emit_expression(collector, inner_expression)?; + if !inner_expression.is_scalar() { + collector.write_fmt(format_args!(".group{}()", array_index))?; + if inner_expression.size > 1 { + collector.write_fmt(format_args!("[{}]", *component_index))?; + } + } + } + if indices.len() > 1 { + collector.write_all(b"]")?; + } + if expression.size > 1 { + collector.write_all(b")")?; } - } - if indices.len() > 1 { - collector.write_all(b"]")?; - } - if expression.size > 1 { - collector.write_all(b")")?; } } ExpressionContent::Constant(data_type, values) => match data_type { @@ -172,17 +187,15 @@ fn emit_assign_trait(collector: &mut W, result: &Parameter, p if result.multi_vector_class() != parameters[0].multi_vector_class() { return Ok(()); } - collector.write_fmt(format_args!( - "impl {}Assign<{}> for {} {{\n fn ", - result.name, - parameters[1].multi_vector_class().class_name, - parameters[0].multi_vector_class().class_name - ))?; + collector.write_fmt(format_args!("impl {}Assign<", result.name))?; + emit_data_type(collector, ¶meters[1].data_type)?; + collector.write_all(b"> for ")?; + emit_data_type(collector, ¶meters[0].data_type)?; + collector.write_all(b" {\n fn ")?; camel_to_snake_case(collector, result.name)?; - collector.write_fmt(format_args!( - "_assign(&mut self, other: {}) {{\n *self = (*self).", - parameters[1].multi_vector_class().class_name - ))?; + collector.write_all(b"_assign(&mut self, other: ")?; + emit_data_type(collector, ¶meters[1].data_type)?; + collector.write_all(b") {\n *self = (*self).")?; camel_to_snake_case(collector, result.name)?; collector.write_all(b"(other);\n }\n}\n\n") } @@ -196,6 +209,9 @@ pub fn emit_code(collector: &mut W, ast_node: &AstNode, inden .write_all(b"use crate::{simd::*, *};\nuse std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};\n\n")?; } AstNode::ClassDefinition { class } => { + if class.is_scalar() { + return Ok(()); + } let element_count = class.grouped_basis.iter().fold(0, |a, b| a + b.len()); let mut simd_widths = Vec::new(); emit_indentation(collector, indentation)?; @@ -471,30 +487,40 @@ pub fn emit_code(collector: &mut W, ast_node: &AstNode, inden collector.write_all(b"}\n")?; } AstNode::TraitImplementation { result, parameters, body } => { - match parameters.len() { - 0 => collector.write_fmt(format_args!("impl {} for {}", result.name, result.multi_vector_class().class_name))?, - 1 if result.name == "Into" => collector.write_fmt(format_args!( - "impl {}<{}> for {}", - result.name, - result.multi_vector_class().class_name, - parameters[0].multi_vector_class().class_name, - ))?, - 1 => collector.write_fmt(format_args!("impl {} for {}", result.name, parameters[0].multi_vector_class().class_name))?, - 2 if !matches!(parameters[1].data_type, DataType::MultiVector(_)) => { - collector.write_fmt(format_args!("impl {} for {}", result.name, parameters[0].multi_vector_class().class_name))? - } - 2 => collector.write_fmt(format_args!( - "impl {}<{}> for {}", - result.name, - parameters[1].multi_vector_class().class_name, - parameters[0].multi_vector_class().class_name, - ))?, - _ => unreachable!(), + if result.data_type.is_scalar() + && !parameters + .iter() + .any(|parameter| matches!(parameter.data_type, DataType::MultiVector(class) if !class.is_scalar())) + { + return Ok(()); } + collector.write_fmt(format_args!("impl {}", result.name))?; + let impl_for = match parameters.len() { + 0 => &result.data_type, + 1 if result.name == "Into" => { + collector.write_all(b"<")?; + emit_data_type(collector, &result.data_type)?; + collector.write_all(b">")?; + ¶meters[0].data_type + } + 1 => ¶meters[0].data_type, + 2 if !matches!(parameters[1].data_type, DataType::MultiVector(_)) => ¶meters[0].data_type, + 2 => { + collector.write_all(b"<")?; + emit_data_type(collector, ¶meters[1].data_type)?; + collector.write_all(b">")?; + ¶meters[0].data_type + } + _ => unreachable!(), + }; + collector.write_all(b" for ")?; + emit_data_type(collector, impl_for)?; collector.write_all(b" {\n")?; if !parameters.is_empty() && result.name != "Into" { emit_indentation(collector, indentation + 1)?; - collector.write_fmt(format_args!("type Output = {};\n\n", result.multi_vector_class().class_name))?; + collector.write_all(b"type Output = ")?; + emit_data_type(collector, &result.data_type)?; + collector.write_all(b";\n\n")?; } emit_indentation(collector, indentation + 1)?; collector.write_all(b"fn ")?; diff --git a/src/lib.rs b/src/lib.rs index 0169005..e325eda 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,17 +10,144 @@ pub mod hpga3d; pub mod simd; pub mod polynomial; -impl epga1d::Scalar { - pub fn real(self) -> f32 { - self[0] +impl Zero for f32 { + fn zero() -> Self { + 0.0 } +} - pub fn sqrt(self) -> epga1d::ComplexNumber { - if self[0] < 0.0 { - epga1d::ComplexNumber::new(0.0, (-self[0]).sqrt()) - } else { - epga1d::ComplexNumber::new(self[0].sqrt(), 0.0) - } +impl One for f32 { + fn one() -> Self { + 1.0 + } +} + +impl Automorphism for f32 { + type Output = f32; + + fn automorphism(self) -> f32 { + self + } +} + +impl Reversal for f32 { + type Output = f32; + + fn reversal(self) -> f32 { + self + } +} + +impl Conjugation for f32 { + type Output = f32; + + fn conjugation(self) -> f32 { + self + } +} + +impl GeometricProduct for f32 { + type Output = f32; + + fn geometric_product(self, other: f32) -> f32 { + self * other + } +} + +impl OuterProduct for f32 { + type Output = f32; + + fn outer_product(self, other: f32) -> f32 { + self * other + } +} + +impl InnerProduct for f32 { + type Output = f32; + + fn inner_product(self, other: f32) -> f32 { + self * other + } +} + +impl LeftContraction for f32 { + type Output = f32; + + fn left_contraction(self, other: f32) -> f32 { + self * other + } +} + +impl RightContraction for f32 { + type Output = f32; + + fn right_contraction(self, other: f32) -> f32 { + self * other + } +} + +impl ScalarProduct for f32 { + type Output = f32; + + fn scalar_product(self, other: f32) -> f32 { + self * other + } +} + +impl SquaredMagnitude for f32 { + type Output = f32; + + fn squared_magnitude(self) -> f32 { + self.scalar_product(self.reversal()) + } +} + +impl Magnitude for f32 { + type Output = f32; + + fn magnitude(self) -> f32 { + self.abs() + } +} + +impl Scale for f32 { + type Output = f32; + + fn scale(self, other: f32) -> f32 { + self.geometric_product(other) + } +} + +impl Signum for f32 { + type Output = f32; + + fn signum(self) -> f32 { + f32::signum(self) + } +} + +impl Inverse for f32 { + type Output = f32; + + fn inverse(self) -> f32 { + 1.0 / self + } +} + +impl GeometricQuotient for f32 { + type Output = f32; + + fn geometric_quotient(self, other: f32) -> f32 { + self.geometric_product(other.inverse()) + } +} + +impl Transformation for f32 { + type Output = f32; + + fn transformation(self, other: f32) -> f32 { + self.geometric_product(other) + .geometric_product(self.reversal()) } } @@ -54,7 +181,7 @@ impl Ln for epga1d::ComplexNumber { type Output = Self; fn ln(self) -> Self { - Self::new(self.magnitude()[0].ln(), self.arg()) + Self::new(self.magnitude().ln(), self.arg()) } } @@ -62,7 +189,7 @@ impl Powf for epga1d::ComplexNumber { type Output = Self; fn powf(self, exponent: f32) -> Self { - Self::from_polar(self.magnitude()[0].powf(exponent), self.arg() * exponent) + Self::from_polar(self.magnitude().powf(exponent), self.arg() * exponent) } } diff --git a/src/polynomial.rs b/src/polynomial.rs index 4fb065d..3188bc0 100644 --- a/src/polynomial.rs +++ b/src/polynomial.rs @@ -50,7 +50,7 @@ pub fn solve_quadratic(coefficients: [f32; 3], error_margin: f32) -> (f32, Vec (f32, Vec ]; let mut solutions = Vec::with_capacity(3); let discriminant = d[1].powi(2) - 4.0 * d[0].powi(3); - let c = Scalar::new(discriminant).sqrt(); - let c = ((c + ComplexNumber::new(if c.real() + d[1] == 0.0 { -d[1] } else { d[1] }, 0.0)) - .scale(0.5)) - .powf(1.0 / 3.0); + let c = discriminant.sqrt(); + let c = ((c + ComplexNumber::new(if c + d[1] == 0.0 { -d[1] } else { d[1] }, 0.0)).scale(0.5)) + .powf(1.0 / 3.0); for root_of_unity in &ROOTS_OF_UNITY_3 { let ci = c.geometric_product(*root_of_unity); let denominator = ci.scale(3.0 * coefficients[3]); @@ -105,7 +104,7 @@ pub fn solve_cubic(coefficients: [f32; 4], error_margin: f32) -> (f32, Vec .geometric_product(denominator.reversal()); solutions.push(Root { numerator, - denominator: denominator.squared_magnitude().real(), + denominator: denominator.squared_magnitude(), }); } let real_root = @@ -144,10 +143,9 @@ pub fn solve_quartic(coefficients: [f32; 5], error_margin: f32) -> (f32, Vec