Adds support for negative basis elements and fixes wrong duals.
This commit is contained in:
parent
62f230ff9a
commit
365ddcba44
6 changed files with 159 additions and 211 deletions
4
.github/workflows/actions.yaml
vendored
4
.github/workflows/actions.yaml
vendored
|
|
@ -24,9 +24,9 @@ jobs:
|
|||
- name: complex
|
||||
descriptor: "complex:-1;Scalar:1;MultiVector:1,e0"
|
||||
- name: ppga2d
|
||||
descriptor: "ppga2d:0,1,1;Scalar:1;MultiVector:1,e12,e1,e2|e0,e012,e01,e02;Rotor:1,e12;Point:e12,e01,e02;Plane:e0,e2,e1;Translator:1,e01,e02;Motor:1,e12,e01,e02;MotorDual:e012,e0,e2,e1"
|
||||
descriptor: "ppga2d:0,1,1;Scalar:1;MultiVector:1,e12,e1,e2|e0,e012,e01,-e02;Rotor:1,e12;Point:e12,e01,-e02;Plane:e0,e2,e1;Translator:1,e01,-e02;Motor:1,e12,e01,-e02;MotorDual:e012,e0,e2,e1"
|
||||
- name: ppga3d
|
||||
descriptor: "ppga3d:0,1,1,1;Scalar:1;MultiVector:1,e23,e13,e12|e0,e023,e013,e012|e123,e1,e2,e3|e0123,e01,e02,e03;Rotor:1,e23,e13,e12;Point:e123,e023,e013,e012;Plane:e0,e1,e2,e3;Line:e01,e02,e03|e23,e13,e12;Translator:1,e01,e02,e03;Motor:1,e23,e13,e12|e0123,e01,e02,e03;PointAndPlane:e123,e023,e013,e012|e0,e1,e2,e3"
|
||||
descriptor: "ppga3d:0,1,1,1;Scalar:1;MultiVector:1,e23,-e13,e12|e0,-e023,e013,-e012|e123,e1,e2,e3|e0123,e01,e02,e03;Rotor:1,e23,-e13,e12;Point:e123,-e023,e013,-e012;Plane:e0,e1,e2,e3;Line:e01,e02,e03|e23,-e13,e12;Translator:1,e01,e02,e03;Motor:1,e23,-e13,e12|e0123,e01,e02,e03;PointAndPlane:e123,-e023,e013,-e012|e0,e1,e2,e3"
|
||||
steps:
|
||||
- uses: actions/download-artifact@v2
|
||||
with:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "geometric_algebra"
|
||||
version = "0.1.0"
|
||||
version = "0.1.1"
|
||||
authors = ["Alexander Meißner <AlexanderMeissner@gmx.net>"]
|
||||
description = "Generate(d) custom libraries for geometric algebras"
|
||||
repository = "https://github.com/Lichtso/geometric_algebra/"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "codegen"
|
||||
version = "0.1.0"
|
||||
version = "0.1.1"
|
||||
authors = ["Alexander Meißner <AlexanderMeissner@gmx.net>"]
|
||||
edition = "2018"
|
||||
publish = false
|
||||
|
|
@ -8,7 +8,14 @@ impl<'a> GeometricAlgebra<'a> {
|
|||
}
|
||||
|
||||
pub fn basis(&self) -> impl Iterator<Item = BasisElement> + '_ {
|
||||
(0..self.basis_size() as BasisElementIndex).map(|index| BasisElement { index })
|
||||
(0..self.basis_size() as BasisElementIndex).map(move |index| {
|
||||
let mut element = BasisElement::from_index(index);
|
||||
let dual = element.dual(self);
|
||||
if dual.cmp(&element) == std::cmp::Ordering::Less {
|
||||
element.scalar = element.dual(self).scalar;
|
||||
}
|
||||
element
|
||||
})
|
||||
}
|
||||
|
||||
pub fn sorted_basis(&self) -> Vec<BasisElement> {
|
||||
|
|
@ -18,24 +25,35 @@ impl<'a> GeometricAlgebra<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
type BasisElementIndex = u16;
|
||||
pub type BasisElementIndex = u16;
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Hash)]
|
||||
pub struct BasisElement {
|
||||
pub scalar: isize,
|
||||
pub index: BasisElementIndex,
|
||||
}
|
||||
|
||||
impl BasisElement {
|
||||
pub fn new(name: &str) -> Self {
|
||||
Self {
|
||||
index: if name == "1" {
|
||||
0
|
||||
} else {
|
||||
pub fn from_index(index: BasisElementIndex) -> Self {
|
||||
Self { scalar: 1, index }
|
||||
}
|
||||
|
||||
pub fn parse(mut name: &str, algebra: &GeometricAlgebra) -> Self {
|
||||
let mut result = Self::from_index(0);
|
||||
if name.starts_with('-') {
|
||||
name = &name[1..];
|
||||
result.scalar = -1;
|
||||
}
|
||||
if name == "1" {
|
||||
return result;
|
||||
}
|
||||
let mut generator_indices = name.chars();
|
||||
assert_eq!(generator_indices.next().unwrap(), 'e');
|
||||
generator_indices.fold(0, |index, generator_index| index | (1 << (generator_index.to_digit(16).unwrap())))
|
||||
},
|
||||
for generator_index in generator_indices {
|
||||
let generator = Self::from_index(1 << (generator_index.to_digit(16).unwrap()));
|
||||
result = BasisElement::product(&result, &generator, algebra);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
pub fn grade(&self) -> usize {
|
||||
|
|
@ -47,8 +65,30 @@ impl BasisElement {
|
|||
}
|
||||
|
||||
pub fn dual(&self, algebra: &GeometricAlgebra) -> Self {
|
||||
Self {
|
||||
let mut result = Self {
|
||||
scalar: self.scalar,
|
||||
index: algebra.basis_size() as BasisElementIndex - 1 - self.index,
|
||||
};
|
||||
result.scalar *= BasisElement::product(&self, &result, &algebra).scalar;
|
||||
result
|
||||
}
|
||||
|
||||
pub fn product(a: &Self, b: &Self, algebra: &GeometricAlgebra) -> Self {
|
||||
let commutations = a.component_bits().fold((0, a.index, b.index), |(commutations, a, b), index| {
|
||||
let hurdles_a = a & (BasisElementIndex::MAX << (index + 1));
|
||||
let hurdles_b = b & ((1 << index) - 1);
|
||||
(
|
||||
commutations + Self::from_index(hurdles_a | hurdles_b).grade(),
|
||||
a & !(1 << index),
|
||||
b ^ (1 << index),
|
||||
)
|
||||
});
|
||||
Self {
|
||||
scalar: Self::from_index(a.index & b.index)
|
||||
.component_bits()
|
||||
.map(|i| algebra.generator_squares[i])
|
||||
.fold(a.scalar * b.scalar * if commutations.0 % 2 == 0 { 1 } else { -1 }, |a, b| a * b),
|
||||
index: a.index ^ b.index,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -56,10 +96,10 @@ impl BasisElement {
|
|||
impl std::fmt::Display for BasisElement {
|
||||
fn fmt(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
if self.index == 0 {
|
||||
formatter.pad("1")
|
||||
formatter.pad_integral(self.scalar >= 0, "", "1")
|
||||
} else {
|
||||
let string = format!("e{}", self.component_bits().map(|index| format!("{:X}", index)).collect::<String>());
|
||||
formatter.pad(string.as_str())
|
||||
formatter.pad_integral(self.scalar >= 0, "", string.as_str())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -86,67 +126,21 @@ impl std::cmp::PartialOrd for BasisElement {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq)]
|
||||
pub struct ScaledElement {
|
||||
pub scalar: isize,
|
||||
pub unit: BasisElement,
|
||||
}
|
||||
|
||||
impl ScaledElement {
|
||||
pub fn from(element: &BasisElement) -> Self {
|
||||
Self {
|
||||
scalar: 1,
|
||||
unit: element.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn product(a: &Self, b: &Self, algebra: &GeometricAlgebra) -> Self {
|
||||
let commutations = a
|
||||
.unit
|
||||
.component_bits()
|
||||
.fold((0, a.unit.index, b.unit.index), |(commutations, a, b), index| {
|
||||
let hurdles_a = a & (BasisElementIndex::MAX << (index + 1));
|
||||
let hurdles_b = b & ((1 << index) - 1);
|
||||
(
|
||||
commutations
|
||||
+ BasisElement {
|
||||
index: hurdles_a | hurdles_b,
|
||||
}
|
||||
.grade(),
|
||||
a & !(1 << index),
|
||||
b ^ (1 << index),
|
||||
)
|
||||
});
|
||||
Self {
|
||||
scalar: BasisElement {
|
||||
index: a.unit.index & b.unit.index,
|
||||
}
|
||||
.component_bits()
|
||||
.map(|i| algebra.generator_squares[i])
|
||||
.fold(a.scalar * b.scalar * if commutations.0 % 2 == 0 { 1 } else { -1 }, |a, b| a * b),
|
||||
unit: BasisElement {
|
||||
index: a.unit.index ^ b.unit.index,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ScaledElement {
|
||||
fn fmt(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
let string = self.unit.to_string();
|
||||
formatter.pad_integral(self.scalar >= 0, "", if self.scalar == 0 { "0" } else { string.as_str() })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Involution {
|
||||
pub terms: Vec<ScaledElement>,
|
||||
pub terms: Vec<(BasisElement, BasisElement)>,
|
||||
}
|
||||
|
||||
impl Involution {
|
||||
pub fn identity(algebra: &GeometricAlgebra) -> Self {
|
||||
Self {
|
||||
terms: algebra.basis().map(|element| ScaledElement { scalar: 1, unit: element }).collect(),
|
||||
terms: algebra.basis().map(|element| (element.clone(), element)).collect(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn projection(class: &MultiVectorClass) -> Self {
|
||||
Self {
|
||||
terms: class.flat_basis().iter().map(|element| (element.clone(), element.clone())).collect(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -158,9 +152,10 @@ impl Involution {
|
|||
terms: self
|
||||
.terms
|
||||
.iter()
|
||||
.map(|element| ScaledElement {
|
||||
scalar: if grade_negation(element.unit.grade()) { -1 } else { 1 },
|
||||
unit: element.unit.clone(),
|
||||
.map(|(key, value)| {
|
||||
let mut element = value.clone();
|
||||
element.scalar *= if grade_negation(value.grade()) { -1 } else { 1 };
|
||||
(key.clone(), element)
|
||||
})
|
||||
.collect(),
|
||||
}
|
||||
|
|
@ -168,14 +163,7 @@ impl Involution {
|
|||
|
||||
pub fn dual(&self, algebra: &GeometricAlgebra) -> Self {
|
||||
Self {
|
||||
terms: self
|
||||
.terms
|
||||
.iter()
|
||||
.map(|term| ScaledElement {
|
||||
scalar: term.scalar,
|
||||
unit: term.unit.dual(algebra),
|
||||
})
|
||||
.collect(),
|
||||
terms: self.terms.iter().map(|(key, value)| (key.clone(), value.dual(algebra))).collect(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -193,7 +181,7 @@ impl Involution {
|
|||
|
||||
#[derive(Clone, PartialEq)]
|
||||
pub struct ProductTerm {
|
||||
pub product: ScaledElement,
|
||||
pub product: BasisElement,
|
||||
pub factor_a: BasisElement,
|
||||
pub factor_b: BasisElement,
|
||||
}
|
||||
|
|
@ -204,15 +192,15 @@ pub struct Product {
|
|||
}
|
||||
|
||||
impl Product {
|
||||
pub fn product(a: &[ScaledElement], b: &[ScaledElement], algebra: &GeometricAlgebra) -> Self {
|
||||
pub fn product(a: &[BasisElement], b: &[BasisElement], algebra: &GeometricAlgebra) -> Self {
|
||||
Self {
|
||||
terms: a
|
||||
.iter()
|
||||
.map(|a| {
|
||||
b.iter().map(move |b| ProductTerm {
|
||||
product: ScaledElement::product(&a, &b, algebra),
|
||||
factor_a: a.unit.clone(),
|
||||
factor_b: b.unit.clone(),
|
||||
product: BasisElement::product(&a, &b, algebra),
|
||||
factor_a: a.clone(),
|
||||
factor_b: b.clone(),
|
||||
})
|
||||
})
|
||||
.flatten()
|
||||
|
|
@ -229,7 +217,7 @@ impl Product {
|
|||
terms: self
|
||||
.terms
|
||||
.iter()
|
||||
.filter(|term| grade_projection(term.factor_a.grade(), term.factor_b.grade(), term.product.unit.grade()))
|
||||
.filter(|term| grade_projection(term.factor_a.grade(), term.factor_b.grade(), term.product.grade()))
|
||||
.cloned()
|
||||
.collect(),
|
||||
}
|
||||
|
|
@ -241,10 +229,7 @@ impl Product {
|
|||
.terms
|
||||
.iter()
|
||||
.map(|term| ProductTerm {
|
||||
product: ScaledElement {
|
||||
scalar: term.product.scalar,
|
||||
unit: term.product.unit.dual(algebra),
|
||||
},
|
||||
product: term.product.dual(algebra),
|
||||
factor_a: term.factor_a.dual(algebra),
|
||||
factor_b: term.factor_b.dual(algebra),
|
||||
})
|
||||
|
|
@ -253,7 +238,7 @@ impl Product {
|
|||
}
|
||||
|
||||
pub fn products(algebra: &GeometricAlgebra) -> Vec<(&'static str, Self)> {
|
||||
let basis = algebra.basis().map(|element| ScaledElement::from(&element)).collect::<Vec<_>>();
|
||||
let basis = algebra.basis().collect::<Vec<_>>();
|
||||
let product = Self::product(&basis, &basis, algebra);
|
||||
vec![
|
||||
("GeometricProduct", product.clone()),
|
||||
|
|
@ -270,7 +255,7 @@ impl Product {
|
|||
#[derive(Default)]
|
||||
pub struct MultiVectorClassRegistry {
|
||||
pub classes: Vec<MultiVectorClass>,
|
||||
index_by_signature: std::collections::HashMap<Vec<BasisElement>, usize>,
|
||||
index_by_signature: std::collections::HashMap<Vec<BasisElementIndex>, usize>,
|
||||
}
|
||||
|
||||
impl MultiVectorClassRegistry {
|
||||
|
|
@ -279,7 +264,7 @@ impl MultiVectorClassRegistry {
|
|||
self.classes.push(class);
|
||||
}
|
||||
|
||||
pub fn get(&self, signature: &[BasisElement]) -> Option<&MultiVectorClass> {
|
||||
pub fn get(&self, signature: &[BasisElementIndex]) -> Option<&MultiVectorClass> {
|
||||
self.index_by_signature.get(signature).map(|index| &self.classes[*index])
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use crate::{
|
||||
algebra::{BasisElement, Involution, MultiVectorClass, MultiVectorClassRegistry, Product},
|
||||
algebra::{BasisElement, BasisElementIndex, Involution, MultiVectorClass, MultiVectorClassRegistry, Product},
|
||||
ast::{AstNode, DataType, Expression, ExpressionContent, Parameter},
|
||||
};
|
||||
|
||||
|
|
@ -146,9 +146,9 @@ impl MultiVectorClass {
|
|||
self.grouped_basis.iter().flatten().cloned().collect()
|
||||
}
|
||||
|
||||
pub fn signature(&self) -> Vec<BasisElement> {
|
||||
let mut signature = self.flat_basis();
|
||||
signature.sort();
|
||||
pub fn signature(&self) -> Vec<BasisElementIndex> {
|
||||
let mut signature: Vec<BasisElementIndex> = self.grouped_basis.iter().flatten().map(|element| element.index).collect();
|
||||
signature.sort_unstable();
|
||||
signature
|
||||
}
|
||||
|
||||
|
|
@ -204,29 +204,48 @@ impl MultiVectorClass {
|
|||
involution: &Involution,
|
||||
parameter_a: &Parameter<'a>,
|
||||
registry: &'a MultiVectorClassRegistry,
|
||||
project: bool,
|
||||
) -> AstNode<'a> {
|
||||
let a_flat_basis = parameter_a.multi_vector_class().flat_basis();
|
||||
let result_signature = a_flat_basis
|
||||
.iter()
|
||||
.map(|element| involution.terms[element.index as usize].unit.clone())
|
||||
.collect::<std::collections::HashSet<_>>();
|
||||
let mut result_signature = result_signature.into_iter().collect::<Vec<_>>();
|
||||
result_signature.sort();
|
||||
let mut result_signature = Vec::new();
|
||||
for a_element in a_flat_basis.iter() {
|
||||
for (in_element, out_element) in involution.terms.iter() {
|
||||
if in_element.index == a_element.index {
|
||||
result_signature.push(out_element.index);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if project {
|
||||
for (in_element, _out_element) in involution.terms.iter() {
|
||||
if !a_flat_basis.iter().any(|element| element.index == in_element.index) {
|
||||
return AstNode::None;
|
||||
}
|
||||
}
|
||||
}
|
||||
result_signature.sort_unstable();
|
||||
if let Some(result_class) = registry.get(&result_signature) {
|
||||
let result_flat_basis = result_class.flat_basis();
|
||||
let mut body = Vec::new();
|
||||
let mut base_index = 0;
|
||||
for result_group in result_class.grouped_basis.iter() {
|
||||
let size = result_group.len();
|
||||
let a_indices = (0..size)
|
||||
let (factors, a_indices): (Vec<_>, Vec<_>) = (0..size)
|
||||
.map(|index_in_group| {
|
||||
let result_element = &result_flat_basis[base_index + index_in_group];
|
||||
let a_element = &involution.terms[result_element.index as usize].unit;
|
||||
parameter_a
|
||||
.multi_vector_class()
|
||||
.index_in_group(a_flat_basis.iter().position(|element| element == a_element).unwrap())
|
||||
let involution_element = involution
|
||||
.terms
|
||||
.iter()
|
||||
.position(|(_in_element, out_element)| out_element.index == result_element.index)
|
||||
.unwrap();
|
||||
let (in_element, out_element) = &involution.terms[involution_element];
|
||||
let index_in_a = a_flat_basis.iter().position(|a_element| a_element.index == in_element.index).unwrap();
|
||||
(
|
||||
out_element.scalar * result_element.scalar * in_element.scalar * a_flat_basis[index_in_a].scalar,
|
||||
parameter_a.multi_vector_class().index_in_group(index_in_a),
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
.unzip();
|
||||
let a_group_index = a_indices[0].0;
|
||||
let expression = Expression {
|
||||
size,
|
||||
|
|
@ -243,13 +262,7 @@ impl MultiVectorClass {
|
|||
}),
|
||||
Box::new(Expression {
|
||||
size,
|
||||
content: ExpressionContent::Constant(
|
||||
DataType::SimdVector(size),
|
||||
result_group
|
||||
.iter()
|
||||
.map(|element| involution.terms[element.index as usize].scalar)
|
||||
.collect(),
|
||||
),
|
||||
content: ExpressionContent::Constant(DataType::SimdVector(size), factors),
|
||||
}),
|
||||
),
|
||||
};
|
||||
|
|
@ -274,58 +287,6 @@ impl MultiVectorClass {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn conversion<'a>(name: &'static str, parameter_a: &Parameter<'a>, result_class: &'a MultiVectorClass) -> AstNode<'a> {
|
||||
if parameter_a.multi_vector_class() == result_class {
|
||||
return AstNode::None;
|
||||
}
|
||||
let a_flat_basis = parameter_a.multi_vector_class().flat_basis();
|
||||
let result_flat_basis = result_class.flat_basis();
|
||||
for element in &result_flat_basis {
|
||||
if !a_flat_basis.contains(element) {
|
||||
return AstNode::None;
|
||||
}
|
||||
}
|
||||
let mut body = Vec::new();
|
||||
let mut base_index = 0;
|
||||
for result_group in result_class.grouped_basis.iter() {
|
||||
let size = result_group.len();
|
||||
let a_indices = (0..size)
|
||||
.map(|index_in_group| {
|
||||
let result_element = &result_flat_basis[base_index + index_in_group];
|
||||
parameter_a
|
||||
.multi_vector_class()
|
||||
.index_in_group(a_flat_basis.iter().position(|a_element| a_element == result_element).unwrap())
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let a_group_index = a_indices[0].0;
|
||||
let expression = Expression {
|
||||
size,
|
||||
content: ExpressionContent::Gather(
|
||||
Box::new(Expression {
|
||||
size: parameter_a.multi_vector_class().grouped_basis[a_group_index].len(),
|
||||
content: ExpressionContent::Variable(parameter_a.name),
|
||||
}),
|
||||
a_indices,
|
||||
),
|
||||
};
|
||||
body.push((DataType::SimdVector(size), *simplify_and_legalize(Box::new(expression))));
|
||||
base_index += size;
|
||||
}
|
||||
AstNode::TraitImplementation {
|
||||
result: Parameter {
|
||||
name,
|
||||
data_type: DataType::MultiVector(result_class),
|
||||
},
|
||||
parameters: vec![parameter_a.clone()],
|
||||
body: vec![AstNode::ReturnStatement {
|
||||
expression: Box::new(Expression {
|
||||
size: 1,
|
||||
content: ExpressionContent::InvokeClassMethod(result_class, "Constructor", body),
|
||||
}),
|
||||
}],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn sum<'a>(
|
||||
name: &'static str,
|
||||
parameter_a: &Parameter<'a>,
|
||||
|
|
@ -339,8 +300,8 @@ impl MultiVectorClass {
|
|||
.chain(b_flat_basis.iter())
|
||||
.cloned()
|
||||
.collect::<std::collections::HashSet<_>>();
|
||||
let mut result_signature = result_signature.into_iter().collect::<Vec<_>>();
|
||||
result_signature.sort();
|
||||
let mut result_signature = result_signature.into_iter().map(|element| element.index).collect::<Vec<_>>();
|
||||
result_signature.sort_unstable();
|
||||
if let Some(result_class) = registry.get(&result_signature) {
|
||||
let parameters = [(parameter_a, &a_flat_basis), (parameter_b, &b_flat_basis)];
|
||||
let mut body = Vec::new();
|
||||
|
|
@ -351,10 +312,10 @@ impl MultiVectorClass {
|
|||
let terms: Vec<_> = result_group
|
||||
.iter()
|
||||
.map(|result_element| {
|
||||
if let Some(index_in_group) = flat_basis.iter().position(|element| element == result_element) {
|
||||
let index_pair = parameter.multi_vector_class().index_in_group(index_in_group);
|
||||
if let Some(index_in_flat_basis) = flat_basis.iter().position(|element| element.index == result_element.index) {
|
||||
let index_pair = parameter.multi_vector_class().index_in_group(index_in_flat_basis);
|
||||
parameter_group_index = Some(index_pair.0);
|
||||
(1, index_pair)
|
||||
(result_element.scalar * flat_basis[index_in_flat_basis].scalar, index_pair)
|
||||
} else {
|
||||
(0, (0, 0))
|
||||
}
|
||||
|
|
@ -428,20 +389,30 @@ impl MultiVectorClass {
|
|||
let b_flat_basis = parameter_b.multi_vector_class().flat_basis();
|
||||
let mut result_signature = std::collections::HashSet::new();
|
||||
for product_term in product.terms.iter() {
|
||||
if a_flat_basis.contains(&product_term.factor_a) && b_flat_basis.contains(&product_term.factor_b) {
|
||||
result_signature.insert(product_term.product.unit.clone());
|
||||
if a_flat_basis.iter().any(|e| e.index == product_term.factor_a.index)
|
||||
&& b_flat_basis.iter().any(|e| e.index == product_term.factor_b.index)
|
||||
{
|
||||
result_signature.insert(product_term.product.index);
|
||||
}
|
||||
}
|
||||
let mut result_signature = result_signature.into_iter().collect::<Vec<_>>();
|
||||
result_signature.sort();
|
||||
result_signature.sort_unstable();
|
||||
if let Some(result_class) = registry.get(&result_signature) {
|
||||
let result_flat_basis = result_class.flat_basis();
|
||||
let mut sorted_terms = vec![vec![(0, 0); a_flat_basis.len()]; result_flat_basis.len()];
|
||||
for product_term in product.terms.iter() {
|
||||
if let Some(y) = result_flat_basis.iter().position(|e| e == &product_term.product.unit) {
|
||||
if let Some(x) = a_flat_basis.iter().position(|e| e == &product_term.factor_a) {
|
||||
if let Some(gather_index) = b_flat_basis.iter().position(|e| e == &product_term.factor_b) {
|
||||
sorted_terms[y][x] = (product_term.product.scalar, gather_index);
|
||||
if let Some(y) = result_flat_basis.iter().position(|e| e.index == product_term.product.index) {
|
||||
if let Some(x) = a_flat_basis.iter().position(|e| e.index == product_term.factor_a.index) {
|
||||
if let Some(gather_index) = b_flat_basis.iter().position(|e| e.index == product_term.factor_b.index) {
|
||||
sorted_terms[y][x] = (
|
||||
result_flat_basis[y].scalar
|
||||
* product_term.product.scalar
|
||||
* a_flat_basis[x].scalar
|
||||
* product_term.factor_a.scalar
|
||||
* b_flat_basis[gather_index].scalar
|
||||
* product_term.factor_b.scalar,
|
||||
gather_index,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -469,7 +440,7 @@ impl MultiVectorClass {
|
|||
},
|
||||
vec![(0, 0); expression.size],
|
||||
vec![(0, 0); expression.size],
|
||||
vec![false; expression.size],
|
||||
vec![0; expression.size],
|
||||
);
|
||||
for (index_in_a, a_terms) in transposed_terms.enumerate() {
|
||||
if a_terms.iter().all(|(factor, _)| *factor == 0) {
|
||||
|
|
@ -488,10 +459,7 @@ impl MultiVectorClass {
|
|||
.enumerate()
|
||||
.map(|(index, (factor, _index_pair))| b_indices[if *factor == 0 { non_zero_index } else { index }])
|
||||
.collect::<Vec<_>>();
|
||||
let is_contractable = a_terms
|
||||
.iter()
|
||||
.enumerate()
|
||||
.all(|(i, (factor, _))| *factor == 0 || *factor == 1 && !contraction.4[i])
|
||||
let is_contractable = a_terms.iter().enumerate().all(|(i, (factor, _))| *factor == 0 || contraction.4[i] == 0)
|
||||
&& (contraction.0.content == ExpressionContent::None
|
||||
|| contraction.0.size == parameter_a.multi_vector_class().grouped_basis[a_group_index].len())
|
||||
&& (contraction.1.content == ExpressionContent::None
|
||||
|
|
@ -511,10 +479,10 @@ impl MultiVectorClass {
|
|||
contraction.3 = b_indices.iter().map(|(b_group_index, _)| (*b_group_index, 0)).collect();
|
||||
}
|
||||
for (i, (factor, _index_in_b)) in a_terms.iter().enumerate() {
|
||||
if *factor == 1 {
|
||||
if *factor != 0 {
|
||||
contraction.2[i] = a_indices[i];
|
||||
contraction.3[i] = b_indices[i];
|
||||
contraction.4[i] = true;
|
||||
contraction.4[i] = *factor;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
|
@ -563,7 +531,7 @@ impl MultiVectorClass {
|
|||
};
|
||||
}
|
||||
}
|
||||
if contraction.4.iter().any(|mask| *mask) {
|
||||
if contraction.4.iter().any(|scalar| *scalar != 0) {
|
||||
expression = Expression {
|
||||
size,
|
||||
content: ExpressionContent::Add(
|
||||
|
|
@ -586,10 +554,7 @@ impl MultiVectorClass {
|
|||
}),
|
||||
Box::new(Expression {
|
||||
size,
|
||||
content: ExpressionContent::Constant(
|
||||
DataType::SimdVector(size),
|
||||
contraction.4.iter().map(|value| *value as isize).collect(),
|
||||
),
|
||||
content: ExpressionContent::Constant(DataType::SimdVector(size), contraction.4),
|
||||
}),
|
||||
),
|
||||
}),
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ mod glsl;
|
|||
mod rust;
|
||||
|
||||
use crate::{
|
||||
algebra::{BasisElement, GeometricAlgebra, Involution, MultiVectorClass, MultiVectorClassRegistry, Product, ScaledElement},
|
||||
algebra::{BasisElement, GeometricAlgebra, Involution, MultiVectorClass, MultiVectorClassRegistry, Product},
|
||||
ast::{AstNode, DataType, Parameter},
|
||||
emit::Emitter,
|
||||
};
|
||||
|
|
@ -31,13 +31,9 @@ fn main() {
|
|||
let involutions = Involution::involutions(&algebra);
|
||||
let products = Product::products(&algebra);
|
||||
let basis = algebra.sorted_basis();
|
||||
for a in basis.iter() {
|
||||
for b in basis.iter() {
|
||||
print!(
|
||||
"{:1$} ",
|
||||
ScaledElement::product(&ScaledElement::from(&a), &ScaledElement::from(&b), &algebra),
|
||||
generator_squares.len() + 2
|
||||
);
|
||||
for a in basis.iter() {
|
||||
print!("{:1$} ", BasisElement::product(&a, &b, &algebra), generator_squares.len() + 2);
|
||||
}
|
||||
println!();
|
||||
}
|
||||
|
|
@ -53,7 +49,7 @@ fn main() {
|
|||
.map(|group_descriptor| {
|
||||
group_descriptor
|
||||
.split(',')
|
||||
.map(|element_name| BasisElement::new(element_name))
|
||||
.map(|element_name| BasisElement::parse(element_name, &algebra))
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
|
|
@ -80,7 +76,7 @@ fn main() {
|
|||
}
|
||||
}
|
||||
for (name, involution) in involutions.iter() {
|
||||
let ast_node = MultiVectorClass::involution(name, &involution, ¶meter_a, ®istry);
|
||||
let ast_node = MultiVectorClass::involution(name, &involution, ¶meter_a, ®istry, false);
|
||||
emitter.emit(&ast_node).unwrap();
|
||||
if ast_node != AstNode::None {
|
||||
single_trait_implementations.insert(name.to_string(), ast_node);
|
||||
|
|
@ -93,12 +89,14 @@ fn main() {
|
|||
name: "other",
|
||||
data_type: DataType::MultiVector(class_b),
|
||||
};
|
||||
if class_a != class_b {
|
||||
let name = "Into";
|
||||
let ast_node = MultiVectorClass::conversion(name, ¶meter_a, ¶meter_b.multi_vector_class());
|
||||
let ast_node = MultiVectorClass::involution(name, &Involution::projection(&class_b), ¶meter_a, ®istry, true);
|
||||
emitter.emit(&ast_node).unwrap();
|
||||
if ast_node != AstNode::None {
|
||||
trait_implementations.insert(name.to_string(), ast_node);
|
||||
}
|
||||
}
|
||||
for name in &["Add", "Sub"] {
|
||||
let ast_node = MultiVectorClass::sum(*name, ¶meter_a, ¶meter_b, ®istry);
|
||||
emitter.emit(&ast_node).unwrap();
|
||||
|
|
@ -135,7 +133,7 @@ fn main() {
|
|||
}
|
||||
for (parameter_b, pair_trait_implementations) in pair_trait_implementations.values() {
|
||||
if let Some(geometric_product) = pair_trait_implementations.get("GeometricProduct") {
|
||||
if parameter_b.multi_vector_class().grouped_basis == vec![vec![BasisElement { index: 0 }]] {
|
||||
if parameter_b.multi_vector_class().grouped_basis == vec![vec![BasisElement::from_index(0)]] {
|
||||
if let Some(magnitude) = single_trait_implementations.get("Magnitude") {
|
||||
let signum = MultiVectorClass::derive_signum("Signum", &geometric_product, &magnitude, ¶meter_a);
|
||||
emitter.emit(&signum).unwrap();
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue