Code cleanup.

This commit is contained in:
Alexander Meißner 2022-11-26 15:54:47 +01:00
parent 09948e3c3e
commit e9061a1105
5 changed files with 83 additions and 102 deletions

View file

@ -70,7 +70,7 @@ impl BasisElement {
scalar: self.scalar, scalar: self.scalar,
index: algebra.basis_size() as BasisElementIndex - 1 - self.index, index: algebra.basis_size() as BasisElementIndex - 1 - self.index,
}; };
result.scalar *= BasisElement::product(&self, &result, &algebra).scalar; result.scalar *= BasisElement::product(self, &result, algebra).scalar;
result result
} }
@ -186,7 +186,7 @@ impl Involution {
} }
} }
#[derive(Clone, PartialEq)] #[derive(Clone, PartialEq, Eq)]
pub struct ProductTerm { pub struct ProductTerm {
pub product: BasisElement, pub product: BasisElement,
pub factor_a: BasisElement, pub factor_a: BasisElement,
@ -199,18 +199,17 @@ pub struct Product {
} }
impl Product { impl Product {
pub fn product(a: &[BasisElement], b: &[BasisElement], algebra: &GeometricAlgebra) -> Self { pub fn new(a: &[BasisElement], b: &[BasisElement], algebra: &GeometricAlgebra) -> Self {
Self { Self {
terms: a terms: a
.iter() .iter()
.map(|a| { .flat_map(|a| {
b.iter().map(move |b| ProductTerm { b.iter().map(move |b| ProductTerm {
product: BasisElement::product(&a, &b, algebra), product: BasisElement::product(a, b, algebra),
factor_a: a.clone(), factor_a: a.clone(),
factor_b: b.clone(), factor_b: b.clone(),
}) })
}) })
.flatten()
.filter(|term| term.product.scalar != 0) .filter(|term| term.product.scalar != 0)
.collect(), .collect(),
} }
@ -246,12 +245,12 @@ impl Product {
pub fn products(algebra: &GeometricAlgebra) -> Vec<(&'static str, Self)> { pub fn products(algebra: &GeometricAlgebra) -> Vec<(&'static str, Self)> {
let basis = algebra.basis().collect::<Vec<_>>(); let basis = algebra.basis().collect::<Vec<_>>();
let product = Self::product(&basis, &basis, algebra); let product = Self::new(&basis, &basis, algebra);
vec![ vec![
("GeometricProduct", product.clone()), ("GeometricProduct", product.clone()),
("RegressiveProduct", product.projected(|r, s, t| t == r + s).dual(algebra)), ("RegressiveProduct", product.projected(|r, s, t| t == r + s).dual(algebra)),
("OuterProduct", product.projected(|r, s, t| t == r + s)), ("OuterProduct", product.projected(|r, s, t| t == r + s)),
("InnerProduct", product.projected(|r, s, t| t == (r as isize - s as isize).abs() as usize)), ("InnerProduct", product.projected(|r, s, t| t == (r as isize - s as isize).unsigned_abs())),
("LeftContraction", product.projected(|r, s, t| t as isize == s as isize - r as isize)), ("LeftContraction", product.projected(|r, s, t| t as isize == s as isize - r as isize)),
("RightContraction", product.projected(|r, s, t| t as isize == r as isize - s as isize)), ("RightContraction", product.projected(|r, s, t| t as isize == r as isize - s as isize)),
("ScalarProduct", product.projected(|_r, _s, t| t == 0)), ("ScalarProduct", product.projected(|_r, _s, t| t == 0)),

View file

@ -35,7 +35,7 @@ fn emit_expression<W: std::io::Write>(collector: &mut W, expression: &Expression
} }
camel_to_snake_case(collector, method_name)?; camel_to_snake_case(collector, method_name)?;
collector.write_all(b"(")?; collector.write_all(b"(")?;
emit_expression(collector, &inner_expression)?; emit_expression(collector, inner_expression)?;
if !arguments.is_empty() { if !arguments.is_empty() {
collector.write_all(b", ")?; collector.write_all(b", ")?;
} }
@ -56,7 +56,7 @@ fn emit_expression<W: std::io::Write>(collector: &mut W, expression: &Expression
if i > 0 { if i > 0 {
collector.write_all(b", ")?; collector.write_all(b", ")?;
} }
emit_expression(collector, &argument)?; emit_expression(collector, argument)?;
} }
collector.write_all(b")")?; collector.write_all(b")")?;
} }
@ -65,23 +65,23 @@ fn emit_expression<W: std::io::Write>(collector: &mut W, expression: &Expression
collector.write_all(b"_")?; collector.write_all(b"_")?;
camel_to_snake_case(collector, &destination_class.class_name)?; camel_to_snake_case(collector, &destination_class.class_name)?;
collector.write_all(b"_into(")?; collector.write_all(b"_into(")?;
emit_expression(collector, &inner_expression)?; emit_expression(collector, inner_expression)?;
collector.write_all(b")")?; collector.write_all(b")")?;
} }
ExpressionContent::Select(condition_expression, then_expression, else_expression) => { ExpressionContent::Select(condition_expression, then_expression, else_expression) => {
collector.write_all(b"(")?; collector.write_all(b"(")?;
emit_expression(collector, &condition_expression)?; emit_expression(collector, condition_expression)?;
collector.write_all(b") ? ")?; collector.write_all(b") ? ")?;
emit_expression(collector, &then_expression)?; emit_expression(collector, then_expression)?;
collector.write_all(b" : ")?; collector.write_all(b" : ")?;
emit_expression(collector, &else_expression)?; emit_expression(collector, else_expression)?;
} }
ExpressionContent::Access(inner_expression, array_index) => { ExpressionContent::Access(inner_expression, array_index) => {
emit_expression(collector, &inner_expression)?; emit_expression(collector, inner_expression)?;
collector.write_fmt(format_args!(".g{}", array_index))?; collector.write_fmt(format_args!(".g{}", array_index))?;
} }
ExpressionContent::Swizzle(inner_expression, indices) => { ExpressionContent::Swizzle(inner_expression, indices) => {
emit_expression(collector, &inner_expression)?; emit_expression(collector, inner_expression)?;
collector.write_all(b".")?; collector.write_all(b".")?;
for component_index in indices.iter() { for component_index in indices.iter() {
collector.write_all(COMPONENT[*component_index].bytes().collect::<Vec<_>>().as_slice())?; collector.write_all(COMPONENT[*component_index].bytes().collect::<Vec<_>>().as_slice())?;
@ -89,13 +89,14 @@ fn emit_expression<W: std::io::Write>(collector: &mut W, expression: &Expression
} }
ExpressionContent::Gather(inner_expression, indices) => { ExpressionContent::Gather(inner_expression, indices) => {
if expression.size > 1 { if expression.size > 1 {
collector.write_fmt(format_args!("vec{}(", expression.size))?; emit_data_type(collector, &DataType::SimdVector(expression.size))?;
collector.write_all(b"(")?;
} }
for (i, (array_index, component_index)) in indices.iter().enumerate() { for (i, (array_index, component_index)) in indices.iter().enumerate() {
if i > 0 { if i > 0 {
collector.write_all(b", ")?; collector.write_all(b", ")?;
} }
emit_expression(collector, &inner_expression)?; emit_expression(collector, inner_expression)?;
collector.write_fmt(format_args!(".g{}", array_index))?; collector.write_fmt(format_args!(".g{}", array_index))?;
if inner_expression.size > 1 { if inner_expression.size > 1 {
collector.write_fmt(format_args!(".{}", COMPONENT[*component_index]))?; collector.write_fmt(format_args!(".{}", COMPONENT[*component_index]))?;
@ -111,9 +112,9 @@ fn emit_expression<W: std::io::Write>(collector: &mut W, expression: &Expression
if expression.size == 1 { if expression.size == 1 {
collector.write_fmt(format_args!("{:.1}", values[0] as f32))? collector.write_fmt(format_args!("{:.1}", values[0] as f32))?
} else { } else {
emit_data_type(collector, &DataType::SimdVector(expression.size))?;
collector.write_fmt(format_args!( collector.write_fmt(format_args!(
"vec{}({})", "({})",
expression.size,
values.iter().map(|value| format!("{:.1}", *value as f32)).collect::<Vec<_>>().join(", ") values.iter().map(|value| format!("{:.1}", *value as f32)).collect::<Vec<_>>().join(", ")
))? ))?
} }
@ -122,7 +123,7 @@ fn emit_expression<W: std::io::Write>(collector: &mut W, expression: &Expression
}, },
ExpressionContent::SquareRoot(inner_expression) => { ExpressionContent::SquareRoot(inner_expression) => {
collector.write_all(b"sqrt(")?; collector.write_all(b"sqrt(")?;
emit_expression(collector, &inner_expression)?; emit_expression(collector, inner_expression)?;
collector.write_all(b")")?; collector.write_all(b")")?;
} }
ExpressionContent::Add(lhs, rhs) ExpressionContent::Add(lhs, rhs)
@ -136,7 +137,7 @@ fn emit_expression<W: std::io::Write>(collector: &mut W, expression: &Expression
if let ExpressionContent::LogicAnd(_, _) = expression.content { if let ExpressionContent::LogicAnd(_, _) = expression.content {
collector.write_all(b"(")?; collector.write_all(b"(")?;
} }
emit_expression(collector, &lhs)?; emit_expression(collector, lhs)?;
collector.write_all(match expression.content { collector.write_all(match expression.content {
ExpressionContent::Add(_, _) => b" + ", ExpressionContent::Add(_, _) => b" + ",
ExpressionContent::Subtract(_, _) => b" - ", ExpressionContent::Subtract(_, _) => b" - ",
@ -148,7 +149,7 @@ fn emit_expression<W: std::io::Write>(collector: &mut W, expression: &Expression
ExpressionContent::BitShiftRight(_, _) => b" >> ", ExpressionContent::BitShiftRight(_, _) => b" >> ",
_ => unreachable!(), _ => unreachable!(),
})?; })?;
emit_expression(collector, &rhs)?; emit_expression(collector, rhs)?;
if let ExpressionContent::LogicAnd(_, _) = expression.content { if let ExpressionContent::LogicAnd(_, _) = expression.content {
collector.write_all(b")")?; collector.write_all(b")")?;
} }
@ -174,11 +175,7 @@ pub fn emit_code<W: std::io::Write>(collector: &mut W, ast_node: &AstNode, inden
} }
collector.write_all(b"\n")?; collector.write_all(b"\n")?;
emit_indentation(collector, indentation + 1)?; emit_indentation(collector, indentation + 1)?;
if group.len() == 1 { emit_data_type(collector, &DataType::SimdVector(group.len()))?;
collector.write_all(b"float")?;
} else {
collector.write_fmt(format_args!("vec{}", group.len()))?;
}
collector.write_fmt(format_args!(" g{};\n", i))?; collector.write_fmt(format_args!(" g{};\n", i))?;
} }
emit_indentation(collector, indentation)?; emit_indentation(collector, indentation)?;

View file

@ -33,7 +33,7 @@ fn main() {
let basis = algebra.sorted_basis(); let basis = algebra.sorted_basis();
for b in basis.iter() { for b in basis.iter() {
for a in basis.iter() { for a in basis.iter() {
print!("{:1$} ", BasisElement::product(&a, &b, &algebra), generator_squares.len() + 2); print!("{:1$} ", BasisElement::product(a, b, &algebra), generator_squares.len() + 2);
} }
println!(); println!();
} }
@ -75,7 +75,7 @@ fn main() {
} }
} }
for (name, involution) in involutions.iter() { for (name, involution) in involutions.iter() {
let ast_node = MultiVectorClass::involution(name, &involution, &parameter_a, &registry, false); let ast_node = MultiVectorClass::involution(name, involution, &parameter_a, &registry, false);
emitter.emit(&ast_node).unwrap(); emitter.emit(&ast_node).unwrap();
if ast_node != AstNode::None { if ast_node != AstNode::None {
single_trait_implementations.insert(name.to_string(), ast_node); single_trait_implementations.insert(name.to_string(), ast_node);
@ -90,7 +90,7 @@ fn main() {
}; };
if class_a != class_b { if class_a != class_b {
let name = "Into"; let name = "Into";
let ast_node = MultiVectorClass::involution(name, &Involution::projection(&class_b), &parameter_a, &registry, true); let ast_node = MultiVectorClass::involution(name, &Involution::projection(class_b), &parameter_a, &registry, true);
emitter.emit(&ast_node).unwrap(); emitter.emit(&ast_node).unwrap();
if ast_node != AstNode::None { if ast_node != AstNode::None {
trait_implementations.insert(name.to_string(), ast_node); trait_implementations.insert(name.to_string(), ast_node);
@ -104,7 +104,7 @@ fn main() {
} }
} }
for (name, product) in products.iter() { for (name, product) in products.iter() {
let ast_node = MultiVectorClass::product(name, &product, &parameter_a, &parameter_b, &registry); let ast_node = MultiVectorClass::product(name, product, &parameter_a, &parameter_b, &registry);
emitter.emit(&ast_node).unwrap(); emitter.emit(&ast_node).unwrap();
if ast_node != AstNode::None { if ast_node != AstNode::None {
trait_implementations.insert(name.to_string(), ast_node); trait_implementations.insert(name.to_string(), ast_node);
@ -120,7 +120,7 @@ fn main() {
if let Some(reversal) = single_trait_implementations.get("Reversal") { if let Some(reversal) = single_trait_implementations.get("Reversal") {
if parameter_a.multi_vector_class() == parameter_b.multi_vector_class() { if parameter_a.multi_vector_class() == parameter_b.multi_vector_class() {
let squared_magnitude = let squared_magnitude =
MultiVectorClass::derive_squared_magnitude("SquaredMagnitude", &scalar_product, &reversal, &parameter_a); MultiVectorClass::derive_squared_magnitude("SquaredMagnitude", scalar_product, reversal, &parameter_a);
emitter.emit(&squared_magnitude).unwrap(); emitter.emit(&squared_magnitude).unwrap();
let magnitude = MultiVectorClass::derive_magnitude("Magnitude", &squared_magnitude, &parameter_a); let magnitude = MultiVectorClass::derive_magnitude("Magnitude", &squared_magnitude, &parameter_a);
emitter.emit(&magnitude).unwrap(); emitter.emit(&magnitude).unwrap();
@ -133,17 +133,16 @@ fn main() {
for (parameter_b, pair_trait_implementations) in pair_trait_implementations.values() { for (parameter_b, pair_trait_implementations) in pair_trait_implementations.values() {
if let Some(geometric_product) = pair_trait_implementations.get("GeometricProduct") { if let Some(geometric_product) = pair_trait_implementations.get("GeometricProduct") {
if parameter_b.multi_vector_class().grouped_basis == vec![vec![BasisElement::from_index(0)]] { if parameter_b.multi_vector_class().grouped_basis == vec![vec![BasisElement::from_index(0)]] {
let scale = MultiVectorClass::derive_scale("Scale", &geometric_product, &parameter_a, &parameter_b); let scale = MultiVectorClass::derive_scale("Scale", geometric_product, &parameter_a, parameter_b);
emitter.emit(&scale).unwrap(); emitter.emit(&scale).unwrap();
if let Some(magnitude) = single_trait_implementations.get("Magnitude") { if let Some(magnitude) = single_trait_implementations.get("Magnitude") {
let signum = MultiVectorClass::derive_signum("Signum", &geometric_product, &magnitude, &parameter_a); let signum = MultiVectorClass::derive_signum("Signum", geometric_product, magnitude, &parameter_a);
emitter.emit(&signum).unwrap(); emitter.emit(&signum).unwrap();
single_trait_implementations.insert(result_of_trait!(signum).name.to_string(), signum); single_trait_implementations.insert(result_of_trait!(signum).name.to_string(), signum);
} }
if let Some(squared_magnitude) = single_trait_implementations.get("SquaredMagnitude") { if let Some(squared_magnitude) = single_trait_implementations.get("SquaredMagnitude") {
if let Some(reversal) = single_trait_implementations.get("Reversal") { if let Some(reversal) = single_trait_implementations.get("Reversal") {
let inverse = let inverse = MultiVectorClass::derive_inverse("Inverse", geometric_product, squared_magnitude, reversal, &parameter_a);
MultiVectorClass::derive_inverse("Inverse", &geometric_product, &squared_magnitude, &reversal, &parameter_a);
emitter.emit(&inverse).unwrap(); emitter.emit(&inverse).unwrap();
single_trait_implementations.insert(result_of_trait!(inverse).name.to_string(), inverse); single_trait_implementations.insert(result_of_trait!(inverse).name.to_string(), inverse);
} }
@ -167,10 +166,10 @@ fn main() {
if let Some(inverse) = single_trait_implementations.get("Inverse") { if let Some(inverse) = single_trait_implementations.get("Inverse") {
let power_of_integer = MultiVectorClass::derive_power_of_integer( let power_of_integer = MultiVectorClass::derive_power_of_integer(
"Powi", "Powi",
&geometric_product, geometric_product,
&constant_one, constant_one,
&inverse, inverse,
&parameter_a, parameter_a,
&Parameter { &Parameter {
name: "exponent", name: "exponent",
data_type: DataType::Integer, data_type: DataType::Integer,
@ -182,8 +181,7 @@ fn main() {
} }
if let Some(b_trait_implementations) = trait_implementations.get(&parameter_b.multi_vector_class().class_name) { if let Some(b_trait_implementations) = trait_implementations.get(&parameter_b.multi_vector_class().class_name) {
if let Some(inverse) = b_trait_implementations.1.get("Inverse") { if let Some(inverse) = b_trait_implementations.1.get("Inverse") {
let division = let division = MultiVectorClass::derive_division("GeometricQuotient", geometric_product, inverse, parameter_a, parameter_b);
MultiVectorClass::derive_division("GeometricQuotient", &geometric_product, &inverse, &parameter_a, &parameter_b);
emitter.emit(&division).unwrap(); emitter.emit(&division).unwrap();
} }
} }
@ -200,12 +198,12 @@ fn main() {
{ {
let transformation = MultiVectorClass::derive_sandwich_product( let transformation = MultiVectorClass::derive_sandwich_product(
"Transformation", "Transformation",
&geometric_product, geometric_product,
&geometric_product_2, geometric_product_2,
&reversal, reversal,
c_pair_trait_implementations.1.get("Into"), c_pair_trait_implementations.1.get("Into"),
&parameter_a, parameter_a,
&parameter_b, parameter_b,
); );
emitter.emit(&transformation).unwrap(); emitter.emit(&transformation).unwrap();
} }

View file

@ -21,7 +21,7 @@ fn emit_expression<W: std::io::Write>(collector: &mut W, expression: &Expression
ExpressionContent::InvokeClassMethod(_, method_name, arguments) | ExpressionContent::InvokeInstanceMethod(_, _, method_name, arguments) => { ExpressionContent::InvokeClassMethod(_, method_name, arguments) | ExpressionContent::InvokeInstanceMethod(_, _, method_name, arguments) => {
match &expression.content { match &expression.content {
ExpressionContent::InvokeInstanceMethod(_result_class, inner_expression, _, _) => { ExpressionContent::InvokeInstanceMethod(_result_class, inner_expression, _, _) => {
emit_expression(collector, &inner_expression)?; emit_expression(collector, inner_expression)?;
collector.write_all(b".")?; collector.write_all(b".")?;
} }
ExpressionContent::InvokeClassMethod(class, _, _) => { ExpressionContent::InvokeClassMethod(class, _, _) => {
@ -44,7 +44,7 @@ fn emit_expression<W: std::io::Write>(collector: &mut W, expression: &Expression
if *method_name == "Constructor" { if *method_name == "Constructor" {
collector.write_fmt(format_args!("g{}: ", i))?; collector.write_fmt(format_args!("g{}: ", i))?;
} }
emit_expression(collector, &argument)?; emit_expression(collector, argument)?;
} }
if *method_name == "Constructor" { if *method_name == "Constructor" {
collector.write_all(b" } }")?; collector.write_all(b" } }")?;
@ -53,31 +53,31 @@ fn emit_expression<W: std::io::Write>(collector: &mut W, expression: &Expression
} }
} }
ExpressionContent::Conversion(_source_class, _destination_class, inner_expression) => { ExpressionContent::Conversion(_source_class, _destination_class, inner_expression) => {
emit_expression(collector, &inner_expression)?; emit_expression(collector, inner_expression)?;
collector.write_all(b".into()")?; collector.write_all(b".into()")?;
} }
ExpressionContent::Select(condition_expression, then_expression, else_expression) => { ExpressionContent::Select(condition_expression, then_expression, else_expression) => {
collector.write_all(b"if ")?; collector.write_all(b"if ")?;
emit_expression(collector, &condition_expression)?; emit_expression(collector, condition_expression)?;
collector.write_all(b" { ")?; collector.write_all(b" { ")?;
emit_expression(collector, &then_expression)?; emit_expression(collector, then_expression)?;
collector.write_all(b" } else { ")?; collector.write_all(b" } else { ")?;
emit_expression(collector, &else_expression)?; emit_expression(collector, else_expression)?;
collector.write_all(b" }")?; collector.write_all(b" }")?;
} }
ExpressionContent::Access(inner_expression, array_index) => { ExpressionContent::Access(inner_expression, array_index) => {
emit_expression(collector, &inner_expression)?; emit_expression(collector, inner_expression)?;
collector.write_fmt(format_args!(".group{}()", array_index))?; collector.write_fmt(format_args!(".group{}()", array_index))?;
} }
ExpressionContent::Swizzle(inner_expression, indices) => { ExpressionContent::Swizzle(inner_expression, indices) => {
if expression.size == 1 { if expression.size == 1 {
emit_expression(collector, &inner_expression)?; emit_expression(collector, inner_expression)?;
if inner_expression.size > 1 { if inner_expression.size > 1 {
collector.write_fmt(format_args!("[{}]", indices[0]))?; collector.write_fmt(format_args!("[{}]", indices[0]))?;
} }
} else { } else {
collector.write_all(b"swizzle!(")?; collector.write_all(b"swizzle!(")?;
emit_expression(collector, &inner_expression)?; emit_expression(collector, inner_expression)?;
collector.write_all(b", ")?; collector.write_all(b", ")?;
for (i, component_index) in indices.iter().enumerate() { for (i, component_index) in indices.iter().enumerate() {
if i > 0 { if i > 0 {
@ -90,7 +90,8 @@ fn emit_expression<W: std::io::Write>(collector: &mut W, expression: &Expression
} }
ExpressionContent::Gather(inner_expression, indices) => { ExpressionContent::Gather(inner_expression, indices) => {
if expression.size > 1 { if expression.size > 1 {
collector.write_fmt(format_args!("Simd32x{}::from(", expression.size))?; emit_data_type(collector, &DataType::SimdVector(expression.size))?;
collector.write_all(b"::from(")?;
} }
if indices.len() > 1 { if indices.len() > 1 {
collector.write_all(b"[")?; collector.write_all(b"[")?;
@ -99,7 +100,7 @@ fn emit_expression<W: std::io::Write>(collector: &mut W, expression: &Expression
if i > 0 { if i > 0 {
collector.write_all(b", ")?; collector.write_all(b", ")?;
} }
emit_expression(collector, &inner_expression)?; emit_expression(collector, inner_expression)?;
collector.write_fmt(format_args!(".group{}()", array_index))?; collector.write_fmt(format_args!(".group{}()", array_index))?;
if inner_expression.size > 1 { if inner_expression.size > 1 {
collector.write_fmt(format_args!("[{}]", *component_index))?; collector.write_fmt(format_args!("[{}]", *component_index))?;
@ -118,7 +119,8 @@ fn emit_expression<W: std::io::Write>(collector: &mut W, expression: &Expression
if expression.size == 1 { if expression.size == 1 {
collector.write_fmt(format_args!("{:.1}", values[0] as f32))?; collector.write_fmt(format_args!("{:.1}", values[0] as f32))?;
} else { } else {
collector.write_fmt(format_args!("Simd32x{}::from(", expression.size))?; emit_data_type(collector, &DataType::SimdVector(expression.size))?;
collector.write_all(b"::from(")?;
if values.len() > 1 { if values.len() > 1 {
collector.write_all(b"[")?; collector.write_all(b"[")?;
} }
@ -137,7 +139,7 @@ fn emit_expression<W: std::io::Write>(collector: &mut W, expression: &Expression
_ => unreachable!(), _ => unreachable!(),
}, },
ExpressionContent::SquareRoot(inner_expression) => { ExpressionContent::SquareRoot(inner_expression) => {
emit_expression(collector, &inner_expression)?; emit_expression(collector, inner_expression)?;
collector.write_all(b".sqrt()")?; collector.write_all(b".sqrt()")?;
} }
ExpressionContent::Add(lhs, rhs) ExpressionContent::Add(lhs, rhs)
@ -148,7 +150,7 @@ fn emit_expression<W: std::io::Write>(collector: &mut W, expression: &Expression
| ExpressionContent::Equal(lhs, rhs) | ExpressionContent::Equal(lhs, rhs)
| ExpressionContent::LogicAnd(lhs, rhs) | ExpressionContent::LogicAnd(lhs, rhs)
| ExpressionContent::BitShiftRight(lhs, rhs) => { | ExpressionContent::BitShiftRight(lhs, rhs) => {
emit_expression(collector, &lhs)?; emit_expression(collector, lhs)?;
collector.write_all(match expression.content { collector.write_all(match expression.content {
ExpressionContent::Add(_, _) => b" + ", ExpressionContent::Add(_, _) => b" + ",
ExpressionContent::Subtract(_, _) => b" - ", ExpressionContent::Subtract(_, _) => b" - ",
@ -160,7 +162,7 @@ fn emit_expression<W: std::io::Write>(collector: &mut W, expression: &Expression
ExpressionContent::BitShiftRight(_, _) => b" >> ", ExpressionContent::BitShiftRight(_, _) => b" >> ",
_ => unreachable!(), _ => unreachable!(),
})?; })?;
emit_expression(collector, &rhs)?; emit_expression(collector, rhs)?;
} }
} }
Ok(()) Ok(())
@ -210,14 +212,9 @@ pub fn emit_code<W: std::io::Write>(collector: &mut W, ast_node: &AstNode, inden
collector.write_all(b"\n")?; collector.write_all(b"\n")?;
emit_indentation(collector, indentation + 1)?; emit_indentation(collector, indentation + 1)?;
collector.write_fmt(format_args!("g{}: ", j))?; collector.write_fmt(format_args!("g{}: ", j))?;
simd_widths.push(if group.len() == 1 { emit_data_type(collector, &DataType::SimdVector(group.len()))?;
collector.write_all(b"f32")?;
1
} else {
collector.write_fmt(format_args!("Simd32x{}", group.len()))?;
4
});
collector.write_all(b",\n")?; collector.write_all(b",\n")?;
simd_widths.push(if group.len() == 1 { 1 } else { 4 });
} }
collector.write_all(b"}\n\n")?; collector.write_all(b"}\n\n")?;
emit_indentation(collector, indentation)?; emit_indentation(collector, indentation)?;
@ -245,6 +242,8 @@ pub fn emit_code<W: std::io::Write>(collector: &mut W, ast_node: &AstNode, inden
emit_indentation(collector, indentation)?; emit_indentation(collector, indentation)?;
collector.write_fmt(format_args!("impl {} {{\n", class.class_name))?; collector.write_fmt(format_args!("impl {} {{\n", class.class_name))?;
emit_indentation(collector, indentation + 1)?; emit_indentation(collector, indentation + 1)?;
collector.write_all(b"#[allow(clippy::too_many_arguments)]\n")?;
emit_indentation(collector, indentation + 1)?;
collector.write_all(b"pub const fn new(")?; collector.write_all(b"pub const fn new(")?;
for i in 0..element_count { for i in 0..element_count {
if i > 0 { if i > 0 {
@ -278,11 +277,7 @@ pub fn emit_code<W: std::io::Write>(collector: &mut W, ast_node: &AstNode, inden
collector.write_all(b", ")?; collector.write_all(b", ")?;
} }
collector.write_fmt(format_args!("g{}: ", j))?; collector.write_fmt(format_args!("g{}: ", j))?;
if group.len() == 1 { emit_data_type(collector, &DataType::SimdVector(group.len()))?;
collector.write_all(b"f32")?;
} else {
collector.write_fmt(format_args!("Simd32x{}", group.len()))?;
}
} }
collector.write_all(b") -> Self {\n")?; collector.write_all(b") -> Self {\n")?;
emit_indentation(collector, indentation + 2)?; emit_indentation(collector, indentation + 2)?;
@ -301,11 +296,7 @@ pub fn emit_code<W: std::io::Write>(collector: &mut W, ast_node: &AstNode, inden
collector.write_all(b"#[inline(always)]\n")?; collector.write_all(b"#[inline(always)]\n")?;
emit_indentation(collector, indentation + 1)?; emit_indentation(collector, indentation + 1)?;
collector.write_fmt(format_args!("pub fn group{}(&self) -> ", j))?; collector.write_fmt(format_args!("pub fn group{}(&self) -> ", j))?;
if group.len() == 1 { emit_data_type(collector, &DataType::SimdVector(group.len()))?;
collector.write_all(b"f32")?;
} else {
collector.write_fmt(format_args!("Simd32x{}", group.len()))?;
}
collector.write_all(b" {\n")?; collector.write_all(b" {\n")?;
emit_indentation(collector, indentation + 2)?; emit_indentation(collector, indentation + 2)?;
collector.write_fmt(format_args!("unsafe {{ self.groups.g{} }}\n", j))?; collector.write_fmt(format_args!("unsafe {{ self.groups.g{} }}\n", j))?;
@ -315,11 +306,7 @@ pub fn emit_code<W: std::io::Write>(collector: &mut W, ast_node: &AstNode, inden
collector.write_all(b"#[inline(always)]\n")?; collector.write_all(b"#[inline(always)]\n")?;
emit_indentation(collector, indentation + 1)?; emit_indentation(collector, indentation + 1)?;
collector.write_fmt(format_args!("pub fn group{}_mut(&mut self) -> &mut ", j))?; collector.write_fmt(format_args!("pub fn group{}_mut(&mut self) -> &mut ", j))?;
if group.len() == 1 { emit_data_type(collector, &DataType::SimdVector(group.len()))?;
collector.write_all(b"f32")?;
} else {
collector.write_fmt(format_args!("Simd32x{}", group.len()))?;
}
collector.write_all(b" {\n")?; collector.write_all(b" {\n")?;
emit_indentation(collector, indentation + 2)?; emit_indentation(collector, indentation + 2)?;
collector.write_fmt(format_args!("unsafe {{ &mut self.groups.g{} }}\n", j))?; collector.write_fmt(format_args!("unsafe {{ &mut self.groups.g{} }}\n", j))?;
@ -385,11 +372,11 @@ pub fn emit_code<W: std::io::Write>(collector: &mut W, ast_node: &AstNode, inden
collector.write_fmt(format_args!("fn from(vector: {}) -> Self {{\n", class.class_name))?; collector.write_fmt(format_args!("fn from(vector: {}) -> Self {{\n", class.class_name))?;
emit_indentation(collector, indentation + 2)?; emit_indentation(collector, indentation + 2)?;
collector.write_all(b"unsafe { [")?; collector.write_all(b"unsafe { [")?;
for i in 0..element_count { for (i, remapped) in index_remap.iter().enumerate() {
if i > 0 { if i > 0 {
collector.write_all(b", ")?; collector.write_all(b", ")?;
} }
collector.write_fmt(format_args!("vector.elements[{}]", index_remap[i]))?; collector.write_fmt(format_args!("vector.elements[{}]", remapped))?;
} }
collector.write_all(b"] }\n")?; collector.write_all(b"] }\n")?;
emit_indentation(collector, indentation + 1)?; emit_indentation(collector, indentation + 1)?;
@ -537,7 +524,7 @@ pub fn emit_code<W: std::io::Write>(collector: &mut W, ast_node: &AstNode, inden
collector.write_all(b"}\n}\n\n")?; collector.write_all(b"}\n}\n\n")?;
match result.name { match result.name {
"Add" | "Sub" | "Mul" | "Div" => { "Add" | "Sub" | "Mul" | "Div" => {
emit_assign_trait(collector, result, &parameters)?; emit_assign_trait(collector, result, parameters)?;
} }
_ => {} _ => {}
} }

View file

@ -18,9 +18,9 @@ impl epga1d::Scalar {
pub fn sqrt(self) -> epga1d::ComplexNumber { pub fn sqrt(self) -> epga1d::ComplexNumber {
if self[0] < 0.0 { if self[0] < 0.0 {
epga1d::ComplexNumber::from([0.0, (-self[0]).sqrt()]) epga1d::ComplexNumber::new(0.0, (-self[0]).sqrt())
} else { } else {
epga1d::ComplexNumber::from([self[0].sqrt(), 0.0]) epga1d::ComplexNumber::new(self[0].sqrt(), 0.0)
} }
} }
} }
@ -35,7 +35,7 @@ impl epga1d::ComplexNumber {
} }
pub fn from_polar(magnitude: f32, argument: f32) -> Self { pub fn from_polar(magnitude: f32, argument: f32) -> Self {
Self::from([magnitude * argument.cos(), magnitude * argument.sin()]) Self::new(magnitude * argument.cos(), magnitude * argument.sin())
} }
pub fn arg(self) -> f32 { pub fn arg(self) -> f32 {
@ -55,7 +55,7 @@ impl Ln for epga1d::ComplexNumber {
type Output = Self; type Output = Self;
fn ln(self) -> Self { fn ln(self) -> Self {
Self::from([self.magnitude()[0].ln(), self.arg()]) Self::new(self.magnitude()[0].ln(), self.arg())
} }
} }
@ -71,7 +71,7 @@ impl Exp for ppga2d::IdealPoint {
type Output = ppga2d::Translator; type Output = ppga2d::Translator;
fn exp(self) -> ppga2d::Translator { fn exp(self) -> ppga2d::Translator {
ppga2d::Translator::from([1.0, self[0], self[1]]) ppga2d::Translator::new(1.0, self[0], self[1])
} }
} }
@ -98,13 +98,13 @@ impl Exp for ppga2d::Point {
fn exp(self) -> ppga2d::Motor { fn exp(self) -> ppga2d::Motor {
let det = self[0] * self[0]; let det = self[0] * self[0];
if det <= 0.0 { if det <= 0.0 {
return ppga2d::Motor::from([1.0, 0.0, self[1], self[2]]); return ppga2d::Motor::new(1.0, 0.0, self[1], self[2]);
} }
let a = det.sqrt(); let a = det.sqrt();
let c = a.cos(); let c = a.cos();
let s = a.sin() / a; let s = a.sin() / a;
let g0 = simd::Simd32x3::from(s) * self.group0(); let g0 = simd::Simd32x3::from(s) * self.group0();
ppga2d::Motor::from([c, g0[0], g0[1], g0[2]]) ppga2d::Motor::new(c, g0[0], g0[1], g0[2])
} }
} }
@ -114,12 +114,12 @@ impl Ln for ppga2d::Motor {
fn ln(self) -> ppga2d::Point { fn ln(self) -> ppga2d::Point {
let det = 1.0 - self[0] * self[0]; let det = 1.0 - self[0] * self[0];
if det <= 0.0 { if det <= 0.0 {
return ppga2d::Point::from([0.0, self[2], self[3]]); return ppga2d::Point::new(0.0, self[2], self[3]);
} }
let a = 1.0 / det; let a = 1.0 / det;
let b = self[0].acos() * a.sqrt(); let b = self[0].acos() * a.sqrt();
let g0 = simd::Simd32x4::from(b) * self.group0(); let g0 = simd::Simd32x4::from(b) * self.group0();
return ppga2d::Point::from([g0[1], g0[2], g0[3]]); ppga2d::Point::new(g0[1], g0[2], g0[3])
} }
} }
@ -135,7 +135,7 @@ impl Exp for ppga3d::IdealPoint {
type Output = ppga3d::Translator; type Output = ppga3d::Translator;
fn exp(self) -> ppga3d::Translator { fn exp(self) -> ppga3d::Translator {
ppga3d::Translator::from([1.0, self[0], self[1], self[2]]) ppga3d::Translator::new(1.0, self[0], self[1], self[2])
} }
} }
@ -162,7 +162,7 @@ impl Exp for ppga3d::Line {
fn exp(self) -> ppga3d::Motor { fn exp(self) -> ppga3d::Motor {
let det = self[3] * self[3] + self[4] * self[4] + self[5] * self[5]; let det = self[3] * self[3] + self[4] * self[4] + self[5] * self[5];
if det <= 0.0 { if det <= 0.0 {
return ppga3d::Motor::from([1.0, 0.0, 0.0, 0.0, 0.0, self[0], self[1], self[2]]); return ppga3d::Motor::new(1.0, 0.0, 0.0, 0.0, 0.0, self[0], self[1], self[2]);
} }
let a = det.sqrt(); let a = det.sqrt();
let c = a.cos(); let c = a.cos();
@ -171,7 +171,7 @@ impl Exp for ppga3d::Line {
let t = m / det * (c - s); let t = m / det * (c - s);
let g0 = simd::Simd32x3::from(s) * self.group1(); let g0 = simd::Simd32x3::from(s) * self.group1();
let g1 = simd::Simd32x3::from(s) * self.group0() + simd::Simd32x3::from(t) * self.group1(); let g1 = simd::Simd32x3::from(s) * self.group0() + simd::Simd32x3::from(t) * self.group1();
ppga3d::Motor::from([c, g0[0], g0[1], g0[2], s * m, g1[0], g1[1], g1[2]]) ppga3d::Motor::new(c, g0[0], g0[1], g0[2], s * m, g1[0], g1[1], g1[2])
} }
} }
@ -181,14 +181,14 @@ impl Ln for ppga3d::Motor {
fn ln(self) -> ppga3d::Line { fn ln(self) -> ppga3d::Line {
let det = 1.0 - self[0] * self[0]; let det = 1.0 - self[0] * self[0];
if det <= 0.0 { if det <= 0.0 {
return ppga3d::Line::from([self[5], self[6], self[7], 0.0, 0.0, 0.0]); return ppga3d::Line::new(self[5], self[6], self[7], 0.0, 0.0, 0.0);
} }
let a = 1.0 / det; let a = 1.0 / det;
let b = self[0].acos() * a.sqrt(); let b = self[0].acos() * a.sqrt();
let c = a * self[4] * (1.0 - self[0] * b); let c = a * self[4] * (1.0 - self[0] * b);
let g0 = simd::Simd32x4::from(b) * self.group1() + simd::Simd32x4::from(c) * self.group0(); let g0 = simd::Simd32x4::from(b) * self.group1() + simd::Simd32x4::from(c) * self.group0();
let g1 = simd::Simd32x4::from(b) * self.group0(); let g1 = simd::Simd32x4::from(b) * self.group0();
return ppga3d::Line::from([g0[1], g0[2], g0[3], g1[1], g1[2], g1[3]]); ppga3d::Line::new(g0[1], g0[2], g0[3], g1[1], g1[2], g1[3])
} }
} }