From b2758277047fb5582dd5678fe70e6ee127888274 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexander=20Mei=C3=9Fner?= Date: Sat, 17 Apr 2021 13:43:52 +0200 Subject: [PATCH] Implements debug, index and assign traits in Rust target. --- codegen/src/rust.rs | 36 ++++++++++++++--- src/lib.rs | 22 +++++------ src/simd.rs | 95 +++++++++++++++++++++++++++++++-------------- 3 files changed, 107 insertions(+), 46 deletions(-) diff --git a/codegen/src/rust.rs b/codegen/src/rust.rs index 6ee6b74..d23b60d 100644 --- a/codegen/src/rust.rs +++ b/codegen/src/rust.rs @@ -1,5 +1,5 @@ use crate::{ - ast::{AstNode, DataType, Expression, ExpressionContent}, + ast::{AstNode, DataType, Expression, ExpressionContent, Parameter}, emit::{camel_to_snake_case, emit_indentation}, }; @@ -72,7 +72,7 @@ fn emit_expression(collector: &mut W, expression: &Expression if expression.size == 1 { emit_expression(collector, &inner_expression)?; if inner_expression.size > 1 { - collector.write_fmt(format_args!(".get_f({})", indices[0]))?; + collector.write_fmt(format_args!("[{}]", indices[0]))?; } } else { collector.write_all(b"swizzle!(")?; @@ -101,7 +101,7 @@ fn emit_expression(collector: &mut W, expression: &Expression emit_expression(collector, &inner_expression)?; collector.write_fmt(format_args!(".g{}", array_index))?; if inner_expression.size > 1 { - collector.write_fmt(format_args!(".get_f({})", *component_index))?; + collector.write_fmt(format_args!("[{}]", *component_index))?; } } if indices.len() > 1 { @@ -165,15 +165,35 @@ fn emit_expression(collector: &mut W, expression: &Expression Ok(()) } +fn emit_assign_trait(collector: &mut W, result: &Parameter, parameters: &[Parameter]) -> std::io::Result<()> { + if result.multi_vector_class() != parameters[0].multi_vector_class() { + return Ok(()); + } + collector.write_fmt(format_args!( + "impl {}Assign<{}> for {} {{\n fn ", + result.name, + parameters[1].multi_vector_class().class_name, + parameters[0].multi_vector_class().class_name + ))?; + camel_to_snake_case(collector, result.name)?; + collector.write_fmt(format_args!( + "_assign(&mut self, other: {}) {{\n *self = (*self).", + parameters[1].multi_vector_class().class_name + ))?; + camel_to_snake_case(collector, result.name)?; + collector.write_all(b"(other);\n }\n}\n\n") +} + pub fn emit_code(collector: &mut W, ast_node: &AstNode, indentation: usize) -> std::io::Result<()> { match &ast_node { AstNode::None => {} AstNode::Preamble => { collector.write_all(b"#![allow(clippy::assign_op_pattern)]\n")?; - collector.write_all(b"use crate::{*, simd::*};\nuse std::ops::{Add, Neg, Sub, Mul, Div};\n\n")?; + collector + .write_all(b"use crate::{simd::*, *};\nuse std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};\n\n")?; } AstNode::ClassDefinition { class } => { - collector.write_fmt(format_args!("#[derive(Clone, Copy)]\npub struct {} {{\n", class.class_name))?; + collector.write_fmt(format_args!("#[derive(Clone, Copy, Debug)]\npub struct {} {{\n", class.class_name))?; for (i, group) in class.grouped_basis.iter().enumerate() { emit_indentation(collector, indentation + 1)?; collector.write_all(b"/// ")?; @@ -283,6 +303,12 @@ pub fn emit_code(collector: &mut W, ast_node: &AstNode, inden } emit_indentation(collector, indentation + 1)?; collector.write_all(b"}\n}\n\n")?; + match result.name { + "Add" | "Sub" | "Mul" | "Div" => { + emit_assign_trait(collector, result, ¶meters)?; + } + _ => {} + } } } Ok(()) diff --git a/src/lib.rs b/src/lib.rs index e5e613c..4ddc13f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,10 +1,10 @@ #![cfg_attr(all(target_arch = "wasm32", target_feature = "simd128"), feature(wasm_simd))] #![cfg_attr(all(any(target_arch = "arm", target_arch = "aarch64"), target_feature = "neon"), feature(stdsimd))] -pub mod simd; pub mod complex; pub mod ppga2d; pub mod ppga3d; +pub mod simd; impl complex::Scalar { pub const fn new(real: f32) -> Self { @@ -34,11 +34,11 @@ impl complex::MultiVector { } pub fn real(self) -> f32 { - self.g0.get_f(0) + self.g0[0] } pub fn imaginary(self) -> f32 { - self.g0.get_f(1) + self.g0[1] } pub fn from_polar(magnitude: f32, angle: f32) -> Self { @@ -71,7 +71,7 @@ pub trait Dual { } /// Negates elements with `grade % 2 == 1` -/// +/// /// Also called main involution pub trait Automorph { type Output; @@ -99,7 +99,7 @@ pub trait GeometricProduct { } /// Dual of the geometric product grade filtered by `t == r + s` -/// +/// /// Also called join pub trait RegressiveProduct { type Output; @@ -107,7 +107,7 @@ pub trait RegressiveProduct { } /// Geometric product grade filtered by `t == r + s` -/// +/// /// Also called meet or exterior product pub trait OuterProduct { type Output; @@ -115,7 +115,7 @@ pub trait OuterProduct { } /// Geometric product grade filtered by `t == (r - s).abs()` -/// +/// /// Also called fat dot product pub trait InnerProduct { type Output; @@ -141,7 +141,7 @@ pub trait ScalarProduct { } /// `self * other * self` -/// +/// /// Basically a sandwich product without an involution pub trait Reflection { type Output; @@ -149,7 +149,7 @@ pub trait Reflection { } /// `self * other * self.transpose()` -/// +/// /// Also called sandwich product pub trait Transformation { type Output; @@ -163,7 +163,7 @@ pub trait SquaredMagnitude { } /// Length as scalar -/// +/// /// Also called amplitude, absolute value or norm pub trait Magnitude { type Output; @@ -171,7 +171,7 @@ pub trait Magnitude { } /// Direction without magnitude (set to scalar `1.0`) -/// +/// /// Also called sign or normalize pub trait Signum { type Output; diff --git a/src/simd.rs b/src/simd.rs index b27d445..9724d64 100755 --- a/src/simd.rs +++ b/src/simd.rs @@ -1,3 +1,5 @@ +use std::ops::{Index, IndexMut}; + #[cfg(target_arch = "aarch64")] pub use std::arch::aarch64::*; #[cfg(target_arch = "arm")] @@ -122,33 +124,63 @@ macro_rules! swizzle { }; } -impl Simd32x4 { - pub fn get_f(&self, index: usize) -> f32 { - unsafe { self.f32x4[index] } - } +impl Index for Simd32x4 { + type Output = f32; - pub fn set_f(&mut self, index: usize, value: f32) { - unsafe { self.f32x4[index] = value; } + fn index(&self, index: usize) -> &Self::Output { + unsafe { &self.f32x4[index] } } } -impl Simd32x3 { - pub fn get_f(&self, index: usize) -> f32 { - unsafe { self.f32x3[index] } - } +impl Index for Simd32x3 { + type Output = f32; - pub fn set_f(&mut self, index: usize, value: f32) { - unsafe { self.f32x3[index] = value; } + fn index(&self, index: usize) -> &Self::Output { + unsafe { &self.f32x3[index] } } } -impl Simd32x2 { - pub fn get_f(&self, index: usize) -> f32 { - unsafe { self.f32x2[index] } - } +impl Index for Simd32x2 { + type Output = f32; - pub fn set_f(&mut self, index: usize, value: f32) { - unsafe { self.f32x2[index] = value; } + fn index(&self, index: usize) -> &Self::Output { + unsafe { &self.f32x2[index] } + } +} + +impl IndexMut for Simd32x4 { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + unsafe { &mut self.f32x4[index] } + } +} + +impl IndexMut for Simd32x3 { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + unsafe { &mut self.f32x3[index] } + } +} + +impl IndexMut for Simd32x2 { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + unsafe { &mut self.f32x2[index] } + } +} + +impl std::convert::From for [f32; 4] { + fn from(simd: Simd32x4) -> Self { + unsafe { simd.f32x4 } + } +} + +impl std::convert::From for [f32; 3] { + fn from(simd: Simd32x3) -> Self { + unsafe { simd.f32x3 } + } +} + +impl std::convert::From for [f32; 2] { + fn from(simd: Simd32x2) -> Self { + unsafe { simd.f32x2 } } } @@ -196,30 +228,33 @@ impl std::convert::From for Simd32x2 { impl std::fmt::Debug for Simd32x4 { fn fmt(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.debug_tuple("Vec4") - .field(&self.get_f(0)) - .field(&self.get_f(1)) - .field(&self.get_f(2)) - .field(&self.get_f(3)) + formatter + .debug_tuple("Vec4") + .field(&self[0]) + .field(&self[1]) + .field(&self[2]) + .field(&self[3]) .finish() } } impl std::fmt::Debug for Simd32x3 { fn fmt(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.debug_tuple("Vec3") - .field(&self.get_f(0)) - .field(&self.get_f(1)) - .field(&self.get_f(2)) + formatter + .debug_tuple("Vec3") + .field(&self[0]) + .field(&self[1]) + .field(&self[2]) .finish() } } impl std::fmt::Debug for Simd32x2 { fn fmt(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.debug_tuple("Vec2") - .field(&self.get_f(0)) - .field(&self.get_f(1)) + formatter + .debug_tuple("Vec2") + .field(&self[0]) + .field(&self[1]) .finish() } }