From 365ddcba442c8c7b1590d1a1081156091d9df160 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexander=20Mei=C3=9Fner?= Date: Sat, 17 Apr 2021 23:32:37 +0200 Subject: [PATCH] Adds support for negative basis elements and fixes wrong duals. --- .github/workflows/actions.yaml | 4 +- Cargo.toml | 2 +- codegen/Cargo.toml | 2 +- codegen/src/algebra.rs | 169 +++++++++++++++------------------ codegen/src/compile.rs | 163 +++++++++++++------------------ codegen/src/main.rs | 30 +++--- 6 files changed, 159 insertions(+), 211 deletions(-) diff --git a/.github/workflows/actions.yaml b/.github/workflows/actions.yaml index 52050d8..59607fe 100644 --- a/.github/workflows/actions.yaml +++ b/.github/workflows/actions.yaml @@ -24,9 +24,9 @@ jobs: - name: complex descriptor: "complex:-1;Scalar:1;MultiVector:1,e0" - name: ppga2d - descriptor: "ppga2d:0,1,1;Scalar:1;MultiVector:1,e12,e1,e2|e0,e012,e01,e02;Rotor:1,e12;Point:e12,e01,e02;Plane:e0,e2,e1;Translator:1,e01,e02;Motor:1,e12,e01,e02;MotorDual:e012,e0,e2,e1" + descriptor: "ppga2d:0,1,1;Scalar:1;MultiVector:1,e12,e1,e2|e0,e012,e01,-e02;Rotor:1,e12;Point:e12,e01,-e02;Plane:e0,e2,e1;Translator:1,e01,-e02;Motor:1,e12,e01,-e02;MotorDual:e012,e0,e2,e1" - name: ppga3d - descriptor: "ppga3d:0,1,1,1;Scalar:1;MultiVector:1,e23,e13,e12|e0,e023,e013,e012|e123,e1,e2,e3|e0123,e01,e02,e03;Rotor:1,e23,e13,e12;Point:e123,e023,e013,e012;Plane:e0,e1,e2,e3;Line:e01,e02,e03|e23,e13,e12;Translator:1,e01,e02,e03;Motor:1,e23,e13,e12|e0123,e01,e02,e03;PointAndPlane:e123,e023,e013,e012|e0,e1,e2,e3" + descriptor: "ppga3d:0,1,1,1;Scalar:1;MultiVector:1,e23,-e13,e12|e0,-e023,e013,-e012|e123,e1,e2,e3|e0123,e01,e02,e03;Rotor:1,e23,-e13,e12;Point:e123,-e023,e013,-e012;Plane:e0,e1,e2,e3;Line:e01,e02,e03|e23,-e13,e12;Translator:1,e01,e02,e03;Motor:1,e23,-e13,e12|e0123,e01,e02,e03;PointAndPlane:e123,-e023,e013,-e012|e0,e1,e2,e3" steps: - uses: actions/download-artifact@v2 with: diff --git a/Cargo.toml b/Cargo.toml index 8f3a65c..ea4f426 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "geometric_algebra" -version = "0.1.0" +version = "0.1.1" authors = ["Alexander Meißner "] description = "Generate(d) custom libraries for geometric algebras" repository = "https://github.com/Lichtso/geometric_algebra/" diff --git a/codegen/Cargo.toml b/codegen/Cargo.toml index cd6471f..0cf4251 100644 --- a/codegen/Cargo.toml +++ b/codegen/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "codegen" -version = "0.1.0" +version = "0.1.1" authors = ["Alexander Meißner "] edition = "2018" publish = false \ No newline at end of file diff --git a/codegen/src/algebra.rs b/codegen/src/algebra.rs index 9c4a76c..79f6b5c 100644 --- a/codegen/src/algebra.rs +++ b/codegen/src/algebra.rs @@ -8,7 +8,14 @@ impl<'a> GeometricAlgebra<'a> { } pub fn basis(&self) -> impl Iterator + '_ { - (0..self.basis_size() as BasisElementIndex).map(|index| BasisElement { index }) + (0..self.basis_size() as BasisElementIndex).map(move |index| { + let mut element = BasisElement::from_index(index); + let dual = element.dual(self); + if dual.cmp(&element) == std::cmp::Ordering::Less { + element.scalar = element.dual(self).scalar; + } + element + }) } pub fn sorted_basis(&self) -> Vec { @@ -18,24 +25,35 @@ impl<'a> GeometricAlgebra<'a> { } } -type BasisElementIndex = u16; +pub type BasisElementIndex = u16; #[derive(Clone, PartialEq, Eq, Hash)] pub struct BasisElement { + pub scalar: isize, pub index: BasisElementIndex, } impl BasisElement { - pub fn new(name: &str) -> Self { - Self { - index: if name == "1" { - 0 - } else { - let mut generator_indices = name.chars(); - assert_eq!(generator_indices.next().unwrap(), 'e'); - generator_indices.fold(0, |index, generator_index| index | (1 << (generator_index.to_digit(16).unwrap()))) - }, + pub fn from_index(index: BasisElementIndex) -> Self { + Self { scalar: 1, index } + } + + pub fn parse(mut name: &str, algebra: &GeometricAlgebra) -> Self { + let mut result = Self::from_index(0); + if name.starts_with('-') { + name = &name[1..]; + result.scalar = -1; } + if name == "1" { + return result; + } + let mut generator_indices = name.chars(); + assert_eq!(generator_indices.next().unwrap(), 'e'); + for generator_index in generator_indices { + let generator = Self::from_index(1 << (generator_index.to_digit(16).unwrap())); + result = BasisElement::product(&result, &generator, algebra); + } + result } pub fn grade(&self) -> usize { @@ -47,8 +65,30 @@ impl BasisElement { } pub fn dual(&self, algebra: &GeometricAlgebra) -> Self { - Self { + let mut result = Self { + scalar: self.scalar, index: algebra.basis_size() as BasisElementIndex - 1 - self.index, + }; + result.scalar *= BasisElement::product(&self, &result, &algebra).scalar; + result + } + + pub fn product(a: &Self, b: &Self, algebra: &GeometricAlgebra) -> Self { + let commutations = a.component_bits().fold((0, a.index, b.index), |(commutations, a, b), index| { + let hurdles_a = a & (BasisElementIndex::MAX << (index + 1)); + let hurdles_b = b & ((1 << index) - 1); + ( + commutations + Self::from_index(hurdles_a | hurdles_b).grade(), + a & !(1 << index), + b ^ (1 << index), + ) + }); + Self { + scalar: Self::from_index(a.index & b.index) + .component_bits() + .map(|i| algebra.generator_squares[i]) + .fold(a.scalar * b.scalar * if commutations.0 % 2 == 0 { 1 } else { -1 }, |a, b| a * b), + index: a.index ^ b.index, } } } @@ -56,10 +96,10 @@ impl BasisElement { impl std::fmt::Display for BasisElement { fn fmt(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { if self.index == 0 { - formatter.pad("1") + formatter.pad_integral(self.scalar >= 0, "", "1") } else { let string = format!("e{}", self.component_bits().map(|index| format!("{:X}", index)).collect::()); - formatter.pad(string.as_str()) + formatter.pad_integral(self.scalar >= 0, "", string.as_str()) } } } @@ -86,67 +126,21 @@ impl std::cmp::PartialOrd for BasisElement { } } -#[derive(Clone, PartialEq)] -pub struct ScaledElement { - pub scalar: isize, - pub unit: BasisElement, -} - -impl ScaledElement { - pub fn from(element: &BasisElement) -> Self { - Self { - scalar: 1, - unit: element.clone(), - } - } - - pub fn product(a: &Self, b: &Self, algebra: &GeometricAlgebra) -> Self { - let commutations = a - .unit - .component_bits() - .fold((0, a.unit.index, b.unit.index), |(commutations, a, b), index| { - let hurdles_a = a & (BasisElementIndex::MAX << (index + 1)); - let hurdles_b = b & ((1 << index) - 1); - ( - commutations - + BasisElement { - index: hurdles_a | hurdles_b, - } - .grade(), - a & !(1 << index), - b ^ (1 << index), - ) - }); - Self { - scalar: BasisElement { - index: a.unit.index & b.unit.index, - } - .component_bits() - .map(|i| algebra.generator_squares[i]) - .fold(a.scalar * b.scalar * if commutations.0 % 2 == 0 { 1 } else { -1 }, |a, b| a * b), - unit: BasisElement { - index: a.unit.index ^ b.unit.index, - }, - } - } -} - -impl std::fmt::Display for ScaledElement { - fn fmt(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - let string = self.unit.to_string(); - formatter.pad_integral(self.scalar >= 0, "", if self.scalar == 0 { "0" } else { string.as_str() }) - } -} - #[derive(Clone)] pub struct Involution { - pub terms: Vec, + pub terms: Vec<(BasisElement, BasisElement)>, } impl Involution { pub fn identity(algebra: &GeometricAlgebra) -> Self { Self { - terms: algebra.basis().map(|element| ScaledElement { scalar: 1, unit: element }).collect(), + terms: algebra.basis().map(|element| (element.clone(), element)).collect(), + } + } + + pub fn projection(class: &MultiVectorClass) -> Self { + Self { + terms: class.flat_basis().iter().map(|element| (element.clone(), element.clone())).collect(), } } @@ -158,9 +152,10 @@ impl Involution { terms: self .terms .iter() - .map(|element| ScaledElement { - scalar: if grade_negation(element.unit.grade()) { -1 } else { 1 }, - unit: element.unit.clone(), + .map(|(key, value)| { + let mut element = value.clone(); + element.scalar *= if grade_negation(value.grade()) { -1 } else { 1 }; + (key.clone(), element) }) .collect(), } @@ -168,14 +163,7 @@ impl Involution { pub fn dual(&self, algebra: &GeometricAlgebra) -> Self { Self { - terms: self - .terms - .iter() - .map(|term| ScaledElement { - scalar: term.scalar, - unit: term.unit.dual(algebra), - }) - .collect(), + terms: self.terms.iter().map(|(key, value)| (key.clone(), value.dual(algebra))).collect(), } } @@ -193,7 +181,7 @@ impl Involution { #[derive(Clone, PartialEq)] pub struct ProductTerm { - pub product: ScaledElement, + pub product: BasisElement, pub factor_a: BasisElement, pub factor_b: BasisElement, } @@ -204,15 +192,15 @@ pub struct Product { } impl Product { - pub fn product(a: &[ScaledElement], b: &[ScaledElement], algebra: &GeometricAlgebra) -> Self { + pub fn product(a: &[BasisElement], b: &[BasisElement], algebra: &GeometricAlgebra) -> Self { Self { terms: a .iter() .map(|a| { b.iter().map(move |b| ProductTerm { - product: ScaledElement::product(&a, &b, algebra), - factor_a: a.unit.clone(), - factor_b: b.unit.clone(), + product: BasisElement::product(&a, &b, algebra), + factor_a: a.clone(), + factor_b: b.clone(), }) }) .flatten() @@ -229,7 +217,7 @@ impl Product { terms: self .terms .iter() - .filter(|term| grade_projection(term.factor_a.grade(), term.factor_b.grade(), term.product.unit.grade())) + .filter(|term| grade_projection(term.factor_a.grade(), term.factor_b.grade(), term.product.grade())) .cloned() .collect(), } @@ -241,10 +229,7 @@ impl Product { .terms .iter() .map(|term| ProductTerm { - product: ScaledElement { - scalar: term.product.scalar, - unit: term.product.unit.dual(algebra), - }, + product: term.product.dual(algebra), factor_a: term.factor_a.dual(algebra), factor_b: term.factor_b.dual(algebra), }) @@ -253,7 +238,7 @@ impl Product { } pub fn products(algebra: &GeometricAlgebra) -> Vec<(&'static str, Self)> { - let basis = algebra.basis().map(|element| ScaledElement::from(&element)).collect::>(); + let basis = algebra.basis().collect::>(); let product = Self::product(&basis, &basis, algebra); vec![ ("GeometricProduct", product.clone()), @@ -270,7 +255,7 @@ impl Product { #[derive(Default)] pub struct MultiVectorClassRegistry { pub classes: Vec, - index_by_signature: std::collections::HashMap, usize>, + index_by_signature: std::collections::HashMap, usize>, } impl MultiVectorClassRegistry { @@ -279,7 +264,7 @@ impl MultiVectorClassRegistry { self.classes.push(class); } - pub fn get(&self, signature: &[BasisElement]) -> Option<&MultiVectorClass> { + pub fn get(&self, signature: &[BasisElementIndex]) -> Option<&MultiVectorClass> { self.index_by_signature.get(signature).map(|index| &self.classes[*index]) } } diff --git a/codegen/src/compile.rs b/codegen/src/compile.rs index faac77d..10c560a 100644 --- a/codegen/src/compile.rs +++ b/codegen/src/compile.rs @@ -1,5 +1,5 @@ use crate::{ - algebra::{BasisElement, Involution, MultiVectorClass, MultiVectorClassRegistry, Product}, + algebra::{BasisElement, BasisElementIndex, Involution, MultiVectorClass, MultiVectorClassRegistry, Product}, ast::{AstNode, DataType, Expression, ExpressionContent, Parameter}, }; @@ -146,9 +146,9 @@ impl MultiVectorClass { self.grouped_basis.iter().flatten().cloned().collect() } - pub fn signature(&self) -> Vec { - let mut signature = self.flat_basis(); - signature.sort(); + pub fn signature(&self) -> Vec { + let mut signature: Vec = self.grouped_basis.iter().flatten().map(|element| element.index).collect(); + signature.sort_unstable(); signature } @@ -204,29 +204,48 @@ impl MultiVectorClass { involution: &Involution, parameter_a: &Parameter<'a>, registry: &'a MultiVectorClassRegistry, + project: bool, ) -> AstNode<'a> { let a_flat_basis = parameter_a.multi_vector_class().flat_basis(); - let result_signature = a_flat_basis - .iter() - .map(|element| involution.terms[element.index as usize].unit.clone()) - .collect::>(); - let mut result_signature = result_signature.into_iter().collect::>(); - result_signature.sort(); + let mut result_signature = Vec::new(); + for a_element in a_flat_basis.iter() { + for (in_element, out_element) in involution.terms.iter() { + if in_element.index == a_element.index { + result_signature.push(out_element.index); + break; + } + } + } + if project { + for (in_element, _out_element) in involution.terms.iter() { + if !a_flat_basis.iter().any(|element| element.index == in_element.index) { + return AstNode::None; + } + } + } + result_signature.sort_unstable(); if let Some(result_class) = registry.get(&result_signature) { let result_flat_basis = result_class.flat_basis(); let mut body = Vec::new(); let mut base_index = 0; for result_group in result_class.grouped_basis.iter() { let size = result_group.len(); - let a_indices = (0..size) + let (factors, a_indices): (Vec<_>, Vec<_>) = (0..size) .map(|index_in_group| { let result_element = &result_flat_basis[base_index + index_in_group]; - let a_element = &involution.terms[result_element.index as usize].unit; - parameter_a - .multi_vector_class() - .index_in_group(a_flat_basis.iter().position(|element| element == a_element).unwrap()) + let involution_element = involution + .terms + .iter() + .position(|(_in_element, out_element)| out_element.index == result_element.index) + .unwrap(); + let (in_element, out_element) = &involution.terms[involution_element]; + let index_in_a = a_flat_basis.iter().position(|a_element| a_element.index == in_element.index).unwrap(); + ( + out_element.scalar * result_element.scalar * in_element.scalar * a_flat_basis[index_in_a].scalar, + parameter_a.multi_vector_class().index_in_group(index_in_a), + ) }) - .collect::>(); + .unzip(); let a_group_index = a_indices[0].0; let expression = Expression { size, @@ -243,13 +262,7 @@ impl MultiVectorClass { }), Box::new(Expression { size, - content: ExpressionContent::Constant( - DataType::SimdVector(size), - result_group - .iter() - .map(|element| involution.terms[element.index as usize].scalar) - .collect(), - ), + content: ExpressionContent::Constant(DataType::SimdVector(size), factors), }), ), }; @@ -274,58 +287,6 @@ impl MultiVectorClass { } } - pub fn conversion<'a>(name: &'static str, parameter_a: &Parameter<'a>, result_class: &'a MultiVectorClass) -> AstNode<'a> { - if parameter_a.multi_vector_class() == result_class { - return AstNode::None; - } - let a_flat_basis = parameter_a.multi_vector_class().flat_basis(); - let result_flat_basis = result_class.flat_basis(); - for element in &result_flat_basis { - if !a_flat_basis.contains(element) { - return AstNode::None; - } - } - let mut body = Vec::new(); - let mut base_index = 0; - for result_group in result_class.grouped_basis.iter() { - let size = result_group.len(); - let a_indices = (0..size) - .map(|index_in_group| { - let result_element = &result_flat_basis[base_index + index_in_group]; - parameter_a - .multi_vector_class() - .index_in_group(a_flat_basis.iter().position(|a_element| a_element == result_element).unwrap()) - }) - .collect::>(); - let a_group_index = a_indices[0].0; - let expression = Expression { - size, - content: ExpressionContent::Gather( - Box::new(Expression { - size: parameter_a.multi_vector_class().grouped_basis[a_group_index].len(), - content: ExpressionContent::Variable(parameter_a.name), - }), - a_indices, - ), - }; - body.push((DataType::SimdVector(size), *simplify_and_legalize(Box::new(expression)))); - base_index += size; - } - AstNode::TraitImplementation { - result: Parameter { - name, - data_type: DataType::MultiVector(result_class), - }, - parameters: vec![parameter_a.clone()], - body: vec![AstNode::ReturnStatement { - expression: Box::new(Expression { - size: 1, - content: ExpressionContent::InvokeClassMethod(result_class, "Constructor", body), - }), - }], - } - } - pub fn sum<'a>( name: &'static str, parameter_a: &Parameter<'a>, @@ -339,8 +300,8 @@ impl MultiVectorClass { .chain(b_flat_basis.iter()) .cloned() .collect::>(); - let mut result_signature = result_signature.into_iter().collect::>(); - result_signature.sort(); + let mut result_signature = result_signature.into_iter().map(|element| element.index).collect::>(); + result_signature.sort_unstable(); if let Some(result_class) = registry.get(&result_signature) { let parameters = [(parameter_a, &a_flat_basis), (parameter_b, &b_flat_basis)]; let mut body = Vec::new(); @@ -351,10 +312,10 @@ impl MultiVectorClass { let terms: Vec<_> = result_group .iter() .map(|result_element| { - if let Some(index_in_group) = flat_basis.iter().position(|element| element == result_element) { - let index_pair = parameter.multi_vector_class().index_in_group(index_in_group); + if let Some(index_in_flat_basis) = flat_basis.iter().position(|element| element.index == result_element.index) { + let index_pair = parameter.multi_vector_class().index_in_group(index_in_flat_basis); parameter_group_index = Some(index_pair.0); - (1, index_pair) + (result_element.scalar * flat_basis[index_in_flat_basis].scalar, index_pair) } else { (0, (0, 0)) } @@ -428,20 +389,30 @@ impl MultiVectorClass { let b_flat_basis = parameter_b.multi_vector_class().flat_basis(); let mut result_signature = std::collections::HashSet::new(); for product_term in product.terms.iter() { - if a_flat_basis.contains(&product_term.factor_a) && b_flat_basis.contains(&product_term.factor_b) { - result_signature.insert(product_term.product.unit.clone()); + if a_flat_basis.iter().any(|e| e.index == product_term.factor_a.index) + && b_flat_basis.iter().any(|e| e.index == product_term.factor_b.index) + { + result_signature.insert(product_term.product.index); } } let mut result_signature = result_signature.into_iter().collect::>(); - result_signature.sort(); + result_signature.sort_unstable(); if let Some(result_class) = registry.get(&result_signature) { let result_flat_basis = result_class.flat_basis(); let mut sorted_terms = vec![vec![(0, 0); a_flat_basis.len()]; result_flat_basis.len()]; for product_term in product.terms.iter() { - if let Some(y) = result_flat_basis.iter().position(|e| e == &product_term.product.unit) { - if let Some(x) = a_flat_basis.iter().position(|e| e == &product_term.factor_a) { - if let Some(gather_index) = b_flat_basis.iter().position(|e| e == &product_term.factor_b) { - sorted_terms[y][x] = (product_term.product.scalar, gather_index); + if let Some(y) = result_flat_basis.iter().position(|e| e.index == product_term.product.index) { + if let Some(x) = a_flat_basis.iter().position(|e| e.index == product_term.factor_a.index) { + if let Some(gather_index) = b_flat_basis.iter().position(|e| e.index == product_term.factor_b.index) { + sorted_terms[y][x] = ( + result_flat_basis[y].scalar + * product_term.product.scalar + * a_flat_basis[x].scalar + * product_term.factor_a.scalar + * b_flat_basis[gather_index].scalar + * product_term.factor_b.scalar, + gather_index, + ); } } } @@ -469,7 +440,7 @@ impl MultiVectorClass { }, vec![(0, 0); expression.size], vec![(0, 0); expression.size], - vec![false; expression.size], + vec![0; expression.size], ); for (index_in_a, a_terms) in transposed_terms.enumerate() { if a_terms.iter().all(|(factor, _)| *factor == 0) { @@ -488,10 +459,7 @@ impl MultiVectorClass { .enumerate() .map(|(index, (factor, _index_pair))| b_indices[if *factor == 0 { non_zero_index } else { index }]) .collect::>(); - let is_contractable = a_terms - .iter() - .enumerate() - .all(|(i, (factor, _))| *factor == 0 || *factor == 1 && !contraction.4[i]) + let is_contractable = a_terms.iter().enumerate().all(|(i, (factor, _))| *factor == 0 || contraction.4[i] == 0) && (contraction.0.content == ExpressionContent::None || contraction.0.size == parameter_a.multi_vector_class().grouped_basis[a_group_index].len()) && (contraction.1.content == ExpressionContent::None @@ -511,10 +479,10 @@ impl MultiVectorClass { contraction.3 = b_indices.iter().map(|(b_group_index, _)| (*b_group_index, 0)).collect(); } for (i, (factor, _index_in_b)) in a_terms.iter().enumerate() { - if *factor == 1 { + if *factor != 0 { contraction.2[i] = a_indices[i]; contraction.3[i] = b_indices[i]; - contraction.4[i] = true; + contraction.4[i] = *factor; } } } else { @@ -563,7 +531,7 @@ impl MultiVectorClass { }; } } - if contraction.4.iter().any(|mask| *mask) { + if contraction.4.iter().any(|scalar| *scalar != 0) { expression = Expression { size, content: ExpressionContent::Add( @@ -586,10 +554,7 @@ impl MultiVectorClass { }), Box::new(Expression { size, - content: ExpressionContent::Constant( - DataType::SimdVector(size), - contraction.4.iter().map(|value| *value as isize).collect(), - ), + content: ExpressionContent::Constant(DataType::SimdVector(size), contraction.4), }), ), }), diff --git a/codegen/src/main.rs b/codegen/src/main.rs index df93457..3b133bb 100644 --- a/codegen/src/main.rs +++ b/codegen/src/main.rs @@ -6,7 +6,7 @@ mod glsl; mod rust; use crate::{ - algebra::{BasisElement, GeometricAlgebra, Involution, MultiVectorClass, MultiVectorClassRegistry, Product, ScaledElement}, + algebra::{BasisElement, GeometricAlgebra, Involution, MultiVectorClass, MultiVectorClassRegistry, Product}, ast::{AstNode, DataType, Parameter}, emit::Emitter, }; @@ -31,13 +31,9 @@ fn main() { let involutions = Involution::involutions(&algebra); let products = Product::products(&algebra); let basis = algebra.sorted_basis(); - for a in basis.iter() { - for b in basis.iter() { - print!( - "{:1$} ", - ScaledElement::product(&ScaledElement::from(&a), &ScaledElement::from(&b), &algebra), - generator_squares.len() + 2 - ); + for b in basis.iter() { + for a in basis.iter() { + print!("{:1$} ", BasisElement::product(&a, &b, &algebra), generator_squares.len() + 2); } println!(); } @@ -53,7 +49,7 @@ fn main() { .map(|group_descriptor| { group_descriptor .split(',') - .map(|element_name| BasisElement::new(element_name)) + .map(|element_name| BasisElement::parse(element_name, &algebra)) .collect::>() }) .collect::>(), @@ -80,7 +76,7 @@ fn main() { } } for (name, involution) in involutions.iter() { - let ast_node = MultiVectorClass::involution(name, &involution, ¶meter_a, ®istry); + let ast_node = MultiVectorClass::involution(name, &involution, ¶meter_a, ®istry, false); emitter.emit(&ast_node).unwrap(); if ast_node != AstNode::None { single_trait_implementations.insert(name.to_string(), ast_node); @@ -93,11 +89,13 @@ fn main() { name: "other", data_type: DataType::MultiVector(class_b), }; - let name = "Into"; - let ast_node = MultiVectorClass::conversion(name, ¶meter_a, ¶meter_b.multi_vector_class()); - emitter.emit(&ast_node).unwrap(); - if ast_node != AstNode::None { - trait_implementations.insert(name.to_string(), ast_node); + if class_a != class_b { + let name = "Into"; + let ast_node = MultiVectorClass::involution(name, &Involution::projection(&class_b), ¶meter_a, ®istry, true); + emitter.emit(&ast_node).unwrap(); + if ast_node != AstNode::None { + trait_implementations.insert(name.to_string(), ast_node); + } } for name in &["Add", "Sub"] { let ast_node = MultiVectorClass::sum(*name, ¶meter_a, ¶meter_b, ®istry); @@ -135,7 +133,7 @@ fn main() { } for (parameter_b, pair_trait_implementations) in pair_trait_implementations.values() { if let Some(geometric_product) = pair_trait_implementations.get("GeometricProduct") { - if parameter_b.multi_vector_class().grouped_basis == vec![vec![BasisElement { index: 0 }]] { + if parameter_b.multi_vector_class().grouped_basis == vec![vec![BasisElement::from_index(0)]] { if let Some(magnitude) = single_trait_implementations.get("Magnitude") { let signum = MultiVectorClass::derive_signum("Signum", &geometric_product, &magnitude, ¶meter_a); emitter.emit(&signum).unwrap();