diff --git a/codegen/src/compile.rs b/codegen/src/compile.rs index d47af00..a3e95ec 100644 --- a/codegen/src/compile.rs +++ b/codegen/src/compile.rs @@ -637,6 +637,58 @@ impl MultiVectorClass { } } + pub fn derive_scale<'a>( + name: &'static str, + geometric_product: &AstNode<'a>, + parameter_a: &Parameter<'a>, + parameter_b: &Parameter<'a>, + ) -> AstNode<'a> { + let geometric_product_result = result_of_trait!(geometric_product); + AstNode::TraitImplementation { + result: Parameter { + name, + data_type: geometric_product_result.data_type.clone(), + }, + parameters: vec![ + parameter_a.clone(), + Parameter { + name: "other", + data_type: DataType::SimdVector(1), + }, + ], + body: vec![AstNode::ReturnStatement { + expression: Box::new(Expression { + size: 1, + content: ExpressionContent::InvokeInstanceMethod( + parameter_a.data_type.clone(), + Box::new(Expression { + size: 1, + content: ExpressionContent::Variable(parameter_a.name), + }), + geometric_product_result.name, + vec![( + DataType::MultiVector(parameter_b.multi_vector_class()), + Expression { + size: 1, + content: ExpressionContent::InvokeClassMethod( + parameter_b.multi_vector_class(), + "Constructor", + vec![( + DataType::SimdVector(1), + Expression { + size: 1, + content: ExpressionContent::Variable(parameter_b.name), + }, + )], + ), + }, + )], + ), + }), + }], + } + } + pub fn derive_magnitude<'a>(name: &'static str, squared_magnitude: &AstNode<'a>, parameter_a: &Parameter<'a>) -> AstNode<'a> { let squared_magnitude_result = result_of_trait!(squared_magnitude); AstNode::TraitImplementation { diff --git a/codegen/src/glsl.rs b/codegen/src/glsl.rs index d66b066..beaeb34 100644 --- a/codegen/src/glsl.rs +++ b/codegen/src/glsl.rs @@ -8,7 +8,8 @@ const COMPONENT: &[&str] = &["x", "y", "z", "w"]; fn emit_data_type(collector: &mut W, data_type: &DataType) -> std::io::Result<()> { match data_type { DataType::Integer => collector.write_all(b"int"), - DataType::SimdVector(size) => collector.write_fmt(format_args!("vec{}", size)), + DataType::SimdVector(size) if *size == 1 => collector.write_all(b"float"), + DataType::SimdVector(size) => collector.write_fmt(format_args!("vec{}", *size)), DataType::MultiVector(class) => collector.write_fmt(format_args!("{}", class.class_name)), } } @@ -223,7 +224,9 @@ pub fn emit_code(collector: &mut W, ast_node: &AstNode, inden camel_to_snake_case(collector, &result.multi_vector_class().class_name)?; } 1 => camel_to_snake_case(collector, ¶meters[0].multi_vector_class().class_name)?, - 2 if result.name == "Powi" => camel_to_snake_case(collector, ¶meters[0].multi_vector_class().class_name)?, + 2 if !matches!(parameters[1].data_type, DataType::MultiVector(_)) => { + camel_to_snake_case(collector, ¶meters[0].multi_vector_class().class_name)? + } 2 => { camel_to_snake_case(collector, ¶meters[0].multi_vector_class().class_name)?; collector.write_all(b"_")?; diff --git a/codegen/src/main.rs b/codegen/src/main.rs index e714eef..773b2f7 100644 --- a/codegen/src/main.rs +++ b/codegen/src/main.rs @@ -133,6 +133,8 @@ fn main() { for (parameter_b, pair_trait_implementations) in pair_trait_implementations.values() { if let Some(geometric_product) = pair_trait_implementations.get("GeometricProduct") { if parameter_b.multi_vector_class().grouped_basis == vec![vec![BasisElement::from_index(0)]] { + let scale = MultiVectorClass::derive_scale("Scale", &geometric_product, ¶meter_a, ¶meter_b); + emitter.emit(&scale).unwrap(); if let Some(magnitude) = single_trait_implementations.get("Magnitude") { let signum = MultiVectorClass::derive_signum("Signum", &geometric_product, &magnitude, ¶meter_a); emitter.emit(&signum).unwrap(); diff --git a/codegen/src/rust.rs b/codegen/src/rust.rs index 33e6807..60564ac 100644 --- a/codegen/src/rust.rs +++ b/codegen/src/rust.rs @@ -6,7 +6,8 @@ use crate::{ fn emit_data_type(collector: &mut W, data_type: &DataType) -> std::io::Result<()> { match data_type { DataType::Integer => collector.write_all(b"isize"), - DataType::SimdVector(size) => collector.write_fmt(format_args!("Simd32x{}", size)), + DataType::SimdVector(size) if *size == 1 => collector.write_all(b"f32"), + DataType::SimdVector(size) => collector.write_fmt(format_args!("Simd32x{}", *size)), DataType::MultiVector(class) => collector.write_fmt(format_args!("{}", class.class_name)), } } @@ -487,7 +488,7 @@ pub fn emit_code(collector: &mut W, ast_node: &AstNode, inden parameters[0].multi_vector_class().class_name, ))?, 1 => collector.write_fmt(format_args!("impl {} for {}", result.name, parameters[0].multi_vector_class().class_name))?, - 2 if result.name == "Powi" => { + 2 if !matches!(parameters[1].data_type, DataType::MultiVector(_)) => { collector.write_fmt(format_args!("impl {} for {}", result.name, parameters[0].multi_vector_class().class_name))? } 2 => collector.write_fmt(format_args!( diff --git a/src/lib.rs b/src/lib.rs index c406cc2..20825ec 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -80,7 +80,7 @@ impl Ln for ppga2d::Translator { fn ln(self) -> ppga2d::IdealPoint { let result: ppga2d::IdealPoint = self.into(); - result.geometric_product(ppga2d::Scalar::from([1.0 / self[0]])) + result.scale(1.0 / self[0]) } } @@ -88,9 +88,7 @@ impl Powf for ppga2d::Translator { type Output = Self; fn powf(self, exponent: f32) -> Self { - self.ln() - .geometric_product(ppga2d::Scalar::from([exponent])) - .exp() + self.ln().scale(exponent).exp() } } @@ -129,9 +127,7 @@ impl Powf for ppga2d::Motor { type Output = Self; fn powf(self, exponent: f32) -> Self { - self.ln() - .geometric_product(ppga2d::Scalar::from([exponent])) - .exp() + self.ln().scale(exponent).exp() } } @@ -148,7 +144,7 @@ impl Ln for ppga3d::Translator { fn ln(self) -> ppga3d::IdealPoint { let result: ppga3d::IdealPoint = self.into(); - result.geometric_product(ppga3d::Scalar::from([1.0 / self[0]])) + result.scale(1.0 / self[0]) } } @@ -156,9 +152,7 @@ impl Powf for ppga3d::Translator { type Output = Self; fn powf(self, exponent: f32) -> Self { - self.ln() - .geometric_product(ppga3d::Scalar::from([exponent])) - .exp() + self.ln().scale(exponent).exp() } } @@ -202,9 +196,7 @@ impl Powf for ppga3d::Motor { type Output = Self; fn powf(self, exponent: f32) -> Self { - self.ln() - .geometric_product(ppga3d::Scalar::from([exponent])) - .exp() + self.ln().scale(exponent).exp() } } @@ -308,6 +300,12 @@ pub trait Transformation { fn transformation(self, other: T) -> Self::Output; } +/// Geometric product with a scalar +pub trait Scale { + type Output; + fn scale(self, other: f32) -> Self::Output; +} + /// Square of the magnitude pub trait SquaredMagnitude { type Output;