Adds Scale trait.

This commit is contained in:
Alexander Meißner 2022-11-26 15:50:07 +01:00
parent cdebda2e29
commit 09948e3c3e
5 changed files with 74 additions and 18 deletions

View file

@ -637,6 +637,58 @@ impl MultiVectorClass {
}
}
pub fn derive_scale<'a>(
name: &'static str,
geometric_product: &AstNode<'a>,
parameter_a: &Parameter<'a>,
parameter_b: &Parameter<'a>,
) -> AstNode<'a> {
let geometric_product_result = result_of_trait!(geometric_product);
AstNode::TraitImplementation {
result: Parameter {
name,
data_type: geometric_product_result.data_type.clone(),
},
parameters: vec![
parameter_a.clone(),
Parameter {
name: "other",
data_type: DataType::SimdVector(1),
},
],
body: vec![AstNode::ReturnStatement {
expression: Box::new(Expression {
size: 1,
content: ExpressionContent::InvokeInstanceMethod(
parameter_a.data_type.clone(),
Box::new(Expression {
size: 1,
content: ExpressionContent::Variable(parameter_a.name),
}),
geometric_product_result.name,
vec![(
DataType::MultiVector(parameter_b.multi_vector_class()),
Expression {
size: 1,
content: ExpressionContent::InvokeClassMethod(
parameter_b.multi_vector_class(),
"Constructor",
vec![(
DataType::SimdVector(1),
Expression {
size: 1,
content: ExpressionContent::Variable(parameter_b.name),
},
)],
),
},
)],
),
}),
}],
}
}
pub fn derive_magnitude<'a>(name: &'static str, squared_magnitude: &AstNode<'a>, parameter_a: &Parameter<'a>) -> AstNode<'a> {
let squared_magnitude_result = result_of_trait!(squared_magnitude);
AstNode::TraitImplementation {

View file

@ -8,7 +8,8 @@ const COMPONENT: &[&str] = &["x", "y", "z", "w"];
fn emit_data_type<W: std::io::Write>(collector: &mut W, data_type: &DataType) -> std::io::Result<()> {
match data_type {
DataType::Integer => collector.write_all(b"int"),
DataType::SimdVector(size) => collector.write_fmt(format_args!("vec{}", size)),
DataType::SimdVector(size) if *size == 1 => collector.write_all(b"float"),
DataType::SimdVector(size) => collector.write_fmt(format_args!("vec{}", *size)),
DataType::MultiVector(class) => collector.write_fmt(format_args!("{}", class.class_name)),
}
}
@ -223,7 +224,9 @@ pub fn emit_code<W: std::io::Write>(collector: &mut W, ast_node: &AstNode, inden
camel_to_snake_case(collector, &result.multi_vector_class().class_name)?;
}
1 => camel_to_snake_case(collector, &parameters[0].multi_vector_class().class_name)?,
2 if result.name == "Powi" => camel_to_snake_case(collector, &parameters[0].multi_vector_class().class_name)?,
2 if !matches!(parameters[1].data_type, DataType::MultiVector(_)) => {
camel_to_snake_case(collector, &parameters[0].multi_vector_class().class_name)?
}
2 => {
camel_to_snake_case(collector, &parameters[0].multi_vector_class().class_name)?;
collector.write_all(b"_")?;

View file

@ -133,6 +133,8 @@ fn main() {
for (parameter_b, pair_trait_implementations) in pair_trait_implementations.values() {
if let Some(geometric_product) = pair_trait_implementations.get("GeometricProduct") {
if parameter_b.multi_vector_class().grouped_basis == vec![vec![BasisElement::from_index(0)]] {
let scale = MultiVectorClass::derive_scale("Scale", &geometric_product, &parameter_a, &parameter_b);
emitter.emit(&scale).unwrap();
if let Some(magnitude) = single_trait_implementations.get("Magnitude") {
let signum = MultiVectorClass::derive_signum("Signum", &geometric_product, &magnitude, &parameter_a);
emitter.emit(&signum).unwrap();

View file

@ -6,7 +6,8 @@ use crate::{
fn emit_data_type<W: std::io::Write>(collector: &mut W, data_type: &DataType) -> std::io::Result<()> {
match data_type {
DataType::Integer => collector.write_all(b"isize"),
DataType::SimdVector(size) => collector.write_fmt(format_args!("Simd32x{}", size)),
DataType::SimdVector(size) if *size == 1 => collector.write_all(b"f32"),
DataType::SimdVector(size) => collector.write_fmt(format_args!("Simd32x{}", *size)),
DataType::MultiVector(class) => collector.write_fmt(format_args!("{}", class.class_name)),
}
}
@ -487,7 +488,7 @@ pub fn emit_code<W: std::io::Write>(collector: &mut W, ast_node: &AstNode, inden
parameters[0].multi_vector_class().class_name,
))?,
1 => collector.write_fmt(format_args!("impl {} for {}", result.name, parameters[0].multi_vector_class().class_name))?,
2 if result.name == "Powi" => {
2 if !matches!(parameters[1].data_type, DataType::MultiVector(_)) => {
collector.write_fmt(format_args!("impl {} for {}", result.name, parameters[0].multi_vector_class().class_name))?
}
2 => collector.write_fmt(format_args!(

View file

@ -80,7 +80,7 @@ impl Ln for ppga2d::Translator {
fn ln(self) -> ppga2d::IdealPoint {
let result: ppga2d::IdealPoint = self.into();
result.geometric_product(ppga2d::Scalar::from([1.0 / self[0]]))
result.scale(1.0 / self[0])
}
}
@ -88,9 +88,7 @@ impl Powf for ppga2d::Translator {
type Output = Self;
fn powf(self, exponent: f32) -> Self {
self.ln()
.geometric_product(ppga2d::Scalar::from([exponent]))
.exp()
self.ln().scale(exponent).exp()
}
}
@ -129,9 +127,7 @@ impl Powf for ppga2d::Motor {
type Output = Self;
fn powf(self, exponent: f32) -> Self {
self.ln()
.geometric_product(ppga2d::Scalar::from([exponent]))
.exp()
self.ln().scale(exponent).exp()
}
}
@ -148,7 +144,7 @@ impl Ln for ppga3d::Translator {
fn ln(self) -> ppga3d::IdealPoint {
let result: ppga3d::IdealPoint = self.into();
result.geometric_product(ppga3d::Scalar::from([1.0 / self[0]]))
result.scale(1.0 / self[0])
}
}
@ -156,9 +152,7 @@ impl Powf for ppga3d::Translator {
type Output = Self;
fn powf(self, exponent: f32) -> Self {
self.ln()
.geometric_product(ppga3d::Scalar::from([exponent]))
.exp()
self.ln().scale(exponent).exp()
}
}
@ -202,9 +196,7 @@ impl Powf for ppga3d::Motor {
type Output = Self;
fn powf(self, exponent: f32) -> Self {
self.ln()
.geometric_product(ppga3d::Scalar::from([exponent]))
.exp()
self.ln().scale(exponent).exp()
}
}
@ -308,6 +300,12 @@ pub trait Transformation<T> {
fn transformation(self, other: T) -> Self::Output;
}
/// Geometric product with a scalar
pub trait Scale {
type Output;
fn scale(self, other: f32) -> Self::Output;
}
/// Square of the magnitude
pub trait SquaredMagnitude {
type Output;