From 340ef207385cae8c67d1fe9bc293285fde447ba7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexander=20Mei=C3=9Fner?= Date: Sat, 26 Nov 2022 14:26:39 +0100 Subject: [PATCH] Adds element wise multiplication as Mul and division as Div. --- codegen/src/compile.rs | 4 +++- codegen/src/main.rs | 11 +++++++++- src/simd.rs | 50 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 63 insertions(+), 2 deletions(-) diff --git a/codegen/src/compile.rs b/codegen/src/compile.rs index a3e95ec..a3148e2 100644 --- a/codegen/src/compile.rs +++ b/codegen/src/compile.rs @@ -287,7 +287,7 @@ impl MultiVectorClass { } } - pub fn sum<'a>( + pub fn element_wise<'a>( name: &'static str, parameter_a: &Parameter<'a>, parameter_b: &Parameter<'a>, @@ -355,6 +355,8 @@ impl MultiVectorClass { content: match name { "Add" => ExpressionContent::Add(Box::new(expressions.next().unwrap()), Box::new(expressions.next().unwrap())), "Sub" => ExpressionContent::Subtract(Box::new(expressions.next().unwrap()), Box::new(expressions.next().unwrap())), + "Mul" => ExpressionContent::Multiply(Box::new(expressions.next().unwrap()), Box::new(expressions.next().unwrap())), + "Div" => ExpressionContent::Divide(Box::new(expressions.next().unwrap()), Box::new(expressions.next().unwrap())), _ => unreachable!(), }, })), diff --git a/codegen/src/main.rs b/codegen/src/main.rs index 4dcaf01..fe0831a 100644 --- a/codegen/src/main.rs +++ b/codegen/src/main.rs @@ -97,12 +97,21 @@ fn main() { } } for name in &["Add", "Sub"] { - let ast_node = MultiVectorClass::sum(*name, ¶meter_a, ¶meter_b, ®istry); + let ast_node = MultiVectorClass::element_wise(*name, ¶meter_a, ¶meter_b, ®istry); emitter.emit(&ast_node).unwrap(); if ast_node != AstNode::None { trait_implementations.insert(name.to_string(), ast_node); } } + if class_a == class_b { + for name in &["Mul", "Div"] { + let ast_node = MultiVectorClass::element_wise(*name, ¶meter_a, ¶meter_b, ®istry); + emitter.emit(&ast_node).unwrap(); + if ast_node != AstNode::None { + trait_implementations.insert(name.to_string(), ast_node); + } + } + } for (name, product) in products.iter() { let ast_node = MultiVectorClass::product(name, product, ¶meter_a, ¶meter_b, ®istry); emitter.emit(&ast_node).unwrap(); diff --git a/src/simd.rs b/src/simd.rs index 8be0bfe..69e4132 100755 --- a/src/simd.rs +++ b/src/simd.rs @@ -412,3 +412,53 @@ impl std::ops::Mul for Simd32x2 { ) } } + +impl std::ops::Div for Simd32x4 { + type Output = Simd32x4; + + fn div(self, other: Self) -> Self { + match_architecture!( + Self, + { f128: _mm_div_ps(self.f128, other.f128) }, + { f128: vdivq_f32(self.f128, other.f128) }, + { v128: f32x4_div(self.v128, other.v128) }, + { f32x4: [ + self.f32x4[0] / other.f32x4[0], + self.f32x4[1] / other.f32x4[1], + self.f32x4[2] / other.f32x4[2], + self.f32x4[3] / other.f32x4[3], + ] }, + ) + } +} + +impl std::ops::Div for Simd32x3 { + type Output = Simd32x3; + + fn div(self, other: Self) -> Self { + match_architecture!( + Self, + { v32x4: unsafe { self.v32x4 / other.v32x4 } }, + { f32x3: [ + self.f32x3[0] / other.f32x3[0], + self.f32x3[1] / other.f32x3[1], + self.f32x3[2] / other.f32x3[2], + ] }, + ) + } +} + +impl std::ops::Div for Simd32x2 { + type Output = Simd32x2; + + fn div(self, other: Self) -> Self { + match_architecture!( + Self, + { v32x4: unsafe { self.v32x4 / other.v32x4 } }, + { f32x2: [ + self.f32x2[0] / other.f32x2[0], + self.f32x2[1] / other.f32x2[1], + ] }, + ) + } +}