Blank scalars without the wrapper class.
This commit is contained in:
parent
a7090ff329
commit
d73b6a253c
6 changed files with 309 additions and 115 deletions
|
|
@ -7,6 +7,16 @@ pub enum DataType<'a> {
|
||||||
MultiVector(&'a MultiVectorClass),
|
MultiVector(&'a MultiVectorClass),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl DataType<'_> {
|
||||||
|
pub fn is_scalar(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
Self::SimdVector(1) => true,
|
||||||
|
Self::MultiVector(multi_vector_class) => multi_vector_class.is_scalar(),
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Clone, Debug)]
|
#[derive(PartialEq, Eq, Clone, Debug)]
|
||||||
pub enum ExpressionContent<'a> {
|
pub enum ExpressionContent<'a> {
|
||||||
None,
|
None,
|
||||||
|
|
@ -42,6 +52,19 @@ pub struct Expression<'a> {
|
||||||
pub content: ExpressionContent<'a>,
|
pub content: ExpressionContent<'a>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Expression<'_> {
|
||||||
|
pub fn is_scalar(&self) -> bool {
|
||||||
|
if self.size > 1 {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
match &self.content {
|
||||||
|
ExpressionContent::Variable(data_type, _) => data_type.is_scalar(),
|
||||||
|
ExpressionContent::InvokeInstanceMethod(_, _, _, result_data_type, _) => result_data_type.is_scalar(),
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Clone, Debug)]
|
#[derive(PartialEq, Eq, Clone, Debug)]
|
||||||
pub struct Parameter<'a> {
|
pub struct Parameter<'a> {
|
||||||
pub name: &'static str,
|
pub name: &'static str,
|
||||||
|
|
|
||||||
|
|
@ -146,6 +146,10 @@ impl MultiVectorClass {
|
||||||
self.grouped_basis.iter().flatten().cloned().collect()
|
self.grouped_basis.iter().flatten().cloned().collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn is_scalar(&self) -> bool {
|
||||||
|
self.flat_basis() == vec![BasisElement { scalar: 1, index: 0 }]
|
||||||
|
}
|
||||||
|
|
||||||
pub fn signature(&self) -> Vec<BasisElementIndex> {
|
pub fn signature(&self) -> Vec<BasisElementIndex> {
|
||||||
let mut signature: Vec<BasisElementIndex> = self.grouped_basis.iter().flatten().map(|element| element.index).collect();
|
let mut signature: Vec<BasisElementIndex> = self.grouped_basis.iter().flatten().map(|element| element.index).collect();
|
||||||
signature.sort_unstable();
|
signature.sort_unstable();
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,8 @@ fn emit_data_type<W: std::io::Write>(collector: &mut W, data_type: &DataType) ->
|
||||||
DataType::Integer => collector.write_all(b"int"),
|
DataType::Integer => collector.write_all(b"int"),
|
||||||
DataType::SimdVector(size) if *size == 1 => collector.write_all(b"float"),
|
DataType::SimdVector(size) if *size == 1 => collector.write_all(b"float"),
|
||||||
DataType::SimdVector(size) => collector.write_fmt(format_args!("vec{}", *size)),
|
DataType::SimdVector(size) => collector.write_fmt(format_args!("vec{}", *size)),
|
||||||
DataType::MultiVector(class) => collector.write_fmt(format_args!("{}", class.class_name)),
|
DataType::MultiVector(class) if class.is_scalar() => collector.write_all(b"float"),
|
||||||
|
DataType::MultiVector(class) => collector.write_all(class.class_name.as_bytes()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -20,6 +21,9 @@ fn emit_expression<W: std::io::Write>(collector: &mut W, expression: &Expression
|
||||||
ExpressionContent::Variable(_data_type, name) => {
|
ExpressionContent::Variable(_data_type, name) => {
|
||||||
collector.write_all(name.bytes().collect::<Vec<_>>().as_slice())?;
|
collector.write_all(name.bytes().collect::<Vec<_>>().as_slice())?;
|
||||||
}
|
}
|
||||||
|
ExpressionContent::InvokeClassMethod(class, "Constructor", arguments) if class.is_scalar() => {
|
||||||
|
emit_expression(collector, &arguments[0].1)?;
|
||||||
|
}
|
||||||
ExpressionContent::InvokeClassMethod(_, _, arguments) | ExpressionContent::InvokeInstanceMethod(_, _, _, _, arguments) => {
|
ExpressionContent::InvokeClassMethod(_, _, arguments) | ExpressionContent::InvokeInstanceMethod(_, _, _, _, arguments) => {
|
||||||
match &expression.content {
|
match &expression.content {
|
||||||
ExpressionContent::InvokeInstanceMethod(result_class, inner_expression, method_name, _, _) => {
|
ExpressionContent::InvokeInstanceMethod(result_class, inner_expression, method_name, _, _) => {
|
||||||
|
|
@ -78,7 +82,9 @@ fn emit_expression<W: std::io::Write>(collector: &mut W, expression: &Expression
|
||||||
}
|
}
|
||||||
ExpressionContent::Access(inner_expression, array_index) => {
|
ExpressionContent::Access(inner_expression, array_index) => {
|
||||||
emit_expression(collector, inner_expression)?;
|
emit_expression(collector, inner_expression)?;
|
||||||
collector.write_fmt(format_args!(".g{}", array_index))?;
|
if !inner_expression.is_scalar() {
|
||||||
|
collector.write_fmt(format_args!(".g{}", array_index))?;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
ExpressionContent::Swizzle(inner_expression, indices) => {
|
ExpressionContent::Swizzle(inner_expression, indices) => {
|
||||||
emit_expression(collector, inner_expression)?;
|
emit_expression(collector, inner_expression)?;
|
||||||
|
|
@ -88,22 +94,28 @@ fn emit_expression<W: std::io::Write>(collector: &mut W, expression: &Expression
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ExpressionContent::Gather(inner_expression, indices) => {
|
ExpressionContent::Gather(inner_expression, indices) => {
|
||||||
if expression.size > 1 {
|
if expression.size == 1 && inner_expression.is_scalar() {
|
||||||
emit_data_type(collector, &DataType::SimdVector(expression.size))?;
|
|
||||||
collector.write_all(b"(")?;
|
|
||||||
}
|
|
||||||
for (i, (array_index, component_index)) in indices.iter().enumerate() {
|
|
||||||
if i > 0 {
|
|
||||||
collector.write_all(b", ")?;
|
|
||||||
}
|
|
||||||
emit_expression(collector, inner_expression)?;
|
emit_expression(collector, inner_expression)?;
|
||||||
collector.write_fmt(format_args!(".g{}", array_index))?;
|
} else {
|
||||||
if inner_expression.size > 1 {
|
if expression.size > 1 {
|
||||||
collector.write_fmt(format_args!(".{}", COMPONENT[*component_index]))?;
|
emit_data_type(collector, &DataType::SimdVector(expression.size))?;
|
||||||
|
collector.write_all(b"(")?;
|
||||||
|
}
|
||||||
|
for (i, (array_index, component_index)) in indices.iter().enumerate() {
|
||||||
|
if i > 0 {
|
||||||
|
collector.write_all(b", ")?;
|
||||||
|
}
|
||||||
|
emit_expression(collector, inner_expression)?;
|
||||||
|
if !inner_expression.is_scalar() {
|
||||||
|
collector.write_fmt(format_args!(".g{}", array_index))?;
|
||||||
|
if inner_expression.size > 1 {
|
||||||
|
collector.write_fmt(format_args!(".{}", COMPONENT[*component_index]))?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if expression.size > 1 {
|
||||||
|
collector.write_all(b")")?;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
if expression.size > 1 {
|
|
||||||
collector.write_all(b")")?;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ExpressionContent::Constant(data_type, values) => match data_type {
|
ExpressionContent::Constant(data_type, values) => match data_type {
|
||||||
|
|
@ -163,6 +175,9 @@ pub fn emit_code<W: std::io::Write>(collector: &mut W, ast_node: &AstNode, inden
|
||||||
AstNode::None => {}
|
AstNode::None => {}
|
||||||
AstNode::Preamble => {}
|
AstNode::Preamble => {}
|
||||||
AstNode::ClassDefinition { class } => {
|
AstNode::ClassDefinition { class } => {
|
||||||
|
if class.is_scalar() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
collector.write_fmt(format_args!("struct {} {{\n", class.class_name))?;
|
collector.write_fmt(format_args!("struct {} {{\n", class.class_name))?;
|
||||||
for (i, group) in class.grouped_basis.iter().enumerate() {
|
for (i, group) in class.grouped_basis.iter().enumerate() {
|
||||||
emit_indentation(collector, indentation + 1)?;
|
emit_indentation(collector, indentation + 1)?;
|
||||||
|
|
@ -212,7 +227,8 @@ pub fn emit_code<W: std::io::Write>(collector: &mut W, ast_node: &AstNode, inden
|
||||||
collector.write_all(b"}\n")?;
|
collector.write_all(b"}\n")?;
|
||||||
}
|
}
|
||||||
AstNode::TraitImplementation { result, parameters, body } => {
|
AstNode::TraitImplementation { result, parameters, body } => {
|
||||||
collector.write_fmt(format_args!("{} ", result.multi_vector_class().class_name))?;
|
emit_data_type(collector, &result.data_type)?;
|
||||||
|
collector.write_all(b" ")?;
|
||||||
match parameters.len() {
|
match parameters.len() {
|
||||||
0 => camel_to_snake_case(collector, &result.multi_vector_class().class_name)?,
|
0 => camel_to_snake_case(collector, &result.multi_vector_class().class_name)?,
|
||||||
1 if result.name == "Into" => {
|
1 if result.name == "Into" => {
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ fn emit_data_type<W: std::io::Write>(collector: &mut W, data_type: &DataType) ->
|
||||||
DataType::Integer => collector.write_all(b"isize"),
|
DataType::Integer => collector.write_all(b"isize"),
|
||||||
DataType::SimdVector(size) if *size == 1 => collector.write_all(b"f32"),
|
DataType::SimdVector(size) if *size == 1 => collector.write_all(b"f32"),
|
||||||
DataType::SimdVector(size) => collector.write_fmt(format_args!("Simd32x{}", *size)),
|
DataType::SimdVector(size) => collector.write_fmt(format_args!("Simd32x{}", *size)),
|
||||||
|
DataType::MultiVector(class) if class.is_scalar() => collector.write_all(b"f32"),
|
||||||
DataType::MultiVector(class) => collector.write_fmt(format_args!("{}", class.class_name)),
|
DataType::MultiVector(class) => collector.write_fmt(format_args!("{}", class.class_name)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -18,39 +19,45 @@ fn emit_expression<W: std::io::Write>(collector: &mut W, expression: &Expression
|
||||||
ExpressionContent::Variable(_data_type, name) => {
|
ExpressionContent::Variable(_data_type, name) => {
|
||||||
collector.write_all(name.bytes().collect::<Vec<_>>().as_slice())?;
|
collector.write_all(name.bytes().collect::<Vec<_>>().as_slice())?;
|
||||||
}
|
}
|
||||||
ExpressionContent::InvokeClassMethod(_, method_name, arguments) | ExpressionContent::InvokeInstanceMethod(_, _, method_name, _, arguments) => {
|
ExpressionContent::InvokeInstanceMethod(_result_class, inner_expression, method_name, _, arguments) => {
|
||||||
match &expression.content {
|
emit_expression(collector, inner_expression)?;
|
||||||
ExpressionContent::InvokeInstanceMethod(_result_class, inner_expression, _, _, _) => {
|
collector.write_all(b".")?;
|
||||||
emit_expression(collector, inner_expression)?;
|
camel_to_snake_case(collector, method_name)?;
|
||||||
collector.write_all(b".")?;
|
collector.write_all(b"(")?;
|
||||||
}
|
|
||||||
ExpressionContent::InvokeClassMethod(class, _, _) => {
|
|
||||||
if *method_name == "Constructor" {
|
|
||||||
collector.write_fmt(format_args!("{} {{ groups: {}Groups {{ ", class.class_name, class.class_name))?;
|
|
||||||
} else {
|
|
||||||
collector.write_fmt(format_args!("{}::", class.class_name))?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => unreachable!(),
|
|
||||||
}
|
|
||||||
if *method_name != "Constructor" {
|
|
||||||
camel_to_snake_case(collector, method_name)?;
|
|
||||||
collector.write_all(b"(")?;
|
|
||||||
}
|
|
||||||
for (i, (_argument_class, argument)) in arguments.iter().enumerate() {
|
for (i, (_argument_class, argument)) in arguments.iter().enumerate() {
|
||||||
if i > 0 {
|
if i > 0 {
|
||||||
collector.write_all(b", ")?;
|
collector.write_all(b", ")?;
|
||||||
}
|
}
|
||||||
if *method_name == "Constructor" {
|
emit_expression(collector, argument)?;
|
||||||
collector.write_fmt(format_args!("g{}: ", i))?;
|
}
|
||||||
|
collector.write_all(b")")?;
|
||||||
|
}
|
||||||
|
ExpressionContent::InvokeClassMethod(class, "Constructor", arguments) if class.is_scalar() => {
|
||||||
|
emit_expression(collector, &arguments[0].1)?;
|
||||||
|
}
|
||||||
|
ExpressionContent::InvokeClassMethod(class, "Constructor", arguments) => {
|
||||||
|
collector.write_fmt(format_args!("{} {{ groups: {}Groups {{ ", class.class_name, class.class_name))?;
|
||||||
|
for (i, (_argument_class, argument)) in arguments.iter().enumerate() {
|
||||||
|
if i > 0 {
|
||||||
|
collector.write_all(b", ")?;
|
||||||
|
}
|
||||||
|
collector.write_fmt(format_args!("g{}: ", i))?;
|
||||||
|
emit_expression(collector, argument)?;
|
||||||
|
}
|
||||||
|
collector.write_all(b" } }")?;
|
||||||
|
}
|
||||||
|
ExpressionContent::InvokeClassMethod(class, method_name, arguments) => {
|
||||||
|
emit_data_type(collector, &DataType::MultiVector(class))?;
|
||||||
|
collector.write_all(b"::")?;
|
||||||
|
camel_to_snake_case(collector, method_name)?;
|
||||||
|
collector.write_all(b"(")?;
|
||||||
|
for (i, (_argument_class, argument)) in arguments.iter().enumerate() {
|
||||||
|
if i > 0 {
|
||||||
|
collector.write_all(b", ")?;
|
||||||
}
|
}
|
||||||
emit_expression(collector, argument)?;
|
emit_expression(collector, argument)?;
|
||||||
}
|
}
|
||||||
if *method_name == "Constructor" {
|
collector.write_all(b")")?;
|
||||||
collector.write_all(b" } }")?;
|
|
||||||
} else {
|
|
||||||
collector.write_all(b")")?;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
ExpressionContent::Conversion(_source_class, _destination_class, inner_expression) => {
|
ExpressionContent::Conversion(_source_class, _destination_class, inner_expression) => {
|
||||||
emit_expression(collector, inner_expression)?;
|
emit_expression(collector, inner_expression)?;
|
||||||
|
|
@ -67,7 +74,9 @@ fn emit_expression<W: std::io::Write>(collector: &mut W, expression: &Expression
|
||||||
}
|
}
|
||||||
ExpressionContent::Access(inner_expression, array_index) => {
|
ExpressionContent::Access(inner_expression, array_index) => {
|
||||||
emit_expression(collector, inner_expression)?;
|
emit_expression(collector, inner_expression)?;
|
||||||
collector.write_fmt(format_args!(".group{}()", array_index))?;
|
if !inner_expression.is_scalar() {
|
||||||
|
collector.write_fmt(format_args!(".group{}()", array_index))?;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
ExpressionContent::Swizzle(inner_expression, indices) => {
|
ExpressionContent::Swizzle(inner_expression, indices) => {
|
||||||
if expression.size == 1 {
|
if expression.size == 1 {
|
||||||
|
|
@ -89,28 +98,34 @@ fn emit_expression<W: std::io::Write>(collector: &mut W, expression: &Expression
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ExpressionContent::Gather(inner_expression, indices) => {
|
ExpressionContent::Gather(inner_expression, indices) => {
|
||||||
if expression.size > 1 {
|
if expression.size == 1 && inner_expression.is_scalar() {
|
||||||
emit_data_type(collector, &DataType::SimdVector(expression.size))?;
|
|
||||||
collector.write_all(b"::from(")?;
|
|
||||||
}
|
|
||||||
if indices.len() > 1 {
|
|
||||||
collector.write_all(b"[")?;
|
|
||||||
}
|
|
||||||
for (i, (array_index, component_index)) in indices.iter().enumerate() {
|
|
||||||
if i > 0 {
|
|
||||||
collector.write_all(b", ")?;
|
|
||||||
}
|
|
||||||
emit_expression(collector, inner_expression)?;
|
emit_expression(collector, inner_expression)?;
|
||||||
collector.write_fmt(format_args!(".group{}()", array_index))?;
|
} else {
|
||||||
if inner_expression.size > 1 {
|
if expression.size > 1 {
|
||||||
collector.write_fmt(format_args!("[{}]", *component_index))?;
|
emit_data_type(collector, &DataType::SimdVector(expression.size))?;
|
||||||
|
collector.write_all(b"::from(")?;
|
||||||
|
}
|
||||||
|
if indices.len() > 1 {
|
||||||
|
collector.write_all(b"[")?;
|
||||||
|
}
|
||||||
|
for (i, (array_index, component_index)) in indices.iter().enumerate() {
|
||||||
|
if i > 0 {
|
||||||
|
collector.write_all(b", ")?;
|
||||||
|
}
|
||||||
|
emit_expression(collector, inner_expression)?;
|
||||||
|
if !inner_expression.is_scalar() {
|
||||||
|
collector.write_fmt(format_args!(".group{}()", array_index))?;
|
||||||
|
if inner_expression.size > 1 {
|
||||||
|
collector.write_fmt(format_args!("[{}]", *component_index))?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if indices.len() > 1 {
|
||||||
|
collector.write_all(b"]")?;
|
||||||
|
}
|
||||||
|
if expression.size > 1 {
|
||||||
|
collector.write_all(b")")?;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
if indices.len() > 1 {
|
|
||||||
collector.write_all(b"]")?;
|
|
||||||
}
|
|
||||||
if expression.size > 1 {
|
|
||||||
collector.write_all(b")")?;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ExpressionContent::Constant(data_type, values) => match data_type {
|
ExpressionContent::Constant(data_type, values) => match data_type {
|
||||||
|
|
@ -172,17 +187,15 @@ fn emit_assign_trait<W: std::io::Write>(collector: &mut W, result: &Parameter, p
|
||||||
if result.multi_vector_class() != parameters[0].multi_vector_class() {
|
if result.multi_vector_class() != parameters[0].multi_vector_class() {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
collector.write_fmt(format_args!(
|
collector.write_fmt(format_args!("impl {}Assign<", result.name))?;
|
||||||
"impl {}Assign<{}> for {} {{\n fn ",
|
emit_data_type(collector, ¶meters[1].data_type)?;
|
||||||
result.name,
|
collector.write_all(b"> for ")?;
|
||||||
parameters[1].multi_vector_class().class_name,
|
emit_data_type(collector, ¶meters[0].data_type)?;
|
||||||
parameters[0].multi_vector_class().class_name
|
collector.write_all(b" {\n fn ")?;
|
||||||
))?;
|
|
||||||
camel_to_snake_case(collector, result.name)?;
|
camel_to_snake_case(collector, result.name)?;
|
||||||
collector.write_fmt(format_args!(
|
collector.write_all(b"_assign(&mut self, other: ")?;
|
||||||
"_assign(&mut self, other: {}) {{\n *self = (*self).",
|
emit_data_type(collector, ¶meters[1].data_type)?;
|
||||||
parameters[1].multi_vector_class().class_name
|
collector.write_all(b") {\n *self = (*self).")?;
|
||||||
))?;
|
|
||||||
camel_to_snake_case(collector, result.name)?;
|
camel_to_snake_case(collector, result.name)?;
|
||||||
collector.write_all(b"(other);\n }\n}\n\n")
|
collector.write_all(b"(other);\n }\n}\n\n")
|
||||||
}
|
}
|
||||||
|
|
@ -196,6 +209,9 @@ pub fn emit_code<W: std::io::Write>(collector: &mut W, ast_node: &AstNode, inden
|
||||||
.write_all(b"use crate::{simd::*, *};\nuse std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};\n\n")?;
|
.write_all(b"use crate::{simd::*, *};\nuse std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};\n\n")?;
|
||||||
}
|
}
|
||||||
AstNode::ClassDefinition { class } => {
|
AstNode::ClassDefinition { class } => {
|
||||||
|
if class.is_scalar() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
let element_count = class.grouped_basis.iter().fold(0, |a, b| a + b.len());
|
let element_count = class.grouped_basis.iter().fold(0, |a, b| a + b.len());
|
||||||
let mut simd_widths = Vec::new();
|
let mut simd_widths = Vec::new();
|
||||||
emit_indentation(collector, indentation)?;
|
emit_indentation(collector, indentation)?;
|
||||||
|
|
@ -471,30 +487,40 @@ pub fn emit_code<W: std::io::Write>(collector: &mut W, ast_node: &AstNode, inden
|
||||||
collector.write_all(b"}\n")?;
|
collector.write_all(b"}\n")?;
|
||||||
}
|
}
|
||||||
AstNode::TraitImplementation { result, parameters, body } => {
|
AstNode::TraitImplementation { result, parameters, body } => {
|
||||||
match parameters.len() {
|
if result.data_type.is_scalar()
|
||||||
0 => collector.write_fmt(format_args!("impl {} for {}", result.name, result.multi_vector_class().class_name))?,
|
&& !parameters
|
||||||
1 if result.name == "Into" => collector.write_fmt(format_args!(
|
.iter()
|
||||||
"impl {}<{}> for {}",
|
.any(|parameter| matches!(parameter.data_type, DataType::MultiVector(class) if !class.is_scalar()))
|
||||||
result.name,
|
{
|
||||||
result.multi_vector_class().class_name,
|
return Ok(());
|
||||||
parameters[0].multi_vector_class().class_name,
|
|
||||||
))?,
|
|
||||||
1 => collector.write_fmt(format_args!("impl {} for {}", result.name, parameters[0].multi_vector_class().class_name))?,
|
|
||||||
2 if !matches!(parameters[1].data_type, DataType::MultiVector(_)) => {
|
|
||||||
collector.write_fmt(format_args!("impl {} for {}", result.name, parameters[0].multi_vector_class().class_name))?
|
|
||||||
}
|
|
||||||
2 => collector.write_fmt(format_args!(
|
|
||||||
"impl {}<{}> for {}",
|
|
||||||
result.name,
|
|
||||||
parameters[1].multi_vector_class().class_name,
|
|
||||||
parameters[0].multi_vector_class().class_name,
|
|
||||||
))?,
|
|
||||||
_ => unreachable!(),
|
|
||||||
}
|
}
|
||||||
|
collector.write_fmt(format_args!("impl {}", result.name))?;
|
||||||
|
let impl_for = match parameters.len() {
|
||||||
|
0 => &result.data_type,
|
||||||
|
1 if result.name == "Into" => {
|
||||||
|
collector.write_all(b"<")?;
|
||||||
|
emit_data_type(collector, &result.data_type)?;
|
||||||
|
collector.write_all(b">")?;
|
||||||
|
¶meters[0].data_type
|
||||||
|
}
|
||||||
|
1 => ¶meters[0].data_type,
|
||||||
|
2 if !matches!(parameters[1].data_type, DataType::MultiVector(_)) => ¶meters[0].data_type,
|
||||||
|
2 => {
|
||||||
|
collector.write_all(b"<")?;
|
||||||
|
emit_data_type(collector, ¶meters[1].data_type)?;
|
||||||
|
collector.write_all(b">")?;
|
||||||
|
¶meters[0].data_type
|
||||||
|
}
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
collector.write_all(b" for ")?;
|
||||||
|
emit_data_type(collector, impl_for)?;
|
||||||
collector.write_all(b" {\n")?;
|
collector.write_all(b" {\n")?;
|
||||||
if !parameters.is_empty() && result.name != "Into" {
|
if !parameters.is_empty() && result.name != "Into" {
|
||||||
emit_indentation(collector, indentation + 1)?;
|
emit_indentation(collector, indentation + 1)?;
|
||||||
collector.write_fmt(format_args!("type Output = {};\n\n", result.multi_vector_class().class_name))?;
|
collector.write_all(b"type Output = ")?;
|
||||||
|
emit_data_type(collector, &result.data_type)?;
|
||||||
|
collector.write_all(b";\n\n")?;
|
||||||
}
|
}
|
||||||
emit_indentation(collector, indentation + 1)?;
|
emit_indentation(collector, indentation + 1)?;
|
||||||
collector.write_all(b"fn ")?;
|
collector.write_all(b"fn ")?;
|
||||||
|
|
|
||||||
149
src/lib.rs
149
src/lib.rs
|
|
@ -10,17 +10,144 @@ pub mod hpga3d;
|
||||||
pub mod simd;
|
pub mod simd;
|
||||||
pub mod polynomial;
|
pub mod polynomial;
|
||||||
|
|
||||||
impl epga1d::Scalar {
|
impl Zero for f32 {
|
||||||
pub fn real(self) -> f32 {
|
fn zero() -> Self {
|
||||||
self[0]
|
0.0
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn sqrt(self) -> epga1d::ComplexNumber {
|
impl One for f32 {
|
||||||
if self[0] < 0.0 {
|
fn one() -> Self {
|
||||||
epga1d::ComplexNumber::new(0.0, (-self[0]).sqrt())
|
1.0
|
||||||
} else {
|
}
|
||||||
epga1d::ComplexNumber::new(self[0].sqrt(), 0.0)
|
}
|
||||||
}
|
|
||||||
|
impl Automorphism for f32 {
|
||||||
|
type Output = f32;
|
||||||
|
|
||||||
|
fn automorphism(self) -> f32 {
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Reversal for f32 {
|
||||||
|
type Output = f32;
|
||||||
|
|
||||||
|
fn reversal(self) -> f32 {
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Conjugation for f32 {
|
||||||
|
type Output = f32;
|
||||||
|
|
||||||
|
fn conjugation(self) -> f32 {
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GeometricProduct<f32> for f32 {
|
||||||
|
type Output = f32;
|
||||||
|
|
||||||
|
fn geometric_product(self, other: f32) -> f32 {
|
||||||
|
self * other
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OuterProduct<f32> for f32 {
|
||||||
|
type Output = f32;
|
||||||
|
|
||||||
|
fn outer_product(self, other: f32) -> f32 {
|
||||||
|
self * other
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl InnerProduct<f32> for f32 {
|
||||||
|
type Output = f32;
|
||||||
|
|
||||||
|
fn inner_product(self, other: f32) -> f32 {
|
||||||
|
self * other
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LeftContraction<f32> for f32 {
|
||||||
|
type Output = f32;
|
||||||
|
|
||||||
|
fn left_contraction(self, other: f32) -> f32 {
|
||||||
|
self * other
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RightContraction<f32> for f32 {
|
||||||
|
type Output = f32;
|
||||||
|
|
||||||
|
fn right_contraction(self, other: f32) -> f32 {
|
||||||
|
self * other
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ScalarProduct<f32> for f32 {
|
||||||
|
type Output = f32;
|
||||||
|
|
||||||
|
fn scalar_product(self, other: f32) -> f32 {
|
||||||
|
self * other
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SquaredMagnitude for f32 {
|
||||||
|
type Output = f32;
|
||||||
|
|
||||||
|
fn squared_magnitude(self) -> f32 {
|
||||||
|
self.scalar_product(self.reversal())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Magnitude for f32 {
|
||||||
|
type Output = f32;
|
||||||
|
|
||||||
|
fn magnitude(self) -> f32 {
|
||||||
|
self.abs()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Scale for f32 {
|
||||||
|
type Output = f32;
|
||||||
|
|
||||||
|
fn scale(self, other: f32) -> f32 {
|
||||||
|
self.geometric_product(other)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Signum for f32 {
|
||||||
|
type Output = f32;
|
||||||
|
|
||||||
|
fn signum(self) -> f32 {
|
||||||
|
f32::signum(self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Inverse for f32 {
|
||||||
|
type Output = f32;
|
||||||
|
|
||||||
|
fn inverse(self) -> f32 {
|
||||||
|
1.0 / self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GeometricQuotient<f32> for f32 {
|
||||||
|
type Output = f32;
|
||||||
|
|
||||||
|
fn geometric_quotient(self, other: f32) -> f32 {
|
||||||
|
self.geometric_product(other.inverse())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Transformation<f32> for f32 {
|
||||||
|
type Output = f32;
|
||||||
|
|
||||||
|
fn transformation(self, other: f32) -> f32 {
|
||||||
|
self.geometric_product(other)
|
||||||
|
.geometric_product(self.reversal())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -54,7 +181,7 @@ impl Ln for epga1d::ComplexNumber {
|
||||||
type Output = Self;
|
type Output = Self;
|
||||||
|
|
||||||
fn ln(self) -> Self {
|
fn ln(self) -> Self {
|
||||||
Self::new(self.magnitude()[0].ln(), self.arg())
|
Self::new(self.magnitude().ln(), self.arg())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -62,7 +189,7 @@ impl Powf for epga1d::ComplexNumber {
|
||||||
type Output = Self;
|
type Output = Self;
|
||||||
|
|
||||||
fn powf(self, exponent: f32) -> Self {
|
fn powf(self, exponent: f32) -> Self {
|
||||||
Self::from_polar(self.magnitude()[0].powf(exponent), self.arg() * exponent)
|
Self::from_polar(self.magnitude().powf(exponent), self.arg() * exponent)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -50,7 +50,7 @@ pub fn solve_quadratic(coefficients: [f32; 3], error_margin: f32) -> (f32, Vec<R
|
||||||
}
|
}
|
||||||
// https://en.wikipedia.org/wiki/Quadratic_formula
|
// https://en.wikipedia.org/wiki/Quadratic_formula
|
||||||
let discriminant = coefficients[1].powi(2) - 4.0 * coefficients[2] * coefficients[0];
|
let discriminant = coefficients[1].powi(2) - 4.0 * coefficients[2] * coefficients[0];
|
||||||
let q = Scalar::new(discriminant).sqrt();
|
let q = discriminant.sqrt();
|
||||||
let mut solutions = Vec::with_capacity(3);
|
let mut solutions = Vec::with_capacity(3);
|
||||||
for s in [-q, q] {
|
for s in [-q, q] {
|
||||||
let numerator = s - ComplexNumber::new(coefficients[1], 0.0);
|
let numerator = s - ComplexNumber::new(coefficients[1], 0.0);
|
||||||
|
|
@ -93,10 +93,9 @@ pub fn solve_cubic(coefficients: [f32; 4], error_margin: f32) -> (f32, Vec<Root>
|
||||||
];
|
];
|
||||||
let mut solutions = Vec::with_capacity(3);
|
let mut solutions = Vec::with_capacity(3);
|
||||||
let discriminant = d[1].powi(2) - 4.0 * d[0].powi(3);
|
let discriminant = d[1].powi(2) - 4.0 * d[0].powi(3);
|
||||||
let c = Scalar::new(discriminant).sqrt();
|
let c = discriminant.sqrt();
|
||||||
let c = ((c + ComplexNumber::new(if c.real() + d[1] == 0.0 { -d[1] } else { d[1] }, 0.0))
|
let c = ((c + ComplexNumber::new(if c + d[1] == 0.0 { -d[1] } else { d[1] }, 0.0)).scale(0.5))
|
||||||
.scale(0.5))
|
.powf(1.0 / 3.0);
|
||||||
.powf(1.0 / 3.0);
|
|
||||||
for root_of_unity in &ROOTS_OF_UNITY_3 {
|
for root_of_unity in &ROOTS_OF_UNITY_3 {
|
||||||
let ci = c.geometric_product(*root_of_unity);
|
let ci = c.geometric_product(*root_of_unity);
|
||||||
let denominator = ci.scale(3.0 * coefficients[3]);
|
let denominator = ci.scale(3.0 * coefficients[3]);
|
||||||
|
|
@ -105,7 +104,7 @@ pub fn solve_cubic(coefficients: [f32; 4], error_margin: f32) -> (f32, Vec<Root>
|
||||||
.geometric_product(denominator.reversal());
|
.geometric_product(denominator.reversal());
|
||||||
solutions.push(Root {
|
solutions.push(Root {
|
||||||
numerator,
|
numerator,
|
||||||
denominator: denominator.squared_magnitude().real(),
|
denominator: denominator.squared_magnitude(),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
let real_root =
|
let real_root =
|
||||||
|
|
@ -144,10 +143,9 @@ pub fn solve_quartic(coefficients: [f32; 5], error_margin: f32) -> (f32, Vec<Roo
|
||||||
- 72.0 * coefficients[4] * coefficients[2] * coefficients[0],
|
- 72.0 * coefficients[4] * coefficients[2] * coefficients[0],
|
||||||
];
|
];
|
||||||
let discriminant = d[1].powi(2) - 4.0 * d[0].powi(3);
|
let discriminant = d[1].powi(2) - 4.0 * d[0].powi(3);
|
||||||
let c = Scalar::new(discriminant).sqrt();
|
let c = discriminant.sqrt();
|
||||||
let c = ((c + ComplexNumber::new(if c.real() + d[1] == 0.0 { -d[1] } else { d[1] }, 0.0))
|
let c = ((c + ComplexNumber::new(if c + d[1] == 0.0 { -d[1] } else { d[1] }, 0.0)).scale(0.5))
|
||||||
.scale(0.5))
|
.powf(1.0 / 3.0);
|
||||||
.powf(1.0 / 3.0);
|
|
||||||
let e = ((c + ComplexNumber::new(d[0], 0.0).geometric_quotient(c))
|
let e = ((c + ComplexNumber::new(d[0], 0.0).geometric_quotient(c))
|
||||||
.scale(1.0 / (3.0 * coefficients[4]))
|
.scale(1.0 / (3.0 * coefficients[4]))
|
||||||
- ComplexNumber::new(p * 2.0 / 3.0, 0.0))
|
- ComplexNumber::new(p * 2.0 / 3.0, 0.0))
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue