Adds element wise multiplication as Mul and division as Div.

This commit is contained in:
Alexander Meißner 2022-11-26 14:26:39 +01:00
parent e9061a1105
commit 340ef20738
3 changed files with 63 additions and 2 deletions

View file

@ -287,7 +287,7 @@ impl MultiVectorClass {
} }
} }
pub fn sum<'a>( pub fn element_wise<'a>(
name: &'static str, name: &'static str,
parameter_a: &Parameter<'a>, parameter_a: &Parameter<'a>,
parameter_b: &Parameter<'a>, parameter_b: &Parameter<'a>,
@ -355,6 +355,8 @@ impl MultiVectorClass {
content: match name { content: match name {
"Add" => ExpressionContent::Add(Box::new(expressions.next().unwrap()), Box::new(expressions.next().unwrap())), "Add" => ExpressionContent::Add(Box::new(expressions.next().unwrap()), Box::new(expressions.next().unwrap())),
"Sub" => ExpressionContent::Subtract(Box::new(expressions.next().unwrap()), Box::new(expressions.next().unwrap())), "Sub" => ExpressionContent::Subtract(Box::new(expressions.next().unwrap()), Box::new(expressions.next().unwrap())),
"Mul" => ExpressionContent::Multiply(Box::new(expressions.next().unwrap()), Box::new(expressions.next().unwrap())),
"Div" => ExpressionContent::Divide(Box::new(expressions.next().unwrap()), Box::new(expressions.next().unwrap())),
_ => unreachable!(), _ => unreachable!(),
}, },
})), })),

View file

@ -97,12 +97,21 @@ fn main() {
} }
} }
for name in &["Add", "Sub"] { for name in &["Add", "Sub"] {
let ast_node = MultiVectorClass::sum(*name, &parameter_a, &parameter_b, &registry); let ast_node = MultiVectorClass::element_wise(*name, &parameter_a, &parameter_b, &registry);
emitter.emit(&ast_node).unwrap(); emitter.emit(&ast_node).unwrap();
if ast_node != AstNode::None { if ast_node != AstNode::None {
trait_implementations.insert(name.to_string(), ast_node); trait_implementations.insert(name.to_string(), ast_node);
} }
} }
if class_a == class_b {
for name in &["Mul", "Div"] {
let ast_node = MultiVectorClass::element_wise(*name, &parameter_a, &parameter_b, &registry);
emitter.emit(&ast_node).unwrap();
if ast_node != AstNode::None {
trait_implementations.insert(name.to_string(), ast_node);
}
}
}
for (name, product) in products.iter() { for (name, product) in products.iter() {
let ast_node = MultiVectorClass::product(name, product, &parameter_a, &parameter_b, &registry); let ast_node = MultiVectorClass::product(name, product, &parameter_a, &parameter_b, &registry);
emitter.emit(&ast_node).unwrap(); emitter.emit(&ast_node).unwrap();

View file

@ -412,3 +412,53 @@ impl std::ops::Mul<Simd32x2> for Simd32x2 {
) )
} }
} }
impl std::ops::Div<Simd32x4> for Simd32x4 {
type Output = Simd32x4;
fn div(self, other: Self) -> Self {
match_architecture!(
Self,
{ f128: _mm_div_ps(self.f128, other.f128) },
{ f128: vdivq_f32(self.f128, other.f128) },
{ v128: f32x4_div(self.v128, other.v128) },
{ f32x4: [
self.f32x4[0] / other.f32x4[0],
self.f32x4[1] / other.f32x4[1],
self.f32x4[2] / other.f32x4[2],
self.f32x4[3] / other.f32x4[3],
] },
)
}
}
impl std::ops::Div<Simd32x3> for Simd32x3 {
type Output = Simd32x3;
fn div(self, other: Self) -> Self {
match_architecture!(
Self,
{ v32x4: unsafe { self.v32x4 / other.v32x4 } },
{ f32x3: [
self.f32x3[0] / other.f32x3[0],
self.f32x3[1] / other.f32x3[1],
self.f32x3[2] / other.f32x3[2],
] },
)
}
}
impl std::ops::Div<Simd32x2> for Simd32x2 {
type Output = Simd32x2;
fn div(self, other: Self) -> Self {
match_architecture!(
Self,
{ v32x4: unsafe { self.v32x4 / other.v32x4 } },
{ f32x2: [
self.f32x2[0] / other.f32x2[0],
self.f32x2[1] / other.f32x2[1],
] },
)
}
}