1226 lines
58 KiB
Rust
1226 lines
58 KiB
Rust
use crate::{
|
|
algebra::{BasisElement, BasisElementIndex, Involution, MultiVectorClass, MultiVectorClassRegistry, Product},
|
|
ast::{AstNode, DataType, Expression, ExpressionContent, Parameter},
|
|
};
|
|
|
|
#[macro_export]
|
|
macro_rules! result_of_trait {
|
|
($ast_node:expr) => {
|
|
match $ast_node {
|
|
AstNode::TraitImplementation { ref result, .. } => result,
|
|
_ => unreachable!(),
|
|
}
|
|
};
|
|
}
|
|
|
|
pub fn simplify_and_legalize(expression: Box<Expression>) -> Box<Expression> {
|
|
match expression.content {
|
|
ExpressionContent::Gather(mut inner_expression, indices) => {
|
|
if let Some(first_index_pair) = indices.first() {
|
|
inner_expression = simplify_and_legalize(inner_expression);
|
|
if indices.iter().all(|index_pair| index_pair == first_index_pair) {
|
|
Box::new(Expression {
|
|
size: expression.size,
|
|
content: ExpressionContent::Gather(inner_expression, vec![*first_index_pair]),
|
|
})
|
|
} else if inner_expression.size == expression.size && indices.iter().all(|(array_index, _)| *array_index == first_index_pair.0) {
|
|
inner_expression = Box::new(Expression {
|
|
size: expression.size,
|
|
content: ExpressionContent::Access(inner_expression, first_index_pair.0),
|
|
});
|
|
if indices.iter().enumerate().any(|(i, (_, component_index))| i != *component_index) {
|
|
Box::new(Expression {
|
|
size: expression.size,
|
|
content: ExpressionContent::Swizzle(
|
|
inner_expression,
|
|
indices.iter().map(|(_, component_index)| *component_index).collect(),
|
|
),
|
|
})
|
|
} else {
|
|
inner_expression
|
|
}
|
|
} else {
|
|
Box::new(Expression {
|
|
size: expression.size,
|
|
content: ExpressionContent::Gather(inner_expression, indices),
|
|
})
|
|
}
|
|
} else {
|
|
Box::new(Expression {
|
|
size: expression.size,
|
|
content: ExpressionContent::None,
|
|
})
|
|
}
|
|
}
|
|
ExpressionContent::Constant(ref data_type, ref values) => {
|
|
let first_value = values.first().unwrap();
|
|
if values.iter().all(|value| value == first_value) {
|
|
Box::new(Expression {
|
|
size: expression.size,
|
|
content: ExpressionContent::Constant(data_type.clone(), vec![*first_value]),
|
|
})
|
|
} else {
|
|
expression
|
|
}
|
|
}
|
|
ExpressionContent::Add(mut a, mut b) => {
|
|
if let ExpressionContent::Multiply(ref c, ref d) = b.content {
|
|
if let ExpressionContent::Multiply(ref e, ref f) = d.content {
|
|
if let ExpressionContent::Constant(_data_type, values) = &f.content {
|
|
if values.iter().all(|value| *value == -1) {
|
|
b = Box::new(Expression {
|
|
size: expression.size,
|
|
content: ExpressionContent::Multiply(c.clone(), e.clone()),
|
|
});
|
|
return simplify_and_legalize(Box::new(Expression {
|
|
size: expression.size,
|
|
content: ExpressionContent::Subtract(a, b),
|
|
}));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
a = simplify_and_legalize(a);
|
|
b = simplify_and_legalize(b);
|
|
if a.content == ExpressionContent::None {
|
|
b
|
|
} else if b.content == ExpressionContent::None {
|
|
a
|
|
} else {
|
|
Box::new(Expression {
|
|
size: expression.size,
|
|
content: ExpressionContent::Add(a, b),
|
|
})
|
|
}
|
|
}
|
|
ExpressionContent::Subtract(mut a, mut b) => {
|
|
a = simplify_and_legalize(a);
|
|
b = simplify_and_legalize(b);
|
|
if a.content == ExpressionContent::None {
|
|
let constant = Expression {
|
|
size: expression.size,
|
|
content: ExpressionContent::Constant(DataType::SimdVector(expression.size), vec![0]),
|
|
};
|
|
Box::new(Expression {
|
|
size: expression.size,
|
|
content: ExpressionContent::Subtract(Box::new(constant), b),
|
|
})
|
|
} else if b.content == ExpressionContent::None {
|
|
a
|
|
} else {
|
|
Box::new(Expression {
|
|
size: expression.size,
|
|
content: ExpressionContent::Subtract(a, b),
|
|
})
|
|
}
|
|
}
|
|
ExpressionContent::Multiply(mut a, mut b) => {
|
|
a = simplify_and_legalize(a);
|
|
b = simplify_and_legalize(b);
|
|
if let ExpressionContent::Constant(_, _) = a.content {
|
|
std::mem::swap(&mut a, &mut b)
|
|
}
|
|
if a.content == ExpressionContent::None {
|
|
b
|
|
} else {
|
|
match b.content {
|
|
ExpressionContent::None => a,
|
|
ExpressionContent::Constant(_data_type, c) if c.iter().all(|c| *c == 1) => a,
|
|
ExpressionContent::Constant(_data_type, c) if c.iter().all(|c| *c == 0) => Box::new(Expression {
|
|
size: expression.size,
|
|
content: ExpressionContent::None,
|
|
}),
|
|
_ => Box::new(Expression {
|
|
size: expression.size,
|
|
content: ExpressionContent::Multiply(a, b),
|
|
}),
|
|
}
|
|
}
|
|
}
|
|
_ => expression,
|
|
}
|
|
}
|
|
|
|
impl MultiVectorClass {
|
|
pub fn flat_basis(&self) -> Vec<BasisElement> {
|
|
self.grouped_basis.iter().flatten().cloned().collect()
|
|
}
|
|
|
|
pub fn signature(&self) -> Vec<BasisElementIndex> {
|
|
let mut signature: Vec<BasisElementIndex> = self.grouped_basis.iter().flatten().map(|element| element.index).collect();
|
|
signature.sort_unstable();
|
|
signature
|
|
}
|
|
|
|
pub fn index_in_group(&self, mut index: usize) -> (usize, usize) {
|
|
for (group_index, group) in self.grouped_basis.iter().enumerate() {
|
|
if index >= group.len() {
|
|
index -= group.len();
|
|
} else {
|
|
return (group_index, index);
|
|
}
|
|
}
|
|
unreachable!()
|
|
}
|
|
|
|
pub fn constant<'a>(&'a self, name: &'static str) -> AstNode<'a> {
|
|
let (scalar_value, other_values) = match name {
|
|
"Zero" => (0, 0),
|
|
"One" => (1, 0),
|
|
_ => unreachable!(),
|
|
};
|
|
let mut body = Vec::new();
|
|
for result_group in self.grouped_basis.iter() {
|
|
let size = result_group.len();
|
|
let expression = Expression {
|
|
size,
|
|
content: ExpressionContent::Constant(
|
|
DataType::SimdVector(size),
|
|
result_group
|
|
.iter()
|
|
.map(|element| if element.index == 0 { scalar_value } else { other_values })
|
|
.collect(),
|
|
),
|
|
};
|
|
body.push((DataType::SimdVector(size), *simplify_and_legalize(Box::new(expression))));
|
|
}
|
|
AstNode::TraitImplementation {
|
|
result: Parameter {
|
|
name,
|
|
data_type: DataType::MultiVector(self),
|
|
},
|
|
parameters: vec![],
|
|
body: vec![AstNode::ReturnStatement {
|
|
expression: Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::InvokeClassMethod(self, "Constructor", body),
|
|
}),
|
|
}],
|
|
}
|
|
}
|
|
|
|
pub fn involution<'a>(
|
|
name: &'static str,
|
|
involution: &Involution,
|
|
parameter_a: &Parameter<'a>,
|
|
registry: &'a MultiVectorClassRegistry,
|
|
project: bool,
|
|
) -> AstNode<'a> {
|
|
let a_flat_basis = parameter_a.multi_vector_class().flat_basis();
|
|
let mut result_signature = Vec::new();
|
|
for a_element in a_flat_basis.iter() {
|
|
for (in_element, out_element) in involution.terms.iter() {
|
|
if in_element.index == a_element.index {
|
|
result_signature.push(out_element.index);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
if project {
|
|
for (in_element, _out_element) in involution.terms.iter() {
|
|
if !a_flat_basis.iter().any(|element| element.index == in_element.index) {
|
|
return AstNode::None;
|
|
}
|
|
}
|
|
}
|
|
result_signature.sort_unstable();
|
|
if let Some(result_class) = registry.get(&result_signature) {
|
|
let result_flat_basis = result_class.flat_basis();
|
|
let mut body = Vec::new();
|
|
let mut base_index = 0;
|
|
for result_group in result_class.grouped_basis.iter() {
|
|
let size = result_group.len();
|
|
let (factors, a_indices): (Vec<_>, Vec<_>) = (0..size)
|
|
.map(|index_in_group| {
|
|
let result_element = &result_flat_basis[base_index + index_in_group];
|
|
let involution_element = involution
|
|
.terms
|
|
.iter()
|
|
.position(|(_in_element, out_element)| out_element.index == result_element.index)
|
|
.unwrap();
|
|
let (in_element, out_element) = &involution.terms[involution_element];
|
|
let index_in_a = a_flat_basis.iter().position(|a_element| a_element.index == in_element.index).unwrap();
|
|
(
|
|
out_element.scalar * result_element.scalar * in_element.scalar * a_flat_basis[index_in_a].scalar,
|
|
parameter_a.multi_vector_class().index_in_group(index_in_a),
|
|
)
|
|
})
|
|
.unzip();
|
|
let a_group_index = a_indices[0].0;
|
|
let expression = Expression {
|
|
size,
|
|
content: ExpressionContent::Multiply(
|
|
Box::new(Expression {
|
|
size,
|
|
content: ExpressionContent::Gather(
|
|
Box::new(Expression {
|
|
size: parameter_a.multi_vector_class().grouped_basis[a_group_index].len(),
|
|
content: ExpressionContent::Variable(parameter_a.name),
|
|
}),
|
|
a_indices,
|
|
),
|
|
}),
|
|
Box::new(Expression {
|
|
size,
|
|
content: ExpressionContent::Constant(DataType::SimdVector(size), factors),
|
|
}),
|
|
),
|
|
};
|
|
body.push((DataType::SimdVector(size), *simplify_and_legalize(Box::new(expression))));
|
|
base_index += size;
|
|
}
|
|
AstNode::TraitImplementation {
|
|
result: Parameter {
|
|
name,
|
|
data_type: DataType::MultiVector(result_class),
|
|
},
|
|
parameters: vec![parameter_a.clone()],
|
|
body: vec![AstNode::ReturnStatement {
|
|
expression: Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::InvokeClassMethod(result_class, "Constructor", body),
|
|
}),
|
|
}],
|
|
}
|
|
} else {
|
|
AstNode::None
|
|
}
|
|
}
|
|
|
|
pub fn sum<'a>(
|
|
name: &'static str,
|
|
parameter_a: &Parameter<'a>,
|
|
parameter_b: &Parameter<'a>,
|
|
registry: &'a MultiVectorClassRegistry,
|
|
) -> AstNode<'a> {
|
|
let a_flat_basis = parameter_a.multi_vector_class().flat_basis();
|
|
let b_flat_basis = parameter_b.multi_vector_class().flat_basis();
|
|
let result_signature = a_flat_basis
|
|
.iter()
|
|
.chain(b_flat_basis.iter())
|
|
.cloned()
|
|
.collect::<std::collections::HashSet<_>>();
|
|
let mut result_signature = result_signature.into_iter().map(|element| element.index).collect::<Vec<_>>();
|
|
result_signature.sort_unstable();
|
|
if let Some(result_class) = registry.get(&result_signature) {
|
|
let parameters = [(parameter_a, &a_flat_basis), (parameter_b, &b_flat_basis)];
|
|
let mut body = Vec::new();
|
|
for result_group in result_class.grouped_basis.iter() {
|
|
let size = result_group.len();
|
|
let mut expressions = parameters.iter().map(|(parameter, flat_basis)| {
|
|
let mut parameter_group_index = None;
|
|
let terms: Vec<_> = result_group
|
|
.iter()
|
|
.map(|result_element| {
|
|
if let Some(index_in_flat_basis) = flat_basis.iter().position(|element| element.index == result_element.index) {
|
|
let index_pair = parameter.multi_vector_class().index_in_group(index_in_flat_basis);
|
|
parameter_group_index = Some(index_pair.0);
|
|
(result_element.scalar * flat_basis[index_in_flat_basis].scalar, index_pair)
|
|
} else {
|
|
(0, (0, 0))
|
|
}
|
|
})
|
|
.collect();
|
|
Expression {
|
|
size,
|
|
content: ExpressionContent::Multiply(
|
|
Box::new(Expression {
|
|
size,
|
|
content: ExpressionContent::Gather(
|
|
Box::new(Expression {
|
|
size: if let Some(index) = parameter_group_index {
|
|
parameter.multi_vector_class().grouped_basis[index].len()
|
|
} else {
|
|
size
|
|
},
|
|
content: ExpressionContent::Variable(parameter.name),
|
|
}),
|
|
terms.iter().map(|(_factor, index_pair)| index_pair).cloned().collect(),
|
|
),
|
|
}),
|
|
Box::new(Expression {
|
|
size,
|
|
content: ExpressionContent::Constant(
|
|
DataType::SimdVector(size),
|
|
terms.iter().map(|(factor, _index_pair)| *factor).collect::<Vec<_>>(),
|
|
),
|
|
}),
|
|
),
|
|
}
|
|
});
|
|
body.push((
|
|
DataType::SimdVector(size),
|
|
*simplify_and_legalize(Box::new(Expression {
|
|
size,
|
|
content: match name {
|
|
"Add" => ExpressionContent::Add(Box::new(expressions.next().unwrap()), Box::new(expressions.next().unwrap())),
|
|
"Sub" => ExpressionContent::Subtract(Box::new(expressions.next().unwrap()), Box::new(expressions.next().unwrap())),
|
|
_ => unreachable!(),
|
|
},
|
|
})),
|
|
));
|
|
}
|
|
AstNode::TraitImplementation {
|
|
result: Parameter {
|
|
name,
|
|
data_type: DataType::MultiVector(result_class),
|
|
},
|
|
parameters: vec![parameter_a.clone(), parameter_b.clone()],
|
|
body: vec![AstNode::ReturnStatement {
|
|
expression: Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::InvokeClassMethod(result_class, "Constructor", body),
|
|
}),
|
|
}],
|
|
}
|
|
} else {
|
|
AstNode::None
|
|
}
|
|
}
|
|
|
|
pub fn product<'a>(
|
|
name: &'static str,
|
|
product: &Product,
|
|
parameter_a: &Parameter<'a>,
|
|
parameter_b: &Parameter<'a>,
|
|
registry: &'a MultiVectorClassRegistry,
|
|
) -> AstNode<'a> {
|
|
let a_flat_basis = parameter_a.multi_vector_class().flat_basis();
|
|
let b_flat_basis = parameter_b.multi_vector_class().flat_basis();
|
|
let mut result_signature = std::collections::HashSet::new();
|
|
for product_term in product.terms.iter() {
|
|
if a_flat_basis.iter().any(|e| e.index == product_term.factor_a.index)
|
|
&& b_flat_basis.iter().any(|e| e.index == product_term.factor_b.index)
|
|
{
|
|
result_signature.insert(product_term.product.index);
|
|
}
|
|
}
|
|
let mut result_signature = result_signature.into_iter().collect::<Vec<_>>();
|
|
result_signature.sort_unstable();
|
|
if let Some(result_class) = registry.get(&result_signature) {
|
|
let result_flat_basis = result_class.flat_basis();
|
|
let mut sorted_terms = vec![vec![(0, 0); a_flat_basis.len()]; result_flat_basis.len()];
|
|
for product_term in product.terms.iter() {
|
|
if let Some(y) = result_flat_basis.iter().position(|e| e.index == product_term.product.index) {
|
|
if let Some(x) = a_flat_basis.iter().position(|e| e.index == product_term.factor_a.index) {
|
|
if let Some(gather_index) = b_flat_basis.iter().position(|e| e.index == product_term.factor_b.index) {
|
|
sorted_terms[y][x] = (
|
|
result_flat_basis[y].scalar
|
|
* product_term.product.scalar
|
|
* a_flat_basis[x].scalar
|
|
* product_term.factor_a.scalar
|
|
* b_flat_basis[gather_index].scalar
|
|
* product_term.factor_b.scalar,
|
|
gather_index,
|
|
);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
let mut body = Vec::new();
|
|
let mut base_index = 0;
|
|
for result_group in result_class.grouped_basis.iter() {
|
|
let size = result_group.len();
|
|
let mut expression = Expression {
|
|
size,
|
|
content: ExpressionContent::None,
|
|
};
|
|
let result_terms = (0..size)
|
|
.map(|index_in_group| &sorted_terms[base_index + index_in_group])
|
|
.collect::<Vec<_>>();
|
|
let transposed_terms = (0..result_terms[0].len()).map(|i| result_terms.iter().map(|inner| inner[i]).collect::<Vec<_>>());
|
|
let mut contraction = (
|
|
Expression {
|
|
size,
|
|
content: ExpressionContent::None,
|
|
},
|
|
Expression {
|
|
size,
|
|
content: ExpressionContent::None,
|
|
},
|
|
vec![(0, 0); expression.size],
|
|
vec![(0, 0); expression.size],
|
|
vec![0; expression.size],
|
|
);
|
|
for (index_in_a, a_terms) in transposed_terms.enumerate() {
|
|
if a_terms.iter().all(|(factor, _)| *factor == 0) {
|
|
continue;
|
|
}
|
|
let (a_group_index, a_index_in_group) = parameter_a.multi_vector_class().index_in_group(index_in_a);
|
|
let a_indices = a_terms.iter().map(|_| (a_group_index, a_index_in_group)).collect::<Vec<_>>();
|
|
let b_indices = a_terms
|
|
.iter()
|
|
.map(|(_, index_in_b)| parameter_b.multi_vector_class().index_in_group(*index_in_b))
|
|
.collect::<Vec<_>>();
|
|
let non_zero_index = a_terms.iter().position(|(factor, _index_pair)| *factor != 0).unwrap();
|
|
let b_group_index = b_indices[non_zero_index].0;
|
|
let b_indices = a_terms
|
|
.iter()
|
|
.enumerate()
|
|
.map(|(index, (factor, _index_pair))| b_indices[if *factor == 0 { non_zero_index } else { index }])
|
|
.collect::<Vec<_>>();
|
|
let is_contractable = a_terms.iter().enumerate().all(|(i, (factor, _))| *factor == 0 || contraction.4[i] == 0)
|
|
&& (contraction.0.content == ExpressionContent::None
|
|
|| contraction.0.size == parameter_a.multi_vector_class().grouped_basis[a_group_index].len())
|
|
&& (contraction.1.content == ExpressionContent::None
|
|
|| contraction.1.size == parameter_b.multi_vector_class().grouped_basis[b_group_index].len());
|
|
if is_contractable && a_terms.iter().any(|(factor, _)| *factor == 0) {
|
|
if contraction.0.content == ExpressionContent::None {
|
|
assert!(contraction.1.content == ExpressionContent::None);
|
|
contraction.0 = Expression {
|
|
size: parameter_a.multi_vector_class().grouped_basis[a_group_index].len(),
|
|
content: ExpressionContent::Variable(parameter_a.name),
|
|
};
|
|
contraction.1 = Expression {
|
|
size: parameter_b.multi_vector_class().grouped_basis[b_group_index].len(),
|
|
content: ExpressionContent::Variable(parameter_b.name),
|
|
};
|
|
contraction.2 = a_indices.iter().map(|(a_group_index, _)| (*a_group_index, 0)).collect();
|
|
contraction.3 = b_indices.iter().map(|(b_group_index, _)| (*b_group_index, 0)).collect();
|
|
}
|
|
for (i, (factor, _index_in_b)) in a_terms.iter().enumerate() {
|
|
if *factor != 0 {
|
|
contraction.2[i] = a_indices[i];
|
|
contraction.3[i] = b_indices[i];
|
|
contraction.4[i] = *factor;
|
|
}
|
|
}
|
|
} else {
|
|
expression = Expression {
|
|
size,
|
|
content: ExpressionContent::Add(
|
|
Box::new(expression),
|
|
Box::new(Expression {
|
|
size,
|
|
content: ExpressionContent::Multiply(
|
|
Box::new(Expression {
|
|
size,
|
|
content: ExpressionContent::Gather(
|
|
Box::new(Expression {
|
|
size: parameter_a.multi_vector_class().grouped_basis[a_group_index].len(),
|
|
content: ExpressionContent::Variable(parameter_a.name),
|
|
}),
|
|
a_indices,
|
|
),
|
|
}),
|
|
Box::new(Expression {
|
|
size,
|
|
content: ExpressionContent::Multiply(
|
|
Box::new(Expression {
|
|
size,
|
|
content: ExpressionContent::Gather(
|
|
Box::new(Expression {
|
|
size: parameter_b.multi_vector_class().grouped_basis[b_group_index].len(),
|
|
content: ExpressionContent::Variable(parameter_b.name),
|
|
}),
|
|
b_indices,
|
|
),
|
|
}),
|
|
Box::new(Expression {
|
|
size,
|
|
content: ExpressionContent::Constant(
|
|
DataType::SimdVector(size),
|
|
a_terms.iter().map(|(factor, _)| *factor).collect::<Vec<_>>(),
|
|
),
|
|
}),
|
|
),
|
|
}),
|
|
),
|
|
}),
|
|
),
|
|
};
|
|
}
|
|
}
|
|
if contraction.4.iter().any(|scalar| *scalar != 0) {
|
|
expression = Expression {
|
|
size,
|
|
content: ExpressionContent::Add(
|
|
Box::new(expression),
|
|
Box::new(Expression {
|
|
size,
|
|
content: ExpressionContent::Multiply(
|
|
Box::new(Expression {
|
|
size,
|
|
content: ExpressionContent::Multiply(
|
|
Box::new(Expression {
|
|
size,
|
|
content: ExpressionContent::Gather(Box::new(contraction.0), contraction.2),
|
|
}),
|
|
Box::new(Expression {
|
|
size,
|
|
content: ExpressionContent::Gather(Box::new(contraction.1), contraction.3),
|
|
}),
|
|
),
|
|
}),
|
|
Box::new(Expression {
|
|
size,
|
|
content: ExpressionContent::Constant(DataType::SimdVector(size), contraction.4),
|
|
}),
|
|
),
|
|
}),
|
|
),
|
|
};
|
|
}
|
|
if expression.content == ExpressionContent::None {
|
|
expression = Expression {
|
|
size,
|
|
content: ExpressionContent::Constant(DataType::SimdVector(size), (0..size).map(|_| 0).collect()),
|
|
};
|
|
}
|
|
body.push((DataType::SimdVector(size), *simplify_and_legalize(Box::new(expression))));
|
|
base_index += size;
|
|
}
|
|
if body.is_empty() {
|
|
AstNode::None
|
|
} else {
|
|
AstNode::TraitImplementation {
|
|
result: Parameter {
|
|
name,
|
|
data_type: DataType::MultiVector(result_class),
|
|
},
|
|
parameters: vec![parameter_a.clone(), parameter_b.clone()],
|
|
body: vec![AstNode::ReturnStatement {
|
|
expression: Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::InvokeClassMethod(result_class, "Constructor", body),
|
|
}),
|
|
}],
|
|
}
|
|
}
|
|
} else {
|
|
AstNode::None
|
|
}
|
|
}
|
|
|
|
pub fn derive_multiplication<'a>(
|
|
name: &'static str,
|
|
geometric_product: &AstNode<'a>,
|
|
parameter_a: &Parameter<'a>,
|
|
parameter_b: &Parameter<'a>,
|
|
) -> AstNode<'a> {
|
|
let geometric_product_result = result_of_trait!(geometric_product);
|
|
AstNode::TraitImplementation {
|
|
result: Parameter {
|
|
name,
|
|
data_type: geometric_product_result.data_type.clone(),
|
|
},
|
|
parameters: vec![parameter_a.clone(), parameter_b.clone()],
|
|
body: vec![AstNode::ReturnStatement {
|
|
expression: Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::InvokeInstanceMethod(
|
|
parameter_a.data_type.clone(),
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Variable(parameter_a.name),
|
|
}),
|
|
geometric_product_result.name,
|
|
vec![(
|
|
DataType::MultiVector(parameter_b.multi_vector_class()),
|
|
Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Variable(parameter_b.name),
|
|
},
|
|
)],
|
|
),
|
|
}),
|
|
}],
|
|
}
|
|
}
|
|
|
|
pub fn derive_squared_magnitude<'a>(
|
|
name: &'static str,
|
|
scalar_product: &AstNode<'a>,
|
|
involution: &AstNode<'a>,
|
|
parameter_a: &Parameter<'a>,
|
|
) -> AstNode<'a> {
|
|
let scalar_product_result = result_of_trait!(scalar_product);
|
|
let involution_result = result_of_trait!(involution);
|
|
AstNode::TraitImplementation {
|
|
result: Parameter {
|
|
name,
|
|
data_type: scalar_product_result.data_type.clone(),
|
|
},
|
|
parameters: vec![parameter_a.clone()],
|
|
body: vec![AstNode::ReturnStatement {
|
|
expression: Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::InvokeInstanceMethod(
|
|
parameter_a.data_type.clone(),
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Variable(parameter_a.name),
|
|
}),
|
|
scalar_product_result.name,
|
|
vec![(
|
|
DataType::MultiVector(involution_result.multi_vector_class()),
|
|
Expression {
|
|
size: 1,
|
|
content: ExpressionContent::InvokeInstanceMethod(
|
|
parameter_a.data_type.clone(),
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Variable(parameter_a.name),
|
|
}),
|
|
involution_result.name,
|
|
vec![],
|
|
),
|
|
},
|
|
)],
|
|
),
|
|
}),
|
|
}],
|
|
}
|
|
}
|
|
|
|
pub fn derive_magnitude<'a>(name: &'static str, squared_magnitude: &AstNode<'a>, parameter_a: &Parameter<'a>) -> AstNode<'a> {
|
|
let squared_magnitude_result = result_of_trait!(squared_magnitude);
|
|
AstNode::TraitImplementation {
|
|
result: Parameter {
|
|
name,
|
|
data_type: squared_magnitude_result.data_type.clone(),
|
|
},
|
|
parameters: vec![parameter_a.clone()],
|
|
body: vec![AstNode::ReturnStatement {
|
|
expression: Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::InvokeClassMethod(
|
|
squared_magnitude_result.multi_vector_class(),
|
|
"Constructor",
|
|
vec![(
|
|
DataType::SimdVector(1),
|
|
Expression {
|
|
size: 1,
|
|
content: ExpressionContent::SquareRoot(Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Access(
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::InvokeInstanceMethod(
|
|
parameter_a.data_type.clone(),
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Variable(parameter_a.name),
|
|
}),
|
|
squared_magnitude_result.name,
|
|
vec![],
|
|
),
|
|
}),
|
|
0,
|
|
),
|
|
})),
|
|
},
|
|
)],
|
|
),
|
|
}),
|
|
}],
|
|
}
|
|
}
|
|
|
|
pub fn derive_signum<'a>(
|
|
name: &'static str,
|
|
geometric_product: &AstNode<'a>,
|
|
magnitude: &AstNode<'a>,
|
|
parameter_a: &Parameter<'a>,
|
|
) -> AstNode<'a> {
|
|
let geometric_product_result = result_of_trait!(geometric_product);
|
|
let magnitude_result = result_of_trait!(magnitude);
|
|
AstNode::TraitImplementation {
|
|
result: Parameter {
|
|
name,
|
|
data_type: geometric_product_result.data_type.clone(),
|
|
},
|
|
parameters: vec![parameter_a.clone()],
|
|
body: vec![AstNode::ReturnStatement {
|
|
expression: Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::InvokeInstanceMethod(
|
|
parameter_a.data_type.clone(),
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Variable(parameter_a.name),
|
|
}),
|
|
geometric_product_result.name,
|
|
vec![(
|
|
DataType::MultiVector(magnitude_result.multi_vector_class()),
|
|
Expression {
|
|
size: 1,
|
|
content: ExpressionContent::InvokeClassMethod(
|
|
magnitude_result.multi_vector_class(),
|
|
"Constructor",
|
|
vec![(
|
|
DataType::SimdVector(1),
|
|
Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Divide(
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Constant(DataType::SimdVector(1), vec![1]),
|
|
}),
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Access(
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::InvokeInstanceMethod(
|
|
parameter_a.data_type.clone(),
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Variable(parameter_a.name),
|
|
}),
|
|
magnitude_result.name,
|
|
vec![],
|
|
),
|
|
}),
|
|
0,
|
|
),
|
|
}),
|
|
),
|
|
},
|
|
)],
|
|
),
|
|
},
|
|
)],
|
|
),
|
|
}),
|
|
}],
|
|
}
|
|
}
|
|
|
|
pub fn derive_inverse<'a>(
|
|
name: &'static str,
|
|
geometric_product: &AstNode<'a>,
|
|
squared_magnitude: &AstNode<'a>,
|
|
involution: &AstNode<'a>,
|
|
parameter_a: &Parameter<'a>,
|
|
) -> AstNode<'a> {
|
|
let geometric_product_result = result_of_trait!(geometric_product);
|
|
let squared_magnitude_result = result_of_trait!(squared_magnitude);
|
|
let involution_result = result_of_trait!(involution);
|
|
AstNode::TraitImplementation {
|
|
result: Parameter {
|
|
name,
|
|
data_type: geometric_product_result.data_type.clone(),
|
|
},
|
|
parameters: vec![parameter_a.clone()],
|
|
body: vec![AstNode::ReturnStatement {
|
|
expression: Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::InvokeInstanceMethod(
|
|
involution_result.data_type.clone(),
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::InvokeInstanceMethod(
|
|
parameter_a.data_type.clone(),
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Variable(parameter_a.name),
|
|
}),
|
|
involution_result.name,
|
|
vec![],
|
|
),
|
|
}),
|
|
geometric_product_result.name,
|
|
vec![(
|
|
DataType::MultiVector(squared_magnitude_result.multi_vector_class()),
|
|
Expression {
|
|
size: 1,
|
|
content: ExpressionContent::InvokeClassMethod(
|
|
squared_magnitude_result.multi_vector_class(),
|
|
"Constructor",
|
|
vec![(
|
|
DataType::SimdVector(1),
|
|
Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Divide(
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Constant(DataType::SimdVector(1), vec![1]),
|
|
}),
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Access(
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::InvokeInstanceMethod(
|
|
parameter_a.data_type.clone(),
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Variable(parameter_a.name),
|
|
}),
|
|
squared_magnitude_result.name,
|
|
vec![],
|
|
),
|
|
}),
|
|
0,
|
|
),
|
|
}),
|
|
),
|
|
},
|
|
)],
|
|
),
|
|
},
|
|
)],
|
|
),
|
|
}),
|
|
}],
|
|
}
|
|
}
|
|
|
|
pub fn derive_power_of_integer<'a>(
|
|
name: &'static str,
|
|
geometric_product: &AstNode<'a>,
|
|
constant_one: &AstNode<'a>,
|
|
inverse: &AstNode<'a>,
|
|
parameter_a: &Parameter<'a>,
|
|
parameter_b: &Parameter<'a>,
|
|
) -> AstNode<'a> {
|
|
let geometric_product_result = result_of_trait!(geometric_product);
|
|
let constant_one_result = result_of_trait!(constant_one);
|
|
let inverse_result = result_of_trait!(inverse);
|
|
AstNode::TraitImplementation {
|
|
result: Parameter {
|
|
name,
|
|
data_type: parameter_a.data_type.clone(),
|
|
},
|
|
parameters: vec![parameter_a.clone(), parameter_b.clone()],
|
|
body: vec![
|
|
AstNode::IfThenBlock {
|
|
condition: Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Equal(
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Variable(parameter_b.name),
|
|
}),
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Constant(DataType::Integer, vec![0]),
|
|
}),
|
|
),
|
|
}),
|
|
body: vec![AstNode::ReturnStatement {
|
|
expression: Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::InvokeClassMethod(parameter_a.multi_vector_class(), constant_one_result.name, vec![]),
|
|
}),
|
|
}],
|
|
},
|
|
AstNode::VariableAssignment {
|
|
name: "x",
|
|
data_type: Some(parameter_a.data_type.clone()),
|
|
expression: Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Select(
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::LessThan(
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Variable(parameter_b.name),
|
|
}),
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Constant(DataType::Integer, vec![0]),
|
|
}),
|
|
),
|
|
}),
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::InvokeInstanceMethod(
|
|
parameter_a.data_type.clone(),
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Variable(parameter_a.name),
|
|
}),
|
|
inverse_result.name,
|
|
vec![],
|
|
),
|
|
}),
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Variable(parameter_a.name),
|
|
}),
|
|
),
|
|
}),
|
|
},
|
|
AstNode::VariableAssignment {
|
|
name: "y",
|
|
data_type: Some(parameter_a.data_type.clone()),
|
|
expression: Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::InvokeClassMethod(parameter_a.multi_vector_class(), constant_one_result.name, vec![]),
|
|
}),
|
|
},
|
|
AstNode::VariableAssignment {
|
|
name: "n",
|
|
data_type: Some(DataType::Integer),
|
|
expression: Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::InvokeInstanceMethod(
|
|
DataType::Integer,
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Variable(parameter_b.name),
|
|
}),
|
|
"Abs",
|
|
vec![],
|
|
),
|
|
}),
|
|
},
|
|
AstNode::WhileLoopBlock {
|
|
condition: Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::LessThan(
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Constant(DataType::Integer, vec![1]),
|
|
}),
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Variable("n"),
|
|
}),
|
|
),
|
|
}),
|
|
body: vec![
|
|
AstNode::IfThenBlock {
|
|
condition: Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Equal(
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::LogicAnd(
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Variable("n"),
|
|
}),
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Constant(DataType::Integer, vec![1]),
|
|
}),
|
|
),
|
|
}),
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Constant(DataType::Integer, vec![1]),
|
|
}),
|
|
),
|
|
}),
|
|
body: vec![AstNode::VariableAssignment {
|
|
name: "y",
|
|
data_type: None,
|
|
expression: Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::InvokeInstanceMethod(
|
|
parameter_a.data_type.clone(),
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Variable("x"),
|
|
}),
|
|
geometric_product_result.name,
|
|
vec![(
|
|
DataType::MultiVector(parameter_a.multi_vector_class()),
|
|
Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Variable("y"),
|
|
},
|
|
)],
|
|
),
|
|
}),
|
|
}],
|
|
},
|
|
AstNode::VariableAssignment {
|
|
name: "x",
|
|
data_type: None,
|
|
expression: Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::InvokeInstanceMethod(
|
|
parameter_a.data_type.clone(),
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Variable("x"),
|
|
}),
|
|
geometric_product_result.name,
|
|
vec![(
|
|
DataType::MultiVector(parameter_a.multi_vector_class()),
|
|
Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Variable("x"),
|
|
},
|
|
)],
|
|
),
|
|
}),
|
|
},
|
|
AstNode::VariableAssignment {
|
|
name: "n",
|
|
data_type: None,
|
|
expression: Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::BitShiftRight(
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Variable("n"),
|
|
}),
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Constant(DataType::Integer, vec![1]),
|
|
}),
|
|
),
|
|
}),
|
|
},
|
|
],
|
|
},
|
|
AstNode::ReturnStatement {
|
|
expression: Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::InvokeInstanceMethod(
|
|
parameter_a.data_type.clone(),
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Variable("x"),
|
|
}),
|
|
geometric_product_result.name,
|
|
vec![(
|
|
DataType::MultiVector(parameter_a.multi_vector_class()),
|
|
Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Variable("y"),
|
|
},
|
|
)],
|
|
),
|
|
}),
|
|
},
|
|
],
|
|
}
|
|
}
|
|
|
|
pub fn derive_division<'a>(
|
|
name: &'static str,
|
|
geometric_product: &AstNode<'a>,
|
|
inverse: &AstNode<'a>,
|
|
parameter_a: &Parameter<'a>,
|
|
parameter_b: &Parameter<'a>,
|
|
) -> AstNode<'a> {
|
|
let geometric_product_result = result_of_trait!(geometric_product);
|
|
let inverse_result = result_of_trait!(inverse);
|
|
AstNode::TraitImplementation {
|
|
result: Parameter {
|
|
name,
|
|
data_type: geometric_product_result.data_type.clone(),
|
|
},
|
|
parameters: vec![parameter_a.clone(), parameter_b.clone()],
|
|
body: vec![AstNode::ReturnStatement {
|
|
expression: Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::InvokeInstanceMethod(
|
|
parameter_a.data_type.clone(),
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Variable(parameter_a.name),
|
|
}),
|
|
geometric_product_result.name,
|
|
vec![(
|
|
DataType::MultiVector(inverse_result.multi_vector_class()),
|
|
Expression {
|
|
size: 1,
|
|
content: ExpressionContent::InvokeInstanceMethod(
|
|
parameter_b.data_type.clone(),
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Variable(parameter_b.name),
|
|
}),
|
|
inverse_result.name,
|
|
vec![],
|
|
),
|
|
},
|
|
)],
|
|
),
|
|
}),
|
|
}],
|
|
}
|
|
}
|
|
|
|
pub fn derive_sandwich_product<'a>(
|
|
name: &'static str,
|
|
geometric_product: &AstNode<'a>,
|
|
geometric_product_2: &AstNode<'a>,
|
|
involution: &AstNode<'a>,
|
|
conversion: Option<&AstNode<'a>>,
|
|
parameter_a: &Parameter<'a>,
|
|
parameter_b: &Parameter<'a>,
|
|
) -> AstNode<'a> {
|
|
let geometric_product_result = result_of_trait!(geometric_product);
|
|
let geometric_product_2_result = result_of_trait!(geometric_product_2);
|
|
let involution_result = result_of_trait!(involution);
|
|
let product = Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::InvokeInstanceMethod(
|
|
geometric_product_result.data_type.clone(),
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::InvokeInstanceMethod(
|
|
parameter_a.data_type.clone(),
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Variable(parameter_a.name),
|
|
}),
|
|
geometric_product_result.name,
|
|
vec![(
|
|
DataType::MultiVector(parameter_b.multi_vector_class()),
|
|
Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Variable(parameter_b.name),
|
|
},
|
|
)],
|
|
),
|
|
}),
|
|
geometric_product_2_result.name,
|
|
vec![(
|
|
DataType::MultiVector(involution_result.multi_vector_class()),
|
|
Expression {
|
|
size: 1,
|
|
content: match name {
|
|
"Reflection" => ExpressionContent::Variable(parameter_a.name),
|
|
"Transformation" => ExpressionContent::InvokeInstanceMethod(
|
|
parameter_a.data_type.clone(),
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Variable(parameter_a.name),
|
|
}),
|
|
involution_result.name,
|
|
vec![],
|
|
),
|
|
_ => unreachable!(),
|
|
},
|
|
},
|
|
)],
|
|
),
|
|
});
|
|
let conversion_result = if let Some(conversion) = conversion {
|
|
result_of_trait!(conversion)
|
|
} else {
|
|
geometric_product_2_result
|
|
};
|
|
AstNode::TraitImplementation {
|
|
result: Parameter {
|
|
name,
|
|
data_type: conversion_result.data_type.clone(),
|
|
},
|
|
parameters: vec![parameter_a.clone(), parameter_b.clone()],
|
|
body: vec![AstNode::ReturnStatement {
|
|
expression: if conversion.is_some() {
|
|
Box::new(Expression {
|
|
size: 1,
|
|
content: ExpressionContent::Conversion(
|
|
geometric_product_2_result.multi_vector_class(),
|
|
conversion_result.multi_vector_class(),
|
|
product,
|
|
),
|
|
})
|
|
} else {
|
|
product
|
|
},
|
|
}],
|
|
}
|
|
}
|
|
}
|