diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1de5659 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +target \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..6d86527 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "geometric_algebra" +version = "0.1.0" +authors = ["Alexander Meißner "] +edition = "2018" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +simd = { path = "simd" } diff --git a/codegen/Cargo.toml b/codegen/Cargo.toml new file mode 100644 index 0000000..5b83ee0 --- /dev/null +++ b/codegen/Cargo.toml @@ -0,0 +1,5 @@ +[package] +name = "codegen" +version = "0.1.0" +authors = ["Alexander Meißner "] +edition = "2018" diff --git a/codegen/rustfmt.toml b/codegen/rustfmt.toml new file mode 100644 index 0000000..09c2684 --- /dev/null +++ b/codegen/rustfmt.toml @@ -0,0 +1 @@ +max_width=150 \ No newline at end of file diff --git a/codegen/src/algebra.rs b/codegen/src/algebra.rs new file mode 100644 index 0000000..9c4a76c --- /dev/null +++ b/codegen/src/algebra.rs @@ -0,0 +1,291 @@ +pub struct GeometricAlgebra<'a> { + pub generator_squares: &'a [isize], +} + +impl<'a> GeometricAlgebra<'a> { + pub fn basis_size(&self) -> usize { + 1 << self.generator_squares.len() + } + + pub fn basis(&self) -> impl Iterator + '_ { + (0..self.basis_size() as BasisElementIndex).map(|index| BasisElement { index }) + } + + pub fn sorted_basis(&self) -> Vec { + let mut basis_elements = self.basis().collect::>(); + basis_elements.sort(); + basis_elements + } +} + +type BasisElementIndex = u16; + +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct BasisElement { + pub index: BasisElementIndex, +} + +impl BasisElement { + pub fn new(name: &str) -> Self { + Self { + index: if name == "1" { + 0 + } else { + let mut generator_indices = name.chars(); + assert_eq!(generator_indices.next().unwrap(), 'e'); + generator_indices.fold(0, |index, generator_index| index | (1 << (generator_index.to_digit(16).unwrap()))) + }, + } + } + + pub fn grade(&self) -> usize { + self.index.count_ones() as usize + } + + pub fn component_bits(&self) -> impl Iterator + '_ { + (0..std::mem::size_of::() * 8).filter(move |index| (self.index >> index) & 1 != 0) + } + + pub fn dual(&self, algebra: &GeometricAlgebra) -> Self { + Self { + index: algebra.basis_size() as BasisElementIndex - 1 - self.index, + } + } +} + +impl std::fmt::Display for BasisElement { + fn fmt(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + if self.index == 0 { + formatter.pad("1") + } else { + let string = format!("e{}", self.component_bits().map(|index| format!("{:X}", index)).collect::()); + formatter.pad(string.as_str()) + } + } +} + +impl std::cmp::Ord for BasisElement { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + let grades_order = self.grade().cmp(&other.grade()); + if grades_order != std::cmp::Ordering::Equal { + return grades_order; + } + let a_without_b = self.index & (!other.index); + let b_without_a = other.index & (!self.index); + if a_without_b.trailing_zeros() < b_without_a.trailing_zeros() { + std::cmp::Ordering::Less + } else { + std::cmp::Ordering::Greater + } + } +} + +impl std::cmp::PartialOrd for BasisElement { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +#[derive(Clone, PartialEq)] +pub struct ScaledElement { + pub scalar: isize, + pub unit: BasisElement, +} + +impl ScaledElement { + pub fn from(element: &BasisElement) -> Self { + Self { + scalar: 1, + unit: element.clone(), + } + } + + pub fn product(a: &Self, b: &Self, algebra: &GeometricAlgebra) -> Self { + let commutations = a + .unit + .component_bits() + .fold((0, a.unit.index, b.unit.index), |(commutations, a, b), index| { + let hurdles_a = a & (BasisElementIndex::MAX << (index + 1)); + let hurdles_b = b & ((1 << index) - 1); + ( + commutations + + BasisElement { + index: hurdles_a | hurdles_b, + } + .grade(), + a & !(1 << index), + b ^ (1 << index), + ) + }); + Self { + scalar: BasisElement { + index: a.unit.index & b.unit.index, + } + .component_bits() + .map(|i| algebra.generator_squares[i]) + .fold(a.scalar * b.scalar * if commutations.0 % 2 == 0 { 1 } else { -1 }, |a, b| a * b), + unit: BasisElement { + index: a.unit.index ^ b.unit.index, + }, + } + } +} + +impl std::fmt::Display for ScaledElement { + fn fmt(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + let string = self.unit.to_string(); + formatter.pad_integral(self.scalar >= 0, "", if self.scalar == 0 { "0" } else { string.as_str() }) + } +} + +#[derive(Clone)] +pub struct Involution { + pub terms: Vec, +} + +impl Involution { + pub fn identity(algebra: &GeometricAlgebra) -> Self { + Self { + terms: algebra.basis().map(|element| ScaledElement { scalar: 1, unit: element }).collect(), + } + } + + pub fn negated(&self, grade_negation: F) -> Self + where + F: Fn(usize) -> bool, + { + Self { + terms: self + .terms + .iter() + .map(|element| ScaledElement { + scalar: if grade_negation(element.unit.grade()) { -1 } else { 1 }, + unit: element.unit.clone(), + }) + .collect(), + } + } + + pub fn dual(&self, algebra: &GeometricAlgebra) -> Self { + Self { + terms: self + .terms + .iter() + .map(|term| ScaledElement { + scalar: term.scalar, + unit: term.unit.dual(algebra), + }) + .collect(), + } + } + + pub fn involutions(algebra: &GeometricAlgebra) -> Vec<(&'static str, Self)> { + let involution = Self::identity(algebra); + vec![ + ("Neg", involution.negated(|_grade| true)), + ("Automorph", involution.negated(|grade| grade % 2 == 1)), + ("Transpose", involution.negated(|grade| grade % 4 >= 2)), + ("Conjugate", involution.negated(|grade| (grade + 3) % 4 < 2)), + ("Dual", involution.dual(algebra)), + ] + } +} + +#[derive(Clone, PartialEq)] +pub struct ProductTerm { + pub product: ScaledElement, + pub factor_a: BasisElement, + pub factor_b: BasisElement, +} + +#[derive(Clone)] +pub struct Product { + pub terms: Vec, +} + +impl Product { + pub fn product(a: &[ScaledElement], b: &[ScaledElement], algebra: &GeometricAlgebra) -> Self { + Self { + terms: a + .iter() + .map(|a| { + b.iter().map(move |b| ProductTerm { + product: ScaledElement::product(&a, &b, algebra), + factor_a: a.unit.clone(), + factor_b: b.unit.clone(), + }) + }) + .flatten() + .filter(|term| term.product.scalar != 0) + .collect(), + } + } + + pub fn projected(&self, grade_projection: F) -> Self + where + F: Fn(usize, usize, usize) -> bool, + { + Self { + terms: self + .terms + .iter() + .filter(|term| grade_projection(term.factor_a.grade(), term.factor_b.grade(), term.product.unit.grade())) + .cloned() + .collect(), + } + } + + pub fn dual(&self, algebra: &GeometricAlgebra) -> Self { + Self { + terms: self + .terms + .iter() + .map(|term| ProductTerm { + product: ScaledElement { + scalar: term.product.scalar, + unit: term.product.unit.dual(algebra), + }, + factor_a: term.factor_a.dual(algebra), + factor_b: term.factor_b.dual(algebra), + }) + .collect(), + } + } + + pub fn products(algebra: &GeometricAlgebra) -> Vec<(&'static str, Self)> { + let basis = algebra.basis().map(|element| ScaledElement::from(&element)).collect::>(); + let product = Self::product(&basis, &basis, algebra); + vec![ + ("GeometricProduct", product.clone()), + ("RegressiveProduct", product.projected(|r, s, t| t == r + s).dual(algebra)), + ("OuterProduct", product.projected(|r, s, t| t == r + s)), + ("InnerProduct", product.projected(|r, s, t| t == (r as isize - s as isize).abs() as usize)), + ("LeftContraction", product.projected(|r, s, t| t as isize == s as isize - r as isize)), + ("RightContraction", product.projected(|r, s, t| t as isize == r as isize - s as isize)), + ("ScalarProduct", product.projected(|_r, _s, t| t == 0)), + ] + } +} + +#[derive(Default)] +pub struct MultiVectorClassRegistry { + pub classes: Vec, + index_by_signature: std::collections::HashMap, usize>, +} + +impl MultiVectorClassRegistry { + pub fn register(&mut self, class: MultiVectorClass) { + self.index_by_signature.insert(class.signature(), self.classes.len()); + self.classes.push(class); + } + + pub fn get(&self, signature: &[BasisElement]) -> Option<&MultiVectorClass> { + self.index_by_signature.get(signature).map(|index| &self.classes[*index]) + } +} + +#[derive(PartialEq, Eq)] +pub struct MultiVectorClass { + pub class_name: String, + pub grouped_basis: Vec>, +} diff --git a/codegen/src/ast.rs b/codegen/src/ast.rs new file mode 100644 index 0000000..f6b7fdb --- /dev/null +++ b/codegen/src/ast.rs @@ -0,0 +1,83 @@ +use crate::algebra::MultiVectorClass; + +#[derive(PartialEq, Eq, Clone)] +pub enum DataType<'a> { + Integer, + SimdVector(usize), + MultiVector(&'a MultiVectorClass), +} + +#[derive(PartialEq, Eq, Clone)] +pub enum ExpressionContent<'a> { + None, + Variable(&'static str), + InvokeClassMethod(&'a MultiVectorClass, &'static str, Vec<(DataType<'a>, Expression<'a>)>), + InvokeInstanceMethod(DataType<'a>, Box>, &'static str, Vec<(DataType<'a>, Expression<'a>)>), + Conversion(&'a MultiVectorClass, &'a MultiVectorClass, Box>), + Select(Box>, Box>, Box>), + Access(Box>, usize), + Swizzle(Box>, Vec), + Gather(Box>, Vec<(usize, usize)>), + Constant(DataType<'a>, Vec), + SquareRoot(Box>), + Add(Box>, Box>), + Subtract(Box>, Box>), + Multiply(Box>, Box>), + Divide(Box>, Box>), + LessThan(Box>, Box>), + Equal(Box>, Box>), + LogicAnd(Box>, Box>), + BitShiftRight(Box>, Box>), +} + +#[derive(PartialEq, Eq, Clone)] +pub struct Expression<'a> { + pub size: usize, + pub content: ExpressionContent<'a>, +} + +#[derive(PartialEq, Eq, Clone)] +pub struct Parameter<'a> { + pub name: &'static str, + pub data_type: DataType<'a>, +} + +impl<'a> Parameter<'a> { + pub fn multi_vector_class(&self) -> &'a MultiVectorClass { + if let DataType::MultiVector(class) = self.data_type { + class + } else { + unreachable!() + } + } +} + +#[derive(PartialEq, Eq, Clone)] +pub enum AstNode<'a> { + None, + Preamble, + ClassDefinition { + class: &'a MultiVectorClass, + }, + ReturnStatement { + expression: Box>, + }, + VariableAssignment { + name: &'static str, + data_type: Option>, + expression: Box>, + }, + IfThenBlock { + condition: Box>, + body: Vec>, + }, + WhileLoopBlock { + condition: Box>, + body: Vec>, + }, + TraitImplementation { + result: Parameter<'a>, + parameters: Vec>, + body: Vec>, + }, +} diff --git a/codegen/src/compile.rs b/codegen/src/compile.rs new file mode 100644 index 0000000..faac77d --- /dev/null +++ b/codegen/src/compile.rs @@ -0,0 +1,1261 @@ +use crate::{ + algebra::{BasisElement, Involution, MultiVectorClass, MultiVectorClassRegistry, Product}, + ast::{AstNode, DataType, Expression, ExpressionContent, Parameter}, +}; + +#[macro_export] +macro_rules! result_of_trait { + ($ast_node:expr) => { + match $ast_node { + AstNode::TraitImplementation { ref result, .. } => result, + _ => unreachable!(), + } + }; +} + +pub fn simplify_and_legalize(expression: Box) -> Box { + match expression.content { + ExpressionContent::Gather(mut inner_expression, indices) => { + if let Some(first_index_pair) = indices.first() { + inner_expression = simplify_and_legalize(inner_expression); + if indices.iter().all(|index_pair| index_pair == first_index_pair) { + Box::new(Expression { + size: expression.size, + content: ExpressionContent::Gather(inner_expression, vec![*first_index_pair]), + }) + } else if inner_expression.size == expression.size && indices.iter().all(|(array_index, _)| *array_index == first_index_pair.0) { + inner_expression = Box::new(Expression { + size: expression.size, + content: ExpressionContent::Access(inner_expression, first_index_pair.0), + }); + if indices.iter().enumerate().any(|(i, (_, component_index))| i != *component_index) { + Box::new(Expression { + size: expression.size, + content: ExpressionContent::Swizzle( + inner_expression, + indices.iter().map(|(_, component_index)| *component_index).collect(), + ), + }) + } else { + inner_expression + } + } else { + Box::new(Expression { + size: expression.size, + content: ExpressionContent::Gather(inner_expression, indices), + }) + } + } else { + Box::new(Expression { + size: expression.size, + content: ExpressionContent::None, + }) + } + } + ExpressionContent::Constant(ref data_type, ref values) => { + let first_value = values.first().unwrap(); + if values.iter().all(|value| value == first_value) { + Box::new(Expression { + size: expression.size, + content: ExpressionContent::Constant(data_type.clone(), vec![*first_value]), + }) + } else { + expression + } + } + ExpressionContent::Add(mut a, mut b) => { + if let ExpressionContent::Multiply(ref c, ref d) = b.content { + if let ExpressionContent::Multiply(ref e, ref f) = d.content { + if let ExpressionContent::Constant(_data_type, values) = &f.content { + if values.iter().all(|value| *value == -1) { + b = Box::new(Expression { + size: expression.size, + content: ExpressionContent::Multiply(c.clone(), e.clone()), + }); + return simplify_and_legalize(Box::new(Expression { + size: expression.size, + content: ExpressionContent::Subtract(a, b), + })); + } + } + } + } + a = simplify_and_legalize(a); + b = simplify_and_legalize(b); + if a.content == ExpressionContent::None { + b + } else if b.content == ExpressionContent::None { + a + } else { + Box::new(Expression { + size: expression.size, + content: ExpressionContent::Add(a, b), + }) + } + } + ExpressionContent::Subtract(mut a, mut b) => { + a = simplify_and_legalize(a); + b = simplify_and_legalize(b); + if a.content == ExpressionContent::None { + let constant = Expression { + size: expression.size, + content: ExpressionContent::Constant(DataType::SimdVector(expression.size), vec![0]), + }; + Box::new(Expression { + size: expression.size, + content: ExpressionContent::Subtract(Box::new(constant), b), + }) + } else if b.content == ExpressionContent::None { + a + } else { + Box::new(Expression { + size: expression.size, + content: ExpressionContent::Subtract(a, b), + }) + } + } + ExpressionContent::Multiply(mut a, mut b) => { + a = simplify_and_legalize(a); + b = simplify_and_legalize(b); + if let ExpressionContent::Constant(_, _) = a.content { + std::mem::swap(&mut a, &mut b) + } + if a.content == ExpressionContent::None { + b + } else { + match b.content { + ExpressionContent::None => a, + ExpressionContent::Constant(_data_type, c) if c.iter().all(|c| *c == 1) => a, + ExpressionContent::Constant(_data_type, c) if c.iter().all(|c| *c == 0) => Box::new(Expression { + size: expression.size, + content: ExpressionContent::None, + }), + _ => Box::new(Expression { + size: expression.size, + content: ExpressionContent::Multiply(a, b), + }), + } + } + } + _ => expression, + } +} + +impl MultiVectorClass { + pub fn flat_basis(&self) -> Vec { + self.grouped_basis.iter().flatten().cloned().collect() + } + + pub fn signature(&self) -> Vec { + let mut signature = self.flat_basis(); + signature.sort(); + signature + } + + pub fn index_in_group(&self, mut index: usize) -> (usize, usize) { + for (group_index, group) in self.grouped_basis.iter().enumerate() { + if index >= group.len() { + index -= group.len(); + } else { + return (group_index, index); + } + } + unreachable!() + } + + pub fn constant<'a>(&'a self, name: &'static str) -> AstNode<'a> { + let (scalar_value, other_values) = match name { + "Zero" => (0, 0), + "One" => (1, 0), + _ => unreachable!(), + }; + let mut body = Vec::new(); + for result_group in self.grouped_basis.iter() { + let size = result_group.len(); + let expression = Expression { + size, + content: ExpressionContent::Constant( + DataType::SimdVector(size), + result_group + .iter() + .map(|element| if element.index == 0 { scalar_value } else { other_values }) + .collect(), + ), + }; + body.push((DataType::SimdVector(size), *simplify_and_legalize(Box::new(expression)))); + } + AstNode::TraitImplementation { + result: Parameter { + name, + data_type: DataType::MultiVector(self), + }, + parameters: vec![], + body: vec![AstNode::ReturnStatement { + expression: Box::new(Expression { + size: 1, + content: ExpressionContent::InvokeClassMethod(self, "Constructor", body), + }), + }], + } + } + + pub fn involution<'a>( + name: &'static str, + involution: &Involution, + parameter_a: &Parameter<'a>, + registry: &'a MultiVectorClassRegistry, + ) -> AstNode<'a> { + let a_flat_basis = parameter_a.multi_vector_class().flat_basis(); + let result_signature = a_flat_basis + .iter() + .map(|element| involution.terms[element.index as usize].unit.clone()) + .collect::>(); + let mut result_signature = result_signature.into_iter().collect::>(); + result_signature.sort(); + if let Some(result_class) = registry.get(&result_signature) { + let result_flat_basis = result_class.flat_basis(); + let mut body = Vec::new(); + let mut base_index = 0; + for result_group in result_class.grouped_basis.iter() { + let size = result_group.len(); + let a_indices = (0..size) + .map(|index_in_group| { + let result_element = &result_flat_basis[base_index + index_in_group]; + let a_element = &involution.terms[result_element.index as usize].unit; + parameter_a + .multi_vector_class() + .index_in_group(a_flat_basis.iter().position(|element| element == a_element).unwrap()) + }) + .collect::>(); + let a_group_index = a_indices[0].0; + let expression = Expression { + size, + content: ExpressionContent::Multiply( + Box::new(Expression { + size, + content: ExpressionContent::Gather( + Box::new(Expression { + size: parameter_a.multi_vector_class().grouped_basis[a_group_index].len(), + content: ExpressionContent::Variable(parameter_a.name), + }), + a_indices, + ), + }), + Box::new(Expression { + size, + content: ExpressionContent::Constant( + DataType::SimdVector(size), + result_group + .iter() + .map(|element| involution.terms[element.index as usize].scalar) + .collect(), + ), + }), + ), + }; + body.push((DataType::SimdVector(size), *simplify_and_legalize(Box::new(expression)))); + base_index += size; + } + AstNode::TraitImplementation { + result: Parameter { + name, + data_type: DataType::MultiVector(result_class), + }, + parameters: vec![parameter_a.clone()], + body: vec![AstNode::ReturnStatement { + expression: Box::new(Expression { + size: 1, + content: ExpressionContent::InvokeClassMethod(result_class, "Constructor", body), + }), + }], + } + } else { + AstNode::None + } + } + + pub fn conversion<'a>(name: &'static str, parameter_a: &Parameter<'a>, result_class: &'a MultiVectorClass) -> AstNode<'a> { + if parameter_a.multi_vector_class() == result_class { + return AstNode::None; + } + let a_flat_basis = parameter_a.multi_vector_class().flat_basis(); + let result_flat_basis = result_class.flat_basis(); + for element in &result_flat_basis { + if !a_flat_basis.contains(element) { + return AstNode::None; + } + } + let mut body = Vec::new(); + let mut base_index = 0; + for result_group in result_class.grouped_basis.iter() { + let size = result_group.len(); + let a_indices = (0..size) + .map(|index_in_group| { + let result_element = &result_flat_basis[base_index + index_in_group]; + parameter_a + .multi_vector_class() + .index_in_group(a_flat_basis.iter().position(|a_element| a_element == result_element).unwrap()) + }) + .collect::>(); + let a_group_index = a_indices[0].0; + let expression = Expression { + size, + content: ExpressionContent::Gather( + Box::new(Expression { + size: parameter_a.multi_vector_class().grouped_basis[a_group_index].len(), + content: ExpressionContent::Variable(parameter_a.name), + }), + a_indices, + ), + }; + body.push((DataType::SimdVector(size), *simplify_and_legalize(Box::new(expression)))); + base_index += size; + } + AstNode::TraitImplementation { + result: Parameter { + name, + data_type: DataType::MultiVector(result_class), + }, + parameters: vec![parameter_a.clone()], + body: vec![AstNode::ReturnStatement { + expression: Box::new(Expression { + size: 1, + content: ExpressionContent::InvokeClassMethod(result_class, "Constructor", body), + }), + }], + } + } + + pub fn sum<'a>( + name: &'static str, + parameter_a: &Parameter<'a>, + parameter_b: &Parameter<'a>, + registry: &'a MultiVectorClassRegistry, + ) -> AstNode<'a> { + let a_flat_basis = parameter_a.multi_vector_class().flat_basis(); + let b_flat_basis = parameter_b.multi_vector_class().flat_basis(); + let result_signature = a_flat_basis + .iter() + .chain(b_flat_basis.iter()) + .cloned() + .collect::>(); + let mut result_signature = result_signature.into_iter().collect::>(); + result_signature.sort(); + if let Some(result_class) = registry.get(&result_signature) { + let parameters = [(parameter_a, &a_flat_basis), (parameter_b, &b_flat_basis)]; + let mut body = Vec::new(); + for result_group in result_class.grouped_basis.iter() { + let size = result_group.len(); + let mut expressions = parameters.iter().map(|(parameter, flat_basis)| { + let mut parameter_group_index = None; + let terms: Vec<_> = result_group + .iter() + .map(|result_element| { + if let Some(index_in_group) = flat_basis.iter().position(|element| element == result_element) { + let index_pair = parameter.multi_vector_class().index_in_group(index_in_group); + parameter_group_index = Some(index_pair.0); + (1, index_pair) + } else { + (0, (0, 0)) + } + }) + .collect(); + Expression { + size, + content: ExpressionContent::Multiply( + Box::new(Expression { + size, + content: ExpressionContent::Gather( + Box::new(Expression { + size: if let Some(index) = parameter_group_index { + parameter.multi_vector_class().grouped_basis[index].len() + } else { + size + }, + content: ExpressionContent::Variable(parameter.name), + }), + terms.iter().map(|(_factor, index_pair)| index_pair).cloned().collect(), + ), + }), + Box::new(Expression { + size, + content: ExpressionContent::Constant( + DataType::SimdVector(size), + terms.iter().map(|(factor, _index_pair)| *factor).collect::>(), + ), + }), + ), + } + }); + body.push(( + DataType::SimdVector(size), + *simplify_and_legalize(Box::new(Expression { + size, + content: match name { + "Add" => ExpressionContent::Add(Box::new(expressions.next().unwrap()), Box::new(expressions.next().unwrap())), + "Sub" => ExpressionContent::Subtract(Box::new(expressions.next().unwrap()), Box::new(expressions.next().unwrap())), + _ => unreachable!(), + }, + })), + )); + } + AstNode::TraitImplementation { + result: Parameter { + name, + data_type: DataType::MultiVector(result_class), + }, + parameters: vec![parameter_a.clone(), parameter_b.clone()], + body: vec![AstNode::ReturnStatement { + expression: Box::new(Expression { + size: 1, + content: ExpressionContent::InvokeClassMethod(result_class, "Constructor", body), + }), + }], + } + } else { + AstNode::None + } + } + + pub fn product<'a>( + name: &'static str, + product: &Product, + parameter_a: &Parameter<'a>, + parameter_b: &Parameter<'a>, + registry: &'a MultiVectorClassRegistry, + ) -> AstNode<'a> { + let a_flat_basis = parameter_a.multi_vector_class().flat_basis(); + let b_flat_basis = parameter_b.multi_vector_class().flat_basis(); + let mut result_signature = std::collections::HashSet::new(); + for product_term in product.terms.iter() { + if a_flat_basis.contains(&product_term.factor_a) && b_flat_basis.contains(&product_term.factor_b) { + result_signature.insert(product_term.product.unit.clone()); + } + } + let mut result_signature = result_signature.into_iter().collect::>(); + result_signature.sort(); + if let Some(result_class) = registry.get(&result_signature) { + let result_flat_basis = result_class.flat_basis(); + let mut sorted_terms = vec![vec![(0, 0); a_flat_basis.len()]; result_flat_basis.len()]; + for product_term in product.terms.iter() { + if let Some(y) = result_flat_basis.iter().position(|e| e == &product_term.product.unit) { + if let Some(x) = a_flat_basis.iter().position(|e| e == &product_term.factor_a) { + if let Some(gather_index) = b_flat_basis.iter().position(|e| e == &product_term.factor_b) { + sorted_terms[y][x] = (product_term.product.scalar, gather_index); + } + } + } + } + let mut body = Vec::new(); + let mut base_index = 0; + for result_group in result_class.grouped_basis.iter() { + let size = result_group.len(); + let mut expression = Expression { + size, + content: ExpressionContent::None, + }; + let result_terms = (0..size) + .map(|index_in_group| &sorted_terms[base_index + index_in_group]) + .collect::>(); + let transposed_terms = (0..result_terms[0].len()).map(|i| result_terms.iter().map(|inner| inner[i]).collect::>()); + let mut contraction = ( + Expression { + size, + content: ExpressionContent::None, + }, + Expression { + size, + content: ExpressionContent::None, + }, + vec![(0, 0); expression.size], + vec![(0, 0); expression.size], + vec![false; expression.size], + ); + for (index_in_a, a_terms) in transposed_terms.enumerate() { + if a_terms.iter().all(|(factor, _)| *factor == 0) { + continue; + } + let (a_group_index, a_index_in_group) = parameter_a.multi_vector_class().index_in_group(index_in_a); + let a_indices = a_terms.iter().map(|_| (a_group_index, a_index_in_group)).collect::>(); + let b_indices = a_terms + .iter() + .map(|(_, index_in_b)| parameter_b.multi_vector_class().index_in_group(*index_in_b)) + .collect::>(); + let non_zero_index = a_terms.iter().position(|(factor, _index_pair)| *factor != 0).unwrap(); + let b_group_index = b_indices[non_zero_index].0; + let b_indices = a_terms + .iter() + .enumerate() + .map(|(index, (factor, _index_pair))| b_indices[if *factor == 0 { non_zero_index } else { index }]) + .collect::>(); + let is_contractable = a_terms + .iter() + .enumerate() + .all(|(i, (factor, _))| *factor == 0 || *factor == 1 && !contraction.4[i]) + && (contraction.0.content == ExpressionContent::None + || contraction.0.size == parameter_a.multi_vector_class().grouped_basis[a_group_index].len()) + && (contraction.1.content == ExpressionContent::None + || contraction.1.size == parameter_b.multi_vector_class().grouped_basis[b_group_index].len()); + if is_contractable && a_terms.iter().any(|(factor, _)| *factor == 0) { + if contraction.0.content == ExpressionContent::None { + assert!(contraction.1.content == ExpressionContent::None); + contraction.0 = Expression { + size: parameter_a.multi_vector_class().grouped_basis[a_group_index].len(), + content: ExpressionContent::Variable(parameter_a.name), + }; + contraction.1 = Expression { + size: parameter_b.multi_vector_class().grouped_basis[b_group_index].len(), + content: ExpressionContent::Variable(parameter_b.name), + }; + contraction.2 = a_indices.iter().map(|(a_group_index, _)| (*a_group_index, 0)).collect(); + contraction.3 = b_indices.iter().map(|(b_group_index, _)| (*b_group_index, 0)).collect(); + } + for (i, (factor, _index_in_b)) in a_terms.iter().enumerate() { + if *factor == 1 { + contraction.2[i] = a_indices[i]; + contraction.3[i] = b_indices[i]; + contraction.4[i] = true; + } + } + } else { + expression = Expression { + size, + content: ExpressionContent::Add( + Box::new(expression), + Box::new(Expression { + size, + content: ExpressionContent::Multiply( + Box::new(Expression { + size, + content: ExpressionContent::Gather( + Box::new(Expression { + size: parameter_a.multi_vector_class().grouped_basis[a_group_index].len(), + content: ExpressionContent::Variable(parameter_a.name), + }), + a_indices, + ), + }), + Box::new(Expression { + size, + content: ExpressionContent::Multiply( + Box::new(Expression { + size, + content: ExpressionContent::Gather( + Box::new(Expression { + size: parameter_b.multi_vector_class().grouped_basis[b_group_index].len(), + content: ExpressionContent::Variable(parameter_b.name), + }), + b_indices, + ), + }), + Box::new(Expression { + size, + content: ExpressionContent::Constant( + DataType::SimdVector(size), + a_terms.iter().map(|(factor, _)| *factor).collect::>(), + ), + }), + ), + }), + ), + }), + ), + }; + } + } + if contraction.4.iter().any(|mask| *mask) { + expression = Expression { + size, + content: ExpressionContent::Add( + Box::new(expression), + Box::new(Expression { + size, + content: ExpressionContent::Multiply( + Box::new(Expression { + size, + content: ExpressionContent::Multiply( + Box::new(Expression { + size, + content: ExpressionContent::Gather(Box::new(contraction.0), contraction.2), + }), + Box::new(Expression { + size, + content: ExpressionContent::Gather(Box::new(contraction.1), contraction.3), + }), + ), + }), + Box::new(Expression { + size, + content: ExpressionContent::Constant( + DataType::SimdVector(size), + contraction.4.iter().map(|value| *value as isize).collect(), + ), + }), + ), + }), + ), + }; + } + if expression.content == ExpressionContent::None { + expression = Expression { + size, + content: ExpressionContent::Constant(DataType::SimdVector(size), (0..size).map(|_| 0).collect()), + }; + } + body.push((DataType::SimdVector(size), *simplify_and_legalize(Box::new(expression)))); + base_index += size; + } + if body.is_empty() { + AstNode::None + } else { + AstNode::TraitImplementation { + result: Parameter { + name, + data_type: DataType::MultiVector(result_class), + }, + parameters: vec![parameter_a.clone(), parameter_b.clone()], + body: vec![AstNode::ReturnStatement { + expression: Box::new(Expression { + size: 1, + content: ExpressionContent::InvokeClassMethod(result_class, "Constructor", body), + }), + }], + } + } + } else { + AstNode::None + } + } + + pub fn derive_multiplication<'a>( + name: &'static str, + geometric_product: &AstNode<'a>, + parameter_a: &Parameter<'a>, + parameter_b: &Parameter<'a>, + ) -> AstNode<'a> { + let geometric_product_result = result_of_trait!(geometric_product); + AstNode::TraitImplementation { + result: Parameter { + name, + data_type: geometric_product_result.data_type.clone(), + }, + parameters: vec![parameter_a.clone(), parameter_b.clone()], + body: vec![AstNode::ReturnStatement { + expression: Box::new(Expression { + size: 1, + content: ExpressionContent::InvokeInstanceMethod( + parameter_a.data_type.clone(), + Box::new(Expression { + size: 1, + content: ExpressionContent::Variable(parameter_a.name), + }), + geometric_product_result.name, + vec![( + DataType::MultiVector(parameter_b.multi_vector_class()), + Expression { + size: 1, + content: ExpressionContent::Variable(parameter_b.name), + }, + )], + ), + }), + }], + } + } + + pub fn derive_squared_magnitude<'a>( + name: &'static str, + scalar_product: &AstNode<'a>, + involution: &AstNode<'a>, + parameter_a: &Parameter<'a>, + ) -> AstNode<'a> { + let scalar_product_result = result_of_trait!(scalar_product); + let involution_result = result_of_trait!(involution); + AstNode::TraitImplementation { + result: Parameter { + name, + data_type: scalar_product_result.data_type.clone(), + }, + parameters: vec![parameter_a.clone()], + body: vec![AstNode::ReturnStatement { + expression: Box::new(Expression { + size: 1, + content: ExpressionContent::InvokeInstanceMethod( + parameter_a.data_type.clone(), + Box::new(Expression { + size: 1, + content: ExpressionContent::Variable(parameter_a.name), + }), + scalar_product_result.name, + vec![( + DataType::MultiVector(involution_result.multi_vector_class()), + Expression { + size: 1, + content: ExpressionContent::InvokeInstanceMethod( + parameter_a.data_type.clone(), + Box::new(Expression { + size: 1, + content: ExpressionContent::Variable(parameter_a.name), + }), + involution_result.name, + vec![], + ), + }, + )], + ), + }), + }], + } + } + + pub fn derive_magnitude<'a>(name: &'static str, squared_magnitude: &AstNode<'a>, parameter_a: &Parameter<'a>) -> AstNode<'a> { + let squared_magnitude_result = result_of_trait!(squared_magnitude); + AstNode::TraitImplementation { + result: Parameter { + name, + data_type: squared_magnitude_result.data_type.clone(), + }, + parameters: vec![parameter_a.clone()], + body: vec![AstNode::ReturnStatement { + expression: Box::new(Expression { + size: 1, + content: ExpressionContent::InvokeClassMethod( + squared_magnitude_result.multi_vector_class(), + "Constructor", + vec![( + DataType::SimdVector(1), + Expression { + size: 1, + content: ExpressionContent::SquareRoot(Box::new(Expression { + size: 1, + content: ExpressionContent::Access( + Box::new(Expression { + size: 1, + content: ExpressionContent::InvokeInstanceMethod( + parameter_a.data_type.clone(), + Box::new(Expression { + size: 1, + content: ExpressionContent::Variable(parameter_a.name), + }), + squared_magnitude_result.name, + vec![], + ), + }), + 0, + ), + })), + }, + )], + ), + }), + }], + } + } + + pub fn derive_signum<'a>( + name: &'static str, + geometric_product: &AstNode<'a>, + magnitude: &AstNode<'a>, + parameter_a: &Parameter<'a>, + ) -> AstNode<'a> { + let geometric_product_result = result_of_trait!(geometric_product); + let magnitude_result = result_of_trait!(magnitude); + AstNode::TraitImplementation { + result: Parameter { + name, + data_type: geometric_product_result.data_type.clone(), + }, + parameters: vec![parameter_a.clone()], + body: vec![AstNode::ReturnStatement { + expression: Box::new(Expression { + size: 1, + content: ExpressionContent::InvokeInstanceMethod( + parameter_a.data_type.clone(), + Box::new(Expression { + size: 1, + content: ExpressionContent::Variable(parameter_a.name), + }), + geometric_product_result.name, + vec![( + DataType::MultiVector(magnitude_result.multi_vector_class()), + Expression { + size: 1, + content: ExpressionContent::InvokeClassMethod( + magnitude_result.multi_vector_class(), + "Constructor", + vec![( + DataType::SimdVector(1), + Expression { + size: 1, + content: ExpressionContent::Divide( + Box::new(Expression { + size: 1, + content: ExpressionContent::Constant(DataType::SimdVector(1), vec![1]), + }), + Box::new(Expression { + size: 1, + content: ExpressionContent::Access( + Box::new(Expression { + size: 1, + content: ExpressionContent::InvokeInstanceMethod( + parameter_a.data_type.clone(), + Box::new(Expression { + size: 1, + content: ExpressionContent::Variable(parameter_a.name), + }), + magnitude_result.name, + vec![], + ), + }), + 0, + ), + }), + ), + }, + )], + ), + }, + )], + ), + }), + }], + } + } + + pub fn derive_inverse<'a>( + name: &'static str, + geometric_product: &AstNode<'a>, + squared_magnitude: &AstNode<'a>, + involution: &AstNode<'a>, + parameter_a: &Parameter<'a>, + ) -> AstNode<'a> { + let geometric_product_result = result_of_trait!(geometric_product); + let squared_magnitude_result = result_of_trait!(squared_magnitude); + let involution_result = result_of_trait!(involution); + AstNode::TraitImplementation { + result: Parameter { + name, + data_type: geometric_product_result.data_type.clone(), + }, + parameters: vec![parameter_a.clone()], + body: vec![AstNode::ReturnStatement { + expression: Box::new(Expression { + size: 1, + content: ExpressionContent::InvokeInstanceMethod( + involution_result.data_type.clone(), + Box::new(Expression { + size: 1, + content: ExpressionContent::InvokeInstanceMethod( + parameter_a.data_type.clone(), + Box::new(Expression { + size: 1, + content: ExpressionContent::Variable(parameter_a.name), + }), + involution_result.name, + vec![], + ), + }), + geometric_product_result.name, + vec![( + DataType::MultiVector(squared_magnitude_result.multi_vector_class()), + Expression { + size: 1, + content: ExpressionContent::InvokeClassMethod( + squared_magnitude_result.multi_vector_class(), + "Constructor", + vec![( + DataType::SimdVector(1), + Expression { + size: 1, + content: ExpressionContent::Divide( + Box::new(Expression { + size: 1, + content: ExpressionContent::Constant(DataType::SimdVector(1), vec![1]), + }), + Box::new(Expression { + size: 1, + content: ExpressionContent::Access( + Box::new(Expression { + size: 1, + content: ExpressionContent::InvokeInstanceMethod( + parameter_a.data_type.clone(), + Box::new(Expression { + size: 1, + content: ExpressionContent::Variable(parameter_a.name), + }), + squared_magnitude_result.name, + vec![], + ), + }), + 0, + ), + }), + ), + }, + )], + ), + }, + )], + ), + }), + }], + } + } + + pub fn derive_power_of_integer<'a>( + name: &'static str, + geometric_product: &AstNode<'a>, + constant_one: &AstNode<'a>, + inverse: &AstNode<'a>, + parameter_a: &Parameter<'a>, + parameter_b: &Parameter<'a>, + ) -> AstNode<'a> { + let geometric_product_result = result_of_trait!(geometric_product); + let constant_one_result = result_of_trait!(constant_one); + let inverse_result = result_of_trait!(inverse); + AstNode::TraitImplementation { + result: Parameter { + name, + data_type: parameter_a.data_type.clone(), + }, + parameters: vec![parameter_a.clone(), parameter_b.clone()], + body: vec![ + AstNode::IfThenBlock { + condition: Box::new(Expression { + size: 1, + content: ExpressionContent::Equal( + Box::new(Expression { + size: 1, + content: ExpressionContent::Variable(parameter_b.name), + }), + Box::new(Expression { + size: 1, + content: ExpressionContent::Constant(DataType::Integer, vec![0]), + }), + ), + }), + body: vec![AstNode::ReturnStatement { + expression: Box::new(Expression { + size: 1, + content: ExpressionContent::InvokeClassMethod(parameter_a.multi_vector_class(), constant_one_result.name, vec![]), + }), + }], + }, + AstNode::VariableAssignment { + name: "x", + data_type: Some(parameter_a.data_type.clone()), + expression: Box::new(Expression { + size: 1, + content: ExpressionContent::Select( + Box::new(Expression { + size: 1, + content: ExpressionContent::LessThan( + Box::new(Expression { + size: 1, + content: ExpressionContent::Variable(parameter_b.name), + }), + Box::new(Expression { + size: 1, + content: ExpressionContent::Constant(DataType::Integer, vec![0]), + }), + ), + }), + Box::new(Expression { + size: 1, + content: ExpressionContent::InvokeInstanceMethod( + parameter_a.data_type.clone(), + Box::new(Expression { + size: 1, + content: ExpressionContent::Variable(parameter_a.name), + }), + inverse_result.name, + vec![], + ), + }), + Box::new(Expression { + size: 1, + content: ExpressionContent::Variable(parameter_a.name), + }), + ), + }), + }, + AstNode::VariableAssignment { + name: "y", + data_type: Some(parameter_a.data_type.clone()), + expression: Box::new(Expression { + size: 1, + content: ExpressionContent::InvokeClassMethod(parameter_a.multi_vector_class(), constant_one_result.name, vec![]), + }), + }, + AstNode::VariableAssignment { + name: "n", + data_type: Some(DataType::Integer), + expression: Box::new(Expression { + size: 1, + content: ExpressionContent::InvokeInstanceMethod( + DataType::Integer, + Box::new(Expression { + size: 1, + content: ExpressionContent::Variable(parameter_b.name), + }), + "Abs", + vec![], + ), + }), + }, + AstNode::WhileLoopBlock { + condition: Box::new(Expression { + size: 1, + content: ExpressionContent::LessThan( + Box::new(Expression { + size: 1, + content: ExpressionContent::Constant(DataType::Integer, vec![1]), + }), + Box::new(Expression { + size: 1, + content: ExpressionContent::Variable("n"), + }), + ), + }), + body: vec![ + AstNode::IfThenBlock { + condition: Box::new(Expression { + size: 1, + content: ExpressionContent::Equal( + Box::new(Expression { + size: 1, + content: ExpressionContent::LogicAnd( + Box::new(Expression { + size: 1, + content: ExpressionContent::Variable("n"), + }), + Box::new(Expression { + size: 1, + content: ExpressionContent::Constant(DataType::Integer, vec![1]), + }), + ), + }), + Box::new(Expression { + size: 1, + content: ExpressionContent::Constant(DataType::Integer, vec![1]), + }), + ), + }), + body: vec![AstNode::VariableAssignment { + name: "y", + data_type: None, + expression: Box::new(Expression { + size: 1, + content: ExpressionContent::InvokeInstanceMethod( + parameter_a.data_type.clone(), + Box::new(Expression { + size: 1, + content: ExpressionContent::Variable("x"), + }), + geometric_product_result.name, + vec![( + DataType::MultiVector(parameter_a.multi_vector_class()), + Expression { + size: 1, + content: ExpressionContent::Variable("y"), + }, + )], + ), + }), + }], + }, + AstNode::VariableAssignment { + name: "x", + data_type: None, + expression: Box::new(Expression { + size: 1, + content: ExpressionContent::InvokeInstanceMethod( + parameter_a.data_type.clone(), + Box::new(Expression { + size: 1, + content: ExpressionContent::Variable("x"), + }), + geometric_product_result.name, + vec![( + DataType::MultiVector(parameter_a.multi_vector_class()), + Expression { + size: 1, + content: ExpressionContent::Variable("x"), + }, + )], + ), + }), + }, + AstNode::VariableAssignment { + name: "n", + data_type: None, + expression: Box::new(Expression { + size: 1, + content: ExpressionContent::BitShiftRight( + Box::new(Expression { + size: 1, + content: ExpressionContent::Variable("n"), + }), + Box::new(Expression { + size: 1, + content: ExpressionContent::Constant(DataType::Integer, vec![1]), + }), + ), + }), + }, + ], + }, + AstNode::ReturnStatement { + expression: Box::new(Expression { + size: 1, + content: ExpressionContent::InvokeInstanceMethod( + parameter_a.data_type.clone(), + Box::new(Expression { + size: 1, + content: ExpressionContent::Variable("x"), + }), + geometric_product_result.name, + vec![( + DataType::MultiVector(parameter_a.multi_vector_class()), + Expression { + size: 1, + content: ExpressionContent::Variable("y"), + }, + )], + ), + }), + }, + ], + } + } + + pub fn derive_division<'a>( + name: &'static str, + geometric_product: &AstNode<'a>, + inverse: &AstNode<'a>, + parameter_a: &Parameter<'a>, + parameter_b: &Parameter<'a>, + ) -> AstNode<'a> { + let geometric_product_result = result_of_trait!(geometric_product); + let inverse_result = result_of_trait!(inverse); + AstNode::TraitImplementation { + result: Parameter { + name, + data_type: geometric_product_result.data_type.clone(), + }, + parameters: vec![parameter_a.clone(), parameter_b.clone()], + body: vec![AstNode::ReturnStatement { + expression: Box::new(Expression { + size: 1, + content: ExpressionContent::InvokeInstanceMethod( + parameter_a.data_type.clone(), + Box::new(Expression { + size: 1, + content: ExpressionContent::Variable(parameter_a.name), + }), + geometric_product_result.name, + vec![( + DataType::MultiVector(inverse_result.multi_vector_class()), + Expression { + size: 1, + content: ExpressionContent::InvokeInstanceMethod( + parameter_b.data_type.clone(), + Box::new(Expression { + size: 1, + content: ExpressionContent::Variable(parameter_b.name), + }), + inverse_result.name, + vec![], + ), + }, + )], + ), + }), + }], + } + } + + pub fn derive_sandwich_product<'a>( + name: &'static str, + geometric_product: &AstNode<'a>, + geometric_product_2: &AstNode<'a>, + involution: &AstNode<'a>, + conversion: Option<&AstNode<'a>>, + parameter_a: &Parameter<'a>, + parameter_b: &Parameter<'a>, + ) -> AstNode<'a> { + let geometric_product_result = result_of_trait!(geometric_product); + let geometric_product_2_result = result_of_trait!(geometric_product_2); + let involution_result = result_of_trait!(involution); + let product = Box::new(Expression { + size: 1, + content: ExpressionContent::InvokeInstanceMethod( + geometric_product_result.data_type.clone(), + Box::new(Expression { + size: 1, + content: ExpressionContent::InvokeInstanceMethod( + parameter_a.data_type.clone(), + Box::new(Expression { + size: 1, + content: ExpressionContent::Variable(parameter_a.name), + }), + geometric_product_result.name, + vec![( + DataType::MultiVector(parameter_b.multi_vector_class()), + Expression { + size: 1, + content: ExpressionContent::Variable(parameter_b.name), + }, + )], + ), + }), + geometric_product_2_result.name, + vec![( + DataType::MultiVector(involution_result.multi_vector_class()), + Expression { + size: 1, + content: match name { + "Reflection" => ExpressionContent::Variable(parameter_a.name), + "Transformation" => ExpressionContent::InvokeInstanceMethod( + parameter_a.data_type.clone(), + Box::new(Expression { + size: 1, + content: ExpressionContent::Variable(parameter_a.name), + }), + involution_result.name, + vec![], + ), + _ => unreachable!(), + }, + }, + )], + ), + }); + let conversion_result = if let Some(conversion) = conversion { + result_of_trait!(conversion) + } else { + geometric_product_2_result + }; + AstNode::TraitImplementation { + result: Parameter { + name, + data_type: conversion_result.data_type.clone(), + }, + parameters: vec![parameter_a.clone(), parameter_b.clone()], + body: vec![AstNode::ReturnStatement { + expression: if conversion.is_some() { + Box::new(Expression { + size: 1, + content: ExpressionContent::Conversion( + geometric_product_2_result.multi_vector_class(), + conversion_result.multi_vector_class(), + product, + ), + }) + } else { + product + }, + }], + } + } +} diff --git a/codegen/src/emit.rs b/codegen/src/emit.rs new file mode 100644 index 0000000..e80f202 --- /dev/null +++ b/codegen/src/emit.rs @@ -0,0 +1,46 @@ +pub fn camel_to_snake_case(collector: &mut W, name: &str) -> std::io::Result<()> { + let mut underscores = name.chars().enumerate().filter(|(_i, c)| c.is_uppercase()).map(|(i, _c)| i).peekable(); + for (i, c) in name.to_lowercase().bytes().enumerate() { + if let Some(next_underscores) = underscores.peek() { + if i == *next_underscores { + if i > 0 { + collector.write_all(b"_")?; + } + underscores.next(); + } + } + collector.write_all(&[c])?; + } + Ok(()) +} + +pub fn emit_indentation(collector: &mut W, indentation: usize) -> std::io::Result<()> { + for _ in 0..indentation { + collector.write_all(b" ")?; + } + Ok(()) +} + +use crate::{ast::AstNode, glsl, rust}; + +pub struct Emitter { + pub rust_collector: W, + pub glsl_collector: W, +} + +impl Emitter { + pub fn new(path: &std::path::Path) -> Self { + Self { + rust_collector: std::fs::File::create(path.with_extension("rs")).unwrap(), + glsl_collector: std::fs::File::create(path.with_extension("glsl")).unwrap(), + } + } +} + +impl Emitter { + pub fn emit(&mut self, ast_node: &AstNode) -> std::io::Result<()> { + rust::emit_code(&mut self.rust_collector, ast_node, 0)?; + glsl::emit_code(&mut self.glsl_collector, ast_node, 0)?; + Ok(()) + } +} diff --git a/codegen/src/glsl.rs b/codegen/src/glsl.rs new file mode 100644 index 0000000..d66b066 --- /dev/null +++ b/codegen/src/glsl.rs @@ -0,0 +1,254 @@ +use crate::{ + ast::{AstNode, DataType, Expression, ExpressionContent}, + emit::{camel_to_snake_case, emit_indentation}, +}; + +const COMPONENT: &[&str] = &["x", "y", "z", "w"]; + +fn emit_data_type(collector: &mut W, data_type: &DataType) -> std::io::Result<()> { + match data_type { + DataType::Integer => collector.write_all(b"int"), + DataType::SimdVector(size) => collector.write_fmt(format_args!("vec{}", size)), + DataType::MultiVector(class) => collector.write_fmt(format_args!("{}", class.class_name)), + } +} + +fn emit_expression(collector: &mut W, expression: &Expression) -> std::io::Result<()> { + match &expression.content { + ExpressionContent::None => unreachable!(), + ExpressionContent::Variable(name) => { + collector.write_all(name.bytes().collect::>().as_slice())?; + } + ExpressionContent::InvokeClassMethod(_, _, arguments) | ExpressionContent::InvokeInstanceMethod(_, _, _, arguments) => { + match &expression.content { + ExpressionContent::InvokeInstanceMethod(result_class, inner_expression, method_name, _) => { + if let DataType::MultiVector(result_class) = result_class { + camel_to_snake_case(collector, &result_class.class_name)?; + collector.write_all(b"_")?; + } + for (argument_class, _argument) in arguments.iter() { + if let DataType::MultiVector(argument_class) = argument_class { + camel_to_snake_case(collector, &argument_class.class_name)?; + collector.write_all(b"_")?; + } + } + camel_to_snake_case(collector, method_name)?; + collector.write_all(b"(")?; + emit_expression(collector, &inner_expression)?; + if !arguments.is_empty() { + collector.write_all(b", ")?; + } + } + ExpressionContent::InvokeClassMethod(class, method_name, _) => { + if *method_name == "Constructor" { + collector.write_fmt(format_args!("{}", &class.class_name))?; + } else { + camel_to_snake_case(collector, &class.class_name)?; + collector.write_all(b"_")?; + camel_to_snake_case(collector, method_name)?; + } + collector.write_all(b"(")?; + } + _ => unreachable!(), + } + for (i, (_argument_class, argument)) in arguments.iter().enumerate() { + if i > 0 { + collector.write_all(b", ")?; + } + emit_expression(collector, &argument)?; + } + collector.write_all(b")")?; + } + ExpressionContent::Conversion(source_class, destination_class, inner_expression) => { + camel_to_snake_case(collector, &source_class.class_name)?; + collector.write_all(b"_")?; + camel_to_snake_case(collector, &destination_class.class_name)?; + collector.write_all(b"_into(")?; + emit_expression(collector, &inner_expression)?; + collector.write_all(b")")?; + } + ExpressionContent::Select(condition_expression, then_expression, else_expression) => { + collector.write_all(b"(")?; + emit_expression(collector, &condition_expression)?; + collector.write_all(b") ? ")?; + emit_expression(collector, &then_expression)?; + collector.write_all(b" : ")?; + emit_expression(collector, &else_expression)?; + } + ExpressionContent::Access(inner_expression, array_index) => { + emit_expression(collector, &inner_expression)?; + collector.write_fmt(format_args!(".g{}", array_index))?; + } + ExpressionContent::Swizzle(inner_expression, indices) => { + emit_expression(collector, &inner_expression)?; + collector.write_all(b".")?; + for component_index in indices.iter() { + collector.write_all(COMPONENT[*component_index].bytes().collect::>().as_slice())?; + } + } + ExpressionContent::Gather(inner_expression, indices) => { + if expression.size > 1 { + collector.write_fmt(format_args!("vec{}(", expression.size))?; + } + for (i, (array_index, component_index)) in indices.iter().enumerate() { + if i > 0 { + collector.write_all(b", ")?; + } + emit_expression(collector, &inner_expression)?; + collector.write_fmt(format_args!(".g{}", array_index))?; + if inner_expression.size > 1 { + collector.write_fmt(format_args!(".{}", COMPONENT[*component_index]))?; + } + } + if expression.size > 1 { + collector.write_all(b")")?; + } + } + ExpressionContent::Constant(data_type, values) => match data_type { + DataType::Integer => collector.write_fmt(format_args!("{}", values[0] as f32))?, + DataType::SimdVector(_size) => { + if expression.size == 1 { + collector.write_fmt(format_args!("{:.1}", values[0] as f32))? + } else { + collector.write_fmt(format_args!( + "vec{}({})", + expression.size, + values.iter().map(|value| format!("{:.1}", *value as f32)).collect::>().join(", ") + ))? + } + } + _ => unreachable!(), + }, + ExpressionContent::SquareRoot(inner_expression) => { + collector.write_all(b"sqrt(")?; + emit_expression(collector, &inner_expression)?; + collector.write_all(b")")?; + } + ExpressionContent::Add(lhs, rhs) + | ExpressionContent::Subtract(lhs, rhs) + | ExpressionContent::Multiply(lhs, rhs) + | ExpressionContent::Divide(lhs, rhs) + | ExpressionContent::LessThan(lhs, rhs) + | ExpressionContent::Equal(lhs, rhs) + | ExpressionContent::LogicAnd(lhs, rhs) + | ExpressionContent::BitShiftRight(lhs, rhs) => { + if let ExpressionContent::LogicAnd(_, _) = expression.content { + collector.write_all(b"(")?; + } + emit_expression(collector, &lhs)?; + collector.write_all(match expression.content { + ExpressionContent::Add(_, _) => b" + ", + ExpressionContent::Subtract(_, _) => b" - ", + ExpressionContent::Multiply(_, _) => b" * ", + ExpressionContent::Divide(_, _) => b" / ", + ExpressionContent::LessThan(_, _) => b" < ", + ExpressionContent::Equal(_, _) => b" == ", + ExpressionContent::LogicAnd(_, _) => b" & ", + ExpressionContent::BitShiftRight(_, _) => b" >> ", + _ => unreachable!(), + })?; + emit_expression(collector, &rhs)?; + if let ExpressionContent::LogicAnd(_, _) = expression.content { + collector.write_all(b")")?; + } + } + } + Ok(()) +} + +pub fn emit_code(collector: &mut W, ast_node: &AstNode, indentation: usize) -> std::io::Result<()> { + match ast_node { + AstNode::None => {} + AstNode::Preamble => {} + AstNode::ClassDefinition { class } => { + collector.write_fmt(format_args!("struct {} {{\n", class.class_name))?; + for (i, group) in class.grouped_basis.iter().enumerate() { + emit_indentation(collector, indentation + 1)?; + collector.write_all(b"// ")?; + for (i, element) in group.iter().enumerate() { + if i > 0 { + collector.write_all(b", ")?; + } + collector.write_fmt(format_args!("{}", element))?; + } + collector.write_all(b"\n")?; + emit_indentation(collector, indentation + 1)?; + if group.len() == 1 { + collector.write_all(b"float")?; + } else { + collector.write_fmt(format_args!("vec{}", group.len()))?; + } + collector.write_fmt(format_args!(" g{};\n", i))?; + } + emit_indentation(collector, indentation)?; + collector.write_all(b"};\n\n")?; + } + AstNode::ReturnStatement { expression } => { + collector.write_all(b"return ")?; + emit_expression(collector, expression)?; + collector.write_all(b";\n")?; + } + AstNode::VariableAssignment { name, data_type, expression } => { + if let Some(data_type) = data_type { + emit_data_type(collector, data_type)?; + collector.write_all(b" ")?; + } + collector.write_fmt(format_args!("{} = ", name))?; + emit_expression(collector, expression)?; + collector.write_all(b";\n")?; + } + AstNode::IfThenBlock { condition, body } | AstNode::WhileLoopBlock { condition, body } => { + collector.write_all(match &ast_node { + AstNode::IfThenBlock { .. } => b"if", + AstNode::WhileLoopBlock { .. } => b"while", + _ => unreachable!(), + })?; + collector.write_all(b"(")?; + emit_expression(collector, condition)?; + collector.write_all(b") {\n")?; + for statement in body.iter() { + emit_indentation(collector, indentation + 1)?; + emit_code(collector, statement, indentation + 1)?; + } + emit_indentation(collector, indentation)?; + collector.write_all(b"}\n")?; + } + AstNode::TraitImplementation { result, parameters, body } => { + collector.write_fmt(format_args!("{} ", result.multi_vector_class().class_name))?; + match parameters.len() { + 0 => camel_to_snake_case(collector, &result.multi_vector_class().class_name)?, + 1 if result.name == "Into" => { + camel_to_snake_case(collector, ¶meters[0].multi_vector_class().class_name)?; + collector.write_all(b"_")?; + camel_to_snake_case(collector, &result.multi_vector_class().class_name)?; + } + 1 => camel_to_snake_case(collector, ¶meters[0].multi_vector_class().class_name)?, + 2 if result.name == "Powi" => camel_to_snake_case(collector, ¶meters[0].multi_vector_class().class_name)?, + 2 => { + camel_to_snake_case(collector, ¶meters[0].multi_vector_class().class_name)?; + collector.write_all(b"_")?; + camel_to_snake_case(collector, ¶meters[1].multi_vector_class().class_name)?; + } + _ => unreachable!(), + } + collector.write_all(b"_")?; + camel_to_snake_case(collector, result.name)?; + collector.write_all(b"(")?; + for (i, parameter) in parameters.iter().enumerate() { + if i > 0 { + collector.write_all(b", ")?; + } + emit_data_type(collector, ¶meter.data_type)?; + collector.write_fmt(format_args!(" {}", parameter.name))?; + } + collector.write_all(b") {\n")?; + for statement in body.iter() { + emit_indentation(collector, indentation + 1)?; + emit_code(collector, statement, indentation + 1)?; + } + emit_indentation(collector, indentation)?; + collector.write_all(b"}\n\n")?; + } + } + Ok(()) +} diff --git a/codegen/src/main.rs b/codegen/src/main.rs new file mode 100644 index 0000000..c584e4b --- /dev/null +++ b/codegen/src/main.rs @@ -0,0 +1,224 @@ +mod algebra; +mod ast; +mod compile; +mod emit; +mod glsl; +mod rust; + +use crate::{ + algebra::{BasisElement, GeometricAlgebra, Involution, MultiVectorClass, MultiVectorClassRegistry, Product, ScaledElement}, + ast::{AstNode, DataType, Parameter}, + emit::Emitter, +}; + +fn main() { + let mut args = std::env::args(); + let _executable = args.next().unwrap(); + let config = args.next().unwrap(); + let mut config_iter = config.split(';'); + let algebra_descriptor = config_iter.next().unwrap(); + let mut algebra_descriptor_iter = algebra_descriptor.split(':'); + let algebra_name = algebra_descriptor_iter.next().unwrap(); + let generator_squares = algebra_descriptor_iter + .next() + .unwrap() + .split(',') + .map(|x| x.parse::().unwrap()) + .collect::>(); + let algebra = GeometricAlgebra { + generator_squares: generator_squares.as_slice(), + }; + let involutions = Involution::involutions(&algebra); + let products = Product::products(&algebra); + let basis = algebra.sorted_basis(); + for a in basis.iter() { + for b in basis.iter() { + print!( + "{:1$} ", + ScaledElement::product(&ScaledElement::from(&a), &ScaledElement::from(&b), &algebra), + generator_squares.len() + 2 + ); + } + println!(); + } + let mut registry = MultiVectorClassRegistry::default(); + for multi_vector_descriptor in config_iter { + let mut multi_vector_descriptor_iter = multi_vector_descriptor.split(':'); + registry.register(MultiVectorClass { + class_name: multi_vector_descriptor_iter.next().unwrap().to_owned(), + grouped_basis: multi_vector_descriptor_iter + .next() + .unwrap() + .split('|') + .map(|group_descriptor| { + group_descriptor + .split(',') + .map(|element_name| BasisElement::new(element_name)) + .collect::>() + }) + .collect::>(), + }); + } + let mut emitter = Emitter::new(&std::path::Path::new("../src/").join(std::path::Path::new(algebra_name))); + emitter.emit(&AstNode::Preamble).unwrap(); + for class in registry.classes.iter() { + emitter.emit(&AstNode::ClassDefinition { class }).unwrap(); + } + let involution_name = if generator_squares == vec![-1] { "Conjugate" } else { "Transpose" }; + let mut trait_implementations = std::collections::HashMap::new(); + for class_a in registry.classes.iter() { + let parameter_a = Parameter { + name: "self", + data_type: DataType::MultiVector(class_a), + }; + let mut single_trait_implementations = std::collections::HashMap::new(); + for name in &["Zero", "One"] { + let ast_node = class_a.constant(name); + emitter.emit(&ast_node).unwrap(); + if ast_node != AstNode::None { + single_trait_implementations.insert(name.to_string(), ast_node); + } + } + for (name, involution) in involutions.iter() { + let ast_node = MultiVectorClass::involution(name, &involution, ¶meter_a, ®istry); + emitter.emit(&ast_node).unwrap(); + if ast_node != AstNode::None { + single_trait_implementations.insert(name.to_string(), ast_node); + } + } + let mut pair_trait_implementations = std::collections::HashMap::new(); + for class_b in registry.classes.iter() { + let mut trait_implementations = std::collections::HashMap::new(); + let parameter_b = Parameter { + name: "other", + data_type: DataType::MultiVector(class_b), + }; + let name = "Into"; + let ast_node = MultiVectorClass::conversion(name, ¶meter_a, ¶meter_b.multi_vector_class()); + emitter.emit(&ast_node).unwrap(); + if ast_node != AstNode::None { + trait_implementations.insert(name.to_string(), ast_node); + } + for name in &["Add", "Sub"] { + let ast_node = MultiVectorClass::sum(*name, ¶meter_a, ¶meter_b, ®istry); + emitter.emit(&ast_node).unwrap(); + if ast_node != AstNode::None { + trait_implementations.insert(name.to_string(), ast_node); + } + } + for (name, product) in products.iter() { + let ast_node = MultiVectorClass::product(name, &product, ¶meter_a, ¶meter_b, ®istry); + emitter.emit(&ast_node).unwrap(); + if ast_node != AstNode::None { + trait_implementations.insert(name.to_string(), ast_node); + } + } + pair_trait_implementations.insert( + parameter_b.multi_vector_class().class_name.clone(), + (parameter_b.clone(), trait_implementations), + ); + } + for (parameter_b, pair_trait_implementations) in pair_trait_implementations.values() { + if let Some(scalar_product) = pair_trait_implementations.get("ScalarProduct") { + if let Some(involution) = single_trait_implementations.get(involution_name) { + if parameter_a.multi_vector_class() == parameter_b.multi_vector_class() { + let squared_magnitude = + MultiVectorClass::derive_squared_magnitude("SquaredMagnitude", &scalar_product, &involution, ¶meter_a); + emitter.emit(&squared_magnitude).unwrap(); + let magnitude = MultiVectorClass::derive_magnitude("Magnitude", &squared_magnitude, ¶meter_a); + emitter.emit(&magnitude).unwrap(); + single_trait_implementations.insert(result_of_trait!(squared_magnitude).name.to_string(), squared_magnitude); + single_trait_implementations.insert(result_of_trait!(magnitude).name.to_string(), magnitude); + } + } + } + } + for (parameter_b, pair_trait_implementations) in pair_trait_implementations.values() { + if let Some(geometric_product) = pair_trait_implementations.get("GeometricProduct") { + if parameter_b.multi_vector_class().grouped_basis == vec![vec![BasisElement { index: 0 }]] { + if let Some(magnitude) = single_trait_implementations.get("Magnitude") { + let signum = MultiVectorClass::derive_signum("Signum", &geometric_product, &magnitude, ¶meter_a); + emitter.emit(&signum).unwrap(); + single_trait_implementations.insert(result_of_trait!(signum).name.to_string(), signum); + } + if let Some(squared_magnitude) = single_trait_implementations.get("SquaredMagnitude") { + if let Some(involution) = single_trait_implementations.get(involution_name) { + let inverse = + MultiVectorClass::derive_inverse("Inverse", &geometric_product, &squared_magnitude, &involution, ¶meter_a); + emitter.emit(&inverse).unwrap(); + single_trait_implementations.insert(result_of_trait!(inverse).name.to_string(), inverse); + } + } + } + } + } + trait_implementations.insert( + parameter_a.multi_vector_class().class_name.clone(), + (parameter_a.clone(), single_trait_implementations, pair_trait_implementations), + ); + } + for (parameter_a, single_trait_implementations, pair_trait_implementations) in trait_implementations.values() { + for (parameter_b, pair_trait_implementations) in pair_trait_implementations.values() { + if let Some(geometric_product) = pair_trait_implementations.get("GeometricProduct") { + let geometric_product_result = result_of_trait!(geometric_product); + let multiplication = MultiVectorClass::derive_multiplication("Mul", &geometric_product, ¶meter_a, ¶meter_b); + emitter.emit(&multiplication).unwrap(); + if parameter_a.multi_vector_class() == parameter_b.multi_vector_class() + && geometric_product_result.multi_vector_class() == parameter_a.multi_vector_class() + { + if let Some(constant_one) = single_trait_implementations.get("One") { + if let Some(inverse) = single_trait_implementations.get("Inverse") { + let power_of_integer = MultiVectorClass::derive_power_of_integer( + "Powi", + &geometric_product, + &constant_one, + &inverse, + ¶meter_a, + &Parameter { + name: "exponent", + data_type: DataType::Integer, + }, + ); + emitter.emit(&power_of_integer).unwrap(); + } + } + } + if let Some(b_trait_implementations) = trait_implementations.get(¶meter_b.multi_vector_class().class_name) { + if let Some(inverse) = b_trait_implementations.1.get("Inverse") { + let division = MultiVectorClass::derive_division("Div", &geometric_product, &inverse, ¶meter_a, ¶meter_b); + emitter.emit(&division).unwrap(); + } + } + if let Some(involution) = single_trait_implementations.get(involution_name) { + if let Some(b_trait_implementations) = trait_implementations.get(&geometric_product_result.multi_vector_class().class_name) { + if let Some(b_pair_trait_implementations) = b_trait_implementations.2.get(¶meter_a.multi_vector_class().class_name) { + if let Some(geometric_product_2) = b_pair_trait_implementations.1.get("GeometricProduct") { + let geometric_product_2_result = result_of_trait!(geometric_product_2); + if let Some(c_trait_implementations) = + trait_implementations.get(&geometric_product_2_result.multi_vector_class().class_name) + { + if let Some(c_pair_trait_implementations) = + c_trait_implementations.2.get(¶meter_b.multi_vector_class().class_name) + { + for name in &["Reflection", "Transformation"] { + let transformation = MultiVectorClass::derive_sandwich_product( + name, + &geometric_product, + &geometric_product_2, + &involution, + c_pair_trait_implementations.1.get("Into"), + ¶meter_a, + ¶meter_b, + ); + emitter.emit(&transformation).unwrap(); + } + } + } + } + } + } + } + } + } + } +} diff --git a/codegen/src/rust.rs b/codegen/src/rust.rs new file mode 100644 index 0000000..e97f8ed --- /dev/null +++ b/codegen/src/rust.rs @@ -0,0 +1,289 @@ +use crate::{ + ast::{AstNode, DataType, Expression, ExpressionContent}, + emit::{camel_to_snake_case, emit_indentation}, +}; + +fn emit_data_type(collector: &mut W, data_type: &DataType) -> std::io::Result<()> { + match data_type { + DataType::Integer => collector.write_all(b"isize"), + DataType::SimdVector(size) => collector.write_fmt(format_args!("Simd32x{}", size)), + DataType::MultiVector(class) => collector.write_fmt(format_args!("{}", class.class_name)), + } +} + +fn emit_expression(collector: &mut W, expression: &Expression) -> std::io::Result<()> { + match &expression.content { + ExpressionContent::None => unreachable!(), + ExpressionContent::Variable(name) => { + collector.write_all(name.bytes().collect::>().as_slice())?; + } + ExpressionContent::InvokeClassMethod(_, method_name, arguments) | ExpressionContent::InvokeInstanceMethod(_, _, method_name, arguments) => { + match &expression.content { + ExpressionContent::InvokeInstanceMethod(_result_class, inner_expression, _, _) => { + emit_expression(collector, &inner_expression)?; + collector.write_all(b".")?; + } + ExpressionContent::InvokeClassMethod(class, _, _) => { + if *method_name == "Constructor" { + collector.write_fmt(format_args!("{} {{ ", class.class_name))?; + } else { + collector.write_fmt(format_args!("{}::", class.class_name))?; + } + } + _ => unreachable!(), + } + if *method_name != "Constructor" { + camel_to_snake_case(collector, method_name)?; + collector.write_all(b"(")?; + } + for (i, (_argument_class, argument)) in arguments.iter().enumerate() { + if i > 0 { + collector.write_all(b", ")?; + } + if *method_name == "Constructor" { + collector.write_fmt(format_args!("g{}: ", i))?; + } + emit_expression(collector, &argument)?; + } + if *method_name == "Constructor" { + collector.write_all(b" }")?; + } else { + collector.write_all(b")")?; + } + } + ExpressionContent::Conversion(_source_class, _destination_class, inner_expression) => { + emit_expression(collector, &inner_expression)?; + collector.write_all(b".into()")?; + } + ExpressionContent::Select(condition_expression, then_expression, else_expression) => { + collector.write_all(b"if ")?; + emit_expression(collector, &condition_expression)?; + collector.write_all(b" { ")?; + emit_expression(collector, &then_expression)?; + collector.write_all(b" } else { ")?; + emit_expression(collector, &else_expression)?; + collector.write_all(b" }")?; + } + ExpressionContent::Access(inner_expression, array_index) => { + emit_expression(collector, &inner_expression)?; + collector.write_fmt(format_args!(".g{}", array_index))?; + } + ExpressionContent::Swizzle(inner_expression, indices) => { + if expression.size == 1 { + emit_expression(collector, &inner_expression)?; + if inner_expression.size > 1 { + collector.write_fmt(format_args!(".get_f({})", indices[0]))?; + } + } else { + collector.write_all(b"swizzle!(")?; + emit_expression(collector, &inner_expression)?; + collector.write_all(b", ")?; + for (i, component_index) in indices.iter().enumerate() { + if i > 0 { + collector.write_all(b", ")?; + } + collector.write_fmt(format_args!("{}", *component_index))?; + } + collector.write_all(b")")?; + } + } + ExpressionContent::Gather(inner_expression, indices) => { + if expression.size > 1 { + collector.write_fmt(format_args!("Simd32x{}::from(", expression.size))?; + } + if indices.len() > 1 { + collector.write_all(b"[")?; + } + for (i, (array_index, component_index)) in indices.iter().enumerate() { + if i > 0 { + collector.write_all(b", ")?; + } + emit_expression(collector, &inner_expression)?; + collector.write_fmt(format_args!(".g{}", array_index))?; + if inner_expression.size > 1 { + collector.write_fmt(format_args!(".get_f({})", *component_index))?; + } + } + if indices.len() > 1 { + collector.write_all(b"]")?; + } + if expression.size > 1 { + collector.write_all(b")")?; + } + } + ExpressionContent::Constant(data_type, values) => match data_type { + DataType::Integer => collector.write_fmt(format_args!("{}", values[0] as f32))?, + DataType::SimdVector(_size) => { + if expression.size == 1 { + collector.write_fmt(format_args!("{:.1}", values[0] as f32))?; + } else { + collector.write_fmt(format_args!("Simd32x{}::from(", expression.size))?; + if values.len() > 1 { + collector.write_all(b"[")?; + } + for (i, value) in values.iter().enumerate() { + if i > 0 { + collector.write_all(b", ")?; + } + collector.write_fmt(format_args!("{:.1}", *value as f32))?; + } + if values.len() > 1 { + collector.write_all(b"]")?; + } + collector.write_all(b")")?; + } + } + _ => unreachable!(), + }, + ExpressionContent::SquareRoot(inner_expression) => { + emit_expression(collector, &inner_expression)?; + collector.write_all(b".sqrt()")?; + } + ExpressionContent::Add(lhs, rhs) + | ExpressionContent::Subtract(lhs, rhs) + | ExpressionContent::Multiply(lhs, rhs) + | ExpressionContent::Divide(lhs, rhs) + | ExpressionContent::LessThan(lhs, rhs) + | ExpressionContent::Equal(lhs, rhs) + | ExpressionContent::LogicAnd(lhs, rhs) + | ExpressionContent::BitShiftRight(lhs, rhs) => { + emit_expression(collector, &lhs)?; + collector.write_all(match expression.content { + ExpressionContent::Add(_, _) => b" + ", + ExpressionContent::Subtract(_, _) => b" - ", + ExpressionContent::Multiply(_, _) => b" * ", + ExpressionContent::Divide(_, _) => b" / ", + ExpressionContent::LessThan(_, _) => b" < ", + ExpressionContent::Equal(_, _) => b" == ", + ExpressionContent::LogicAnd(_, _) => b" & ", + ExpressionContent::BitShiftRight(_, _) => b" >> ", + _ => unreachable!(), + })?; + emit_expression(collector, &rhs)?; + } + } + Ok(()) +} + +pub fn emit_code(collector: &mut W, ast_node: &AstNode, indentation: usize) -> std::io::Result<()> { + match &ast_node { + AstNode::None => {} + AstNode::Preamble => { + collector.write_all(b"#![allow(clippy::assign_op_pattern)]\n")?; + collector.write_all(b"use crate::*;\nuse std::ops::{Add, Neg, Sub, Mul, Div};\n\n")?; + } + AstNode::ClassDefinition { class } => { + collector.write_fmt(format_args!("#[derive(Clone, Copy)]\npub struct {} {{\n", class.class_name))?; + for (i, group) in class.grouped_basis.iter().enumerate() { + emit_indentation(collector, indentation + 1)?; + collector.write_all(b"/// ")?; + for (i, element) in group.iter().enumerate() { + if i > 0 { + collector.write_all(b", ")?; + } + collector.write_fmt(format_args!("{}", element))?; + } + collector.write_all(b"\n")?; + emit_indentation(collector, indentation + 1)?; + collector.write_fmt(format_args!("pub g{}: ", i))?; + if group.len() == 1 { + collector.write_all(b"f32,\n")?; + } else { + collector.write_fmt(format_args!("Simd32x{},\n", group.len()))?; + } + } + collector.write_all(b"}\n\n")?; + } + AstNode::ReturnStatement { expression } => { + collector.write_all(b"return ")?; + emit_expression(collector, expression)?; + collector.write_all(b";\n")?; + } + AstNode::VariableAssignment { name, data_type, expression } => { + if let Some(data_type) = data_type { + collector.write_fmt(format_args!("let mut {}", name))?; + collector.write_all(b": ")?; + emit_data_type(collector, data_type)?; + } else { + collector.write_fmt(format_args!("{}", name))?; + } + collector.write_all(b" = ")?; + emit_expression(collector, expression)?; + collector.write_all(b";\n")?; + } + AstNode::IfThenBlock { condition, body } | AstNode::WhileLoopBlock { condition, body } => { + collector.write_all(match &ast_node { + AstNode::IfThenBlock { .. } => b"if ", + AstNode::WhileLoopBlock { .. } => b"while ", + _ => unreachable!(), + })?; + emit_expression(collector, condition)?; + collector.write_all(b" {\n")?; + for statement in body.iter() { + emit_indentation(collector, indentation + 1)?; + emit_code(collector, statement, indentation + 1)?; + } + emit_indentation(collector, indentation)?; + collector.write_all(b"}\n")?; + } + AstNode::TraitImplementation { result, parameters, body } => { + match parameters.len() { + 0 => collector.write_fmt(format_args!("impl {} for {}", result.name, result.multi_vector_class().class_name))?, + 1 if result.name == "Into" => collector.write_fmt(format_args!( + "impl {}<{}> for {}", + result.name, + result.multi_vector_class().class_name, + parameters[0].multi_vector_class().class_name, + ))?, + 1 => collector.write_fmt(format_args!("impl {} for {}", result.name, parameters[0].multi_vector_class().class_name))?, + 2 if result.name == "Powi" => { + collector.write_fmt(format_args!("impl {} for {}", result.name, parameters[0].multi_vector_class().class_name))? + } + 2 => collector.write_fmt(format_args!( + "impl {}<{}> for {}", + result.name, + parameters[1].multi_vector_class().class_name, + parameters[0].multi_vector_class().class_name, + ))?, + _ => unreachable!(), + } + collector.write_all(b" {\n")?; + if !parameters.is_empty() && result.name != "Into" { + emit_indentation(collector, indentation + 1)?; + collector.write_fmt(format_args!("type Output = {};\n\n", result.multi_vector_class().class_name))?; + } + emit_indentation(collector, indentation + 1)?; + collector.write_all(b"fn ")?; + camel_to_snake_case(collector, result.name)?; + match parameters.len() { + 0 => collector.write_all(b"() -> Self")?, + 1 => { + collector.write_fmt(format_args!("({}) -> ", parameters[0].name))?; + emit_data_type(collector, &result.data_type)?; + } + 2 => { + collector.write_fmt(format_args!("({}, {}: ", parameters[0].name, parameters[1].name))?; + emit_data_type(collector, ¶meters[1].data_type)?; + collector.write_all(b") -> ")?; + emit_data_type(collector, &result.data_type)?; + } + _ => unreachable!(), + } + collector.write_all(b" {\n")?; + for (i, statement) in body.iter().enumerate() { + emit_indentation(collector, indentation + 2)?; + if i + 1 == body.len() { + if let AstNode::ReturnStatement { expression } = statement { + emit_expression(collector, expression)?; + collector.write_all(b"\n")?; + break; + } + } + emit_code(collector, statement, indentation + 2)?; + } + emit_indentation(collector, indentation + 1)?; + collector.write_all(b"}\n}\n\n")?; + } + } + Ok(()) +} diff --git a/simd/Cargo.toml b/simd/Cargo.toml new file mode 100644 index 0000000..db81fcf --- /dev/null +++ b/simd/Cargo.toml @@ -0,0 +1,5 @@ +[package] +name = "simd" +version = "0.1.0" +authors = ["Alexander Meißner "] +edition = "2018" diff --git a/simd/src/lib.rs b/simd/src/lib.rs new file mode 100755 index 0000000..8018e46 --- /dev/null +++ b/simd/src/lib.rs @@ -0,0 +1,378 @@ +#![cfg_attr(all(target_arch = "wasm32", target_feature = "simd128"), feature(wasm_simd))] +#![cfg_attr(all(any(target_arch = "arm", target_arch = "aarch64"), target_feature = "neon"), feature(stdsimd))] + +#[cfg(target_arch = "aarch64")] +pub use std::arch::aarch64::*; +#[cfg(target_arch = "arm")] +pub use std::arch::arm::*; +#[cfg(target_arch = "wasm32")] +pub use std::arch::wasm32::*; +#[cfg(target_arch = "x86")] +pub use std::arch::x86::*; +#[cfg(target_arch = "x86_64")] +pub use std::arch::x86_64::*; + +#[derive(Clone, Copy)] +#[repr(C)] +pub union Simd32x4 { + // Intel / AMD + #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "sse2"))] + pub f128: __m128, + #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "sse2"))] + pub i128: __m128i, + + // ARM + #[cfg(all(any(target_arch = "arm", target_arch = "aarch64"), target_feature = "neon"))] + pub f128: float32x4_t, + #[cfg(all(any(target_arch = "arm", target_arch = "aarch64"), target_feature = "neon"))] + pub i128: int32x4_t, + #[cfg(all(any(target_arch = "arm", target_arch = "aarch64"), target_feature = "neon"))] + pub u128: uint32x4_t, + + // Web + #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))] + pub v128: v128, + + // Fallback + pub f32x4: [f32; 4], + pub i32x4: [i32; 4], + pub u32x4: [u32; 4], +} + +#[derive(Clone, Copy)] +#[repr(C)] +pub union Simd32x3 { + pub v32x4: Simd32x4, + + // Fallback + pub f32x3: [f32; 3], + pub i32x3: [i32; 3], + pub u32x3: [u32; 3], +} + +#[derive(Clone, Copy)] +#[repr(C)] +pub union Simd32x2 { + pub v32x4: Simd32x4, + + // Fallback + pub f32x2: [f32; 2], + pub i32x2: [i32; 2], + pub u32x2: [u32; 2], +} + +#[macro_export] +macro_rules! match_architecture { + ($Simd:ident, $native:tt, $fallback:tt,) => {{ + #[cfg(any(target_arch = "x86", target_arch = "x86_64", target_arch = "arm", target_arch = "aarch64", target_arch = "wasm32"))] + unsafe { $Simd $native } + #[cfg(not(any(target_arch = "x86", target_arch = "x86_64", target_arch = "arm", target_arch = "aarch64", target_arch = "wasm32")))] + unsafe { $Simd $fallback } + }}; + ($Simd:ident, $x86:tt, $arm:tt, $web:tt, $fallback:tt,) => {{ + #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "sse2"))] + unsafe { $Simd $x86 } + #[cfg(all(any(target_arch = "arm", target_arch = "aarch64"), target_feature = "neon"))] + unsafe { $Simd $arm } + #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))] + unsafe { $Simd $web } + #[cfg(not(any(target_arch = "x86", target_arch = "x86_64", target_arch = "arm", target_arch = "aarch64", target_arch = "wasm32")))] + unsafe { $Simd $fallback } + }}; +} + +#[macro_export] +macro_rules! swizzle { + ($self:expr, $x:literal, $y:literal, $z:literal, $w:literal) => { + $crate::match_architecture!( + Simd32x4, + { f128: $crate::_mm_permute_ps($self.f128, ($x as i32) | (($y as i32) << 2) | (($z as i32) << 4) | (($w as i32) << 6)) }, + { f32x4: [ + $self.f32x4[$x], + $self.f32x4[$y], + $self.f32x4[$z], + $self.f32x4[$w], + ] }, + { v128: $crate::v32x4_shuffle::<$x, $y, $z, $w>($self.v128, $self.v128) }, + { f32x4: [ + $self.f32x4[$x], + $self.f32x4[$y], + $self.f32x4[$z], + $self.f32x4[$w], + ] }, + ) + }; + ($self:expr, $x:literal, $y:literal, $z:literal) => { + $crate::match_architecture!( + Simd32x3, + { v32x4: $crate::swizzle!($self.v32x4, $x, $y, $z, 0) }, + { f32x3: [ + $self.f32x3[$x], + $self.f32x3[$y], + $self.f32x3[$z], + ] }, + ) + }; + ($self:expr, $x:literal, $y:literal) => { + $crate::match_architecture!( + Simd32x2, + { v32x4: $crate::swizzle!($self.v32x4, $x, $y, 0, 0) }, + { f32x2: [ + $self.f32x2[$x], + $self.f32x2[$y], + ] }, + ) + }; +} + +impl Simd32x4 { + pub fn get_f(&self, index: usize) -> f32 { + unsafe { self.f32x4[index] } + } + + pub fn set_f(&mut self, index: usize, value: f32) { + unsafe { self.f32x4[index] = value; } + } +} + +impl Simd32x3 { + pub fn get_f(&self, index: usize) -> f32 { + unsafe { self.f32x3[index] } + } + + pub fn set_f(&mut self, index: usize, value: f32) { + unsafe { self.f32x3[index] = value; } + } +} + +impl Simd32x2 { + pub fn get_f(&self, index: usize) -> f32 { + unsafe { self.f32x2[index] } + } + + pub fn set_f(&mut self, index: usize, value: f32) { + unsafe { self.f32x2[index] = value; } + } +} + +impl std::convert::From<[f32; 4]> for Simd32x4 { + fn from(f32x4: [f32; 4]) -> Self { + Self { f32x4 } + } +} + +impl std::convert::From<[f32; 3]> for Simd32x3 { + fn from(f32x3: [f32; 3]) -> Self { + Self { f32x3 } + } +} + +impl std::convert::From<[f32; 2]> for Simd32x2 { + fn from(f32x2: [f32; 2]) -> Self { + Self { f32x2 } + } +} + +impl std::convert::From for Simd32x4 { + fn from(value: f32) -> Self { + Self { + f32x4: [value, value, value, value], + } + } +} + +impl std::convert::From for Simd32x3 { + fn from(value: f32) -> Self { + Self { + f32x3: [value, value, value], + } + } +} + +impl std::convert::From for Simd32x2 { + fn from(value: f32) -> Self { + Self { + f32x2: [value, value], + } + } +} + +impl std::fmt::Debug for Simd32x4 { + fn fmt(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.debug_tuple("Vec4") + .field(&self.get_f(0)) + .field(&self.get_f(1)) + .field(&self.get_f(2)) + .field(&self.get_f(3)) + .finish() + } +} + +impl std::fmt::Debug for Simd32x3 { + fn fmt(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.debug_tuple("Vec3") + .field(&self.get_f(0)) + .field(&self.get_f(1)) + .field(&self.get_f(2)) + .finish() + } +} + +impl std::fmt::Debug for Simd32x2 { + fn fmt(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.debug_tuple("Vec2") + .field(&self.get_f(0)) + .field(&self.get_f(1)) + .finish() + } +} + +impl std::ops::Add for Simd32x4 { + type Output = Simd32x4; + + fn add(self, other: Self) -> Self { + match_architecture!( + Self, + { f128: _mm_add_ps(self.f128, other.f128) }, + { f128: vaddq_f32(self.f128, other.f128) }, + { v128: f32x4_add(self.v128, other.v128) }, + { f32x4: [ + self.f32x4[0] + other.f32x4[0], + self.f32x4[1] + other.f32x4[1], + self.f32x4[2] + other.f32x4[2], + self.f32x4[3] + other.f32x4[3], + ] }, + ) + } +} + +impl std::ops::Add for Simd32x3 { + type Output = Simd32x3; + + fn add(self, other: Self) -> Self { + match_architecture!( + Self, + { v32x4: self.v32x4 + other.v32x4 }, + { f32x3: [ + self.f32x3[0] + other.f32x3[0], + self.f32x3[1] + other.f32x3[1], + self.f32x3[2] + other.f32x3[2], + ] }, + ) + } +} + +impl std::ops::Add for Simd32x2 { + type Output = Simd32x2; + + fn add(self, other: Self) -> Self { + match_architecture!( + Self, + { v32x4: self.v32x4 + other.v32x4 }, + { f32x2: [ + self.f32x2[0] + other.f32x2[0], + self.f32x2[1] + other.f32x2[1], + ] }, + ) + } +} + +impl std::ops::Sub for Simd32x4 { + type Output = Simd32x4; + + fn sub(self, other: Self) -> Self { + match_architecture!( + Self, + { f128: _mm_sub_ps(self.f128, other.f128) }, + { f128: vsubq_f32(self.f128, other.f128) }, + { v128: f32x4_sub(self.v128, other.v128) }, + { f32x4: [ + self.f32x4[0] - other.f32x4[0], + self.f32x4[1] - other.f32x4[1], + self.f32x4[2] - other.f32x4[2], + self.f32x4[3] - other.f32x4[3], + ] }, + ) + } +} + +impl std::ops::Sub for Simd32x3 { + type Output = Simd32x3; + + fn sub(self, other: Self) -> Self { + match_architecture!( + Self, + { v32x4: self.v32x4 - other.v32x4 }, + { f32x3: [ + self.f32x3[0] - other.f32x3[0], + self.f32x3[1] - other.f32x3[1], + self.f32x3[2] - other.f32x3[2], + ] }, + ) + } +} + +impl std::ops::Sub for Simd32x2 { + type Output = Simd32x2; + + fn sub(self, other: Self) -> Self { + match_architecture!( + Self, + { v32x4: self.v32x4 - other.v32x4 }, + { f32x2: [ + self.f32x2[0] - other.f32x2[0], + self.f32x2[1] - other.f32x2[1], + ] }, + ) + } +} + +impl std::ops::Mul for Simd32x4 { + type Output = Simd32x4; + + fn mul(self, other: Self) -> Self { + match_architecture!( + Self, + { f128: _mm_mul_ps(self.f128, other.f128) }, + { f128: vmulq_f32(self.f128, other.f128) }, + { v128: f32x4_mul(self.v128, other.v128) }, + { f32x4: [ + self.f32x4[0] * other.f32x4[0], + self.f32x4[1] * other.f32x4[1], + self.f32x4[2] * other.f32x4[2], + self.f32x4[3] * other.f32x4[3], + ] }, + ) + } +} + +impl std::ops::Mul for Simd32x3 { + type Output = Simd32x3; + + fn mul(self, other: Self) -> Self { + match_architecture!( + Self, + { v32x4: self.v32x4 * other.v32x4 }, + { f32x3: [ + self.f32x3[0] * other.f32x3[0], + self.f32x3[1] * other.f32x3[1], + self.f32x3[2] * other.f32x3[2], + ] }, + ) + } +} + +impl std::ops::Mul for Simd32x2 { + type Output = Simd32x2; + + fn mul(self, other: Self) -> Self { + match_architecture!( + Self, + { v32x4: self.v32x4 * other.v32x4 }, + { f32x2: [ + self.f32x2[0] * other.f32x2[0], + self.f32x2[1] * other.f32x2[1], + ] }, + ) + } +} \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..15ffd16 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,210 @@ +#![cfg_attr(all(target_arch = "wasm32", target_feature = "simd128"), feature(wasm_simd))] +#![cfg_attr(all(any(target_arch = "arm", target_arch = "aarch64"), target_feature = "neon"), feature(stdsimd))] + +pub use simd::*; +pub mod complex; +pub mod ppga2d; +pub mod ppga3d; + +impl complex::Scalar { + pub const fn new(real: f32) -> Self { + Self { g0: real } + } + + pub fn real(self) -> f32 { + self.g0 + } + + pub fn sqrt(self) -> complex::MultiVector { + if self.g0 < 0.0 { + complex::MultiVector::new(0.0, (-self.g0).sqrt()) + } else { + complex::MultiVector::new(self.g0.sqrt(), 0.0) + } + } +} + +impl complex::MultiVector { + pub const fn new(real: f32, imaginary: f32) -> Self { + Self { + g0: simd::Simd32x2 { + f32x2: [real, imaginary], + }, + } + } + + pub fn real(self) -> f32 { + self.g0.get_f(0) + } + + pub fn imaginary(self) -> f32 { + self.g0.get_f(1) + } + + pub fn from_polar(magnitude: f32, angle: f32) -> Self { + Self::new(magnitude * angle.cos(), magnitude * angle.sin()) + } + + pub fn arg(self) -> f32 { + self.imaginary().atan2(self.real()) + } + + pub fn powf(self, exponent: f32) -> Self { + Self::from_polar(self.magnitude().g0.powf(exponent), self.arg() * exponent) + } +} + +impl ppga2d::Rotor { + pub fn from_angle(mut angle: f32) -> Self { + angle *= 0.5; + Self { + g0: simd::Simd32x2::from([angle.cos(), angle.sin()]), + } + } + + pub fn angle(self) -> f32 { + self.g0.get_f(1).atan2(self.g0.get_f(0)) * 2.0 + } +} + +impl ppga2d::Point { + pub fn from_coordinates(coordinates: [f32; 2]) -> Self { + Self { + g0: simd::Simd32x3::from([1.0, coordinates[0], coordinates[1]]), + } + } + + pub fn from_direction(coordinates: [f32; 2]) -> Self { + Self { + g0: simd::Simd32x3::from([0.0, coordinates[0], coordinates[1]]), + } + } +} + +impl ppga2d::Plane { + pub fn from_normal_and_distance(normal: [f32; 2], distance: f32) -> Self { + Self { + g0: simd::Simd32x3::from([distance, normal[1], -normal[0]]), + } + } +} + +impl ppga2d::Translator { + pub fn from_coordinates(coordinates: [f32; 2]) -> Self { + Self { + g0: simd::Simd32x3::from([1.0, coordinates[1] * 0.5, coordinates[0] * -0.5]), + } + } +} + +/// All elements set to `0.0` +pub trait Zero { + fn zero() -> Self; +} + +/// All elements set to `0.0`, except for the scalar, which is set to `1.0` +pub trait One { + fn one() -> Self; +} + +/// Element order reversed +pub trait Dual { + type Output; + fn dual(self) -> Self::Output; +} + +/// Also called reversion +pub trait Transpose { + type Output; + fn transpose(self) -> Self::Output; +} + +/// Also called involution +pub trait Automorph { + type Output; + fn automorph(self) -> Self::Output; +} + +pub trait Conjugate { + type Output; + fn conjugate(self) -> Self::Output; +} + +pub trait GeometricProduct { + type Output; + fn geometric_product(self, other: T) -> Self::Output; +} + +/// Also called join +pub trait RegressiveProduct { + type Output; + fn regressive_product(self, other: T) -> Self::Output; +} + +/// Also called meet or exterior product +pub trait OuterProduct { + type Output; + fn outer_product(self, other: T) -> Self::Output; +} + +/// Also called fat dot product +pub trait InnerProduct { + type Output; + fn inner_product(self, other: T) -> Self::Output; +} + +pub trait LeftContraction { + type Output; + fn left_contraction(self, other: T) -> Self::Output; +} + +pub trait RightContraction { + type Output; + fn right_contraction(self, other: T) -> Self::Output; +} + +pub trait ScalarProduct { + type Output; + fn scalar_product(self, other: T) -> Self::Output; +} + +pub trait Reflection { + type Output; + fn reflection(self, other: T) -> Self::Output; +} + +/// Also called sandwich product +pub trait Transformation { + type Output; + fn transformation(self, other: T) -> Self::Output; +} + +/// Square of the magnitude +pub trait SquaredMagnitude { + type Output; + fn squared_magnitude(self) -> Self::Output; +} + +/// Also called amplitude, absolute value or norm +pub trait Magnitude { + type Output; + fn magnitude(self) -> Self::Output; +} + +/// Also called normalize +pub trait Signum { + type Output; + fn signum(self) -> Self::Output; +} + +/// Exponentiation by scalar negative one +pub trait Inverse { + type Output; + fn inverse(self) -> Self::Output; +} + +/// Exponentiation by a scalar integer +pub trait Powi { + type Output; + fn powi(self, exponent: isize) -> Self::Output; +}