Adds element wise multiplication as Mul and division as Div.
This commit is contained in:
parent
e9061a1105
commit
340ef20738
3 changed files with 63 additions and 2 deletions
|
|
@ -287,7 +287,7 @@ impl MultiVectorClass {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn sum<'a>(
|
pub fn element_wise<'a>(
|
||||||
name: &'static str,
|
name: &'static str,
|
||||||
parameter_a: &Parameter<'a>,
|
parameter_a: &Parameter<'a>,
|
||||||
parameter_b: &Parameter<'a>,
|
parameter_b: &Parameter<'a>,
|
||||||
|
|
@ -355,6 +355,8 @@ impl MultiVectorClass {
|
||||||
content: match name {
|
content: match name {
|
||||||
"Add" => ExpressionContent::Add(Box::new(expressions.next().unwrap()), Box::new(expressions.next().unwrap())),
|
"Add" => ExpressionContent::Add(Box::new(expressions.next().unwrap()), Box::new(expressions.next().unwrap())),
|
||||||
"Sub" => ExpressionContent::Subtract(Box::new(expressions.next().unwrap()), Box::new(expressions.next().unwrap())),
|
"Sub" => ExpressionContent::Subtract(Box::new(expressions.next().unwrap()), Box::new(expressions.next().unwrap())),
|
||||||
|
"Mul" => ExpressionContent::Multiply(Box::new(expressions.next().unwrap()), Box::new(expressions.next().unwrap())),
|
||||||
|
"Div" => ExpressionContent::Divide(Box::new(expressions.next().unwrap()), Box::new(expressions.next().unwrap())),
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
},
|
},
|
||||||
})),
|
})),
|
||||||
|
|
|
||||||
|
|
@ -97,12 +97,21 @@ fn main() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for name in &["Add", "Sub"] {
|
for name in &["Add", "Sub"] {
|
||||||
let ast_node = MultiVectorClass::sum(*name, ¶meter_a, ¶meter_b, ®istry);
|
let ast_node = MultiVectorClass::element_wise(*name, ¶meter_a, ¶meter_b, ®istry);
|
||||||
emitter.emit(&ast_node).unwrap();
|
emitter.emit(&ast_node).unwrap();
|
||||||
if ast_node != AstNode::None {
|
if ast_node != AstNode::None {
|
||||||
trait_implementations.insert(name.to_string(), ast_node);
|
trait_implementations.insert(name.to_string(), ast_node);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if class_a == class_b {
|
||||||
|
for name in &["Mul", "Div"] {
|
||||||
|
let ast_node = MultiVectorClass::element_wise(*name, ¶meter_a, ¶meter_b, ®istry);
|
||||||
|
emitter.emit(&ast_node).unwrap();
|
||||||
|
if ast_node != AstNode::None {
|
||||||
|
trait_implementations.insert(name.to_string(), ast_node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
for (name, product) in products.iter() {
|
for (name, product) in products.iter() {
|
||||||
let ast_node = MultiVectorClass::product(name, product, ¶meter_a, ¶meter_b, ®istry);
|
let ast_node = MultiVectorClass::product(name, product, ¶meter_a, ¶meter_b, ®istry);
|
||||||
emitter.emit(&ast_node).unwrap();
|
emitter.emit(&ast_node).unwrap();
|
||||||
|
|
|
||||||
50
src/simd.rs
50
src/simd.rs
|
|
@ -412,3 +412,53 @@ impl std::ops::Mul<Simd32x2> for Simd32x2 {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl std::ops::Div<Simd32x4> for Simd32x4 {
|
||||||
|
type Output = Simd32x4;
|
||||||
|
|
||||||
|
fn div(self, other: Self) -> Self {
|
||||||
|
match_architecture!(
|
||||||
|
Self,
|
||||||
|
{ f128: _mm_div_ps(self.f128, other.f128) },
|
||||||
|
{ f128: vdivq_f32(self.f128, other.f128) },
|
||||||
|
{ v128: f32x4_div(self.v128, other.v128) },
|
||||||
|
{ f32x4: [
|
||||||
|
self.f32x4[0] / other.f32x4[0],
|
||||||
|
self.f32x4[1] / other.f32x4[1],
|
||||||
|
self.f32x4[2] / other.f32x4[2],
|
||||||
|
self.f32x4[3] / other.f32x4[3],
|
||||||
|
] },
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::ops::Div<Simd32x3> for Simd32x3 {
|
||||||
|
type Output = Simd32x3;
|
||||||
|
|
||||||
|
fn div(self, other: Self) -> Self {
|
||||||
|
match_architecture!(
|
||||||
|
Self,
|
||||||
|
{ v32x4: unsafe { self.v32x4 / other.v32x4 } },
|
||||||
|
{ f32x3: [
|
||||||
|
self.f32x3[0] / other.f32x3[0],
|
||||||
|
self.f32x3[1] / other.f32x3[1],
|
||||||
|
self.f32x3[2] / other.f32x3[2],
|
||||||
|
] },
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::ops::Div<Simd32x2> for Simd32x2 {
|
||||||
|
type Output = Simd32x2;
|
||||||
|
|
||||||
|
fn div(self, other: Self) -> Self {
|
||||||
|
match_architecture!(
|
||||||
|
Self,
|
||||||
|
{ v32x4: unsafe { self.v32x4 / other.v32x4 } },
|
||||||
|
{ f32x2: [
|
||||||
|
self.f32x2[0] / other.f32x2[0],
|
||||||
|
self.f32x2[1] / other.f32x2[1],
|
||||||
|
] },
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue