diff --git a/codegen/src/ast.rs b/codegen/src/ast.rs index 7d7fa51..bb28279 100644 --- a/codegen/src/ast.rs +++ b/codegen/src/ast.rs @@ -12,7 +12,13 @@ pub enum ExpressionContent<'a> { None, Variable(DataType<'a>, &'static str), InvokeClassMethod(&'a MultiVectorClass, &'static str, Vec<(DataType<'a>, Expression<'a>)>), - InvokeInstanceMethod(DataType<'a>, Box>, &'static str, Vec<(DataType<'a>, Expression<'a>)>), + InvokeInstanceMethod( + DataType<'a>, + Box>, + &'static str, + DataType<'a>, + Vec<(DataType<'a>, Expression<'a>)>, + ), Conversion(&'a MultiVectorClass, &'a MultiVectorClass, Box>), Select(Box>, Box>, Box>), Access(Box>, usize), diff --git a/codegen/src/compile.rs b/codegen/src/compile.rs index 2295f96..6c46fba 100644 --- a/codegen/src/compile.rs +++ b/codegen/src/compile.rs @@ -618,6 +618,7 @@ impl MultiVectorClass { content: ExpressionContent::Variable(parameter_a.data_type.clone(), parameter_a.name), }), scalar_product_result.name, + scalar_product_result.data_type.clone(), vec![( DataType::MultiVector(involution_result.multi_vector_class()), Expression { @@ -629,6 +630,7 @@ impl MultiVectorClass { content: ExpressionContent::Variable(parameter_a.data_type.clone(), parameter_a.name), }), involution_result.name, + involution_result.data_type.clone(), vec![], ), }, @@ -668,6 +670,7 @@ impl MultiVectorClass { content: ExpressionContent::Variable(parameter_a.data_type.clone(), parameter_a.name), }), geometric_product_result.name, + geometric_product_result.data_type.clone(), vec![( DataType::MultiVector(parameter_b.multi_vector_class()), Expression { @@ -721,6 +724,7 @@ impl MultiVectorClass { content: ExpressionContent::Variable(parameter_a.data_type.clone(), parameter_a.name), }), squared_magnitude_result.name, + squared_magnitude_result.data_type.clone(), vec![], ), }), @@ -759,6 +763,7 @@ impl MultiVectorClass { content: ExpressionContent::Variable(parameter_a.data_type.clone(), parameter_a.name), }), geometric_product_result.name, + geometric_product_result.data_type.clone(), vec![( DataType::MultiVector(magnitude_result.multi_vector_class()), Expression { @@ -790,6 +795,7 @@ impl MultiVectorClass { ), }), magnitude_result.name, + magnitude_result.data_type.clone(), vec![], ), }), @@ -838,10 +844,12 @@ impl MultiVectorClass { content: ExpressionContent::Variable(parameter_a.data_type.clone(), parameter_a.name), }), involution_result.name, + involution_result.data_type.clone(), vec![], ), }), geometric_product_result.name, + geometric_product_result.data_type.clone(), vec![( DataType::MultiVector(squared_magnitude_result.multi_vector_class()), Expression { @@ -873,6 +881,7 @@ impl MultiVectorClass { ), }), squared_magnitude_result.name, + squared_magnitude_result.data_type.clone(), vec![], ), }), @@ -958,6 +967,7 @@ impl MultiVectorClass { content: ExpressionContent::Variable(parameter_a.data_type.clone(), parameter_a.name), }), inverse_result.name, + inverse_result.data_type.clone(), vec![], ), }), @@ -988,6 +998,7 @@ impl MultiVectorClass { content: ExpressionContent::Variable(parameter_b.data_type.clone(), parameter_b.name), }), "Abs", + DataType::Integer, vec![], ), }), @@ -1042,6 +1053,7 @@ impl MultiVectorClass { content: ExpressionContent::Variable(parameter_a.data_type.clone(), "x"), }), geometric_product_result.name, + geometric_product_result.data_type.clone(), vec![( DataType::MultiVector(parameter_a.multi_vector_class()), Expression { @@ -1065,6 +1077,7 @@ impl MultiVectorClass { content: ExpressionContent::Variable(parameter_a.data_type.clone(), "x"), }), geometric_product_result.name, + geometric_product_result.data_type.clone(), vec![( DataType::MultiVector(parameter_a.multi_vector_class()), Expression { @@ -1104,6 +1117,7 @@ impl MultiVectorClass { content: ExpressionContent::Variable(parameter_a.data_type.clone(), "x"), }), geometric_product_result.name, + geometric_product_result.data_type.clone(), vec![( DataType::MultiVector(parameter_a.multi_vector_class()), Expression { @@ -1143,6 +1157,7 @@ impl MultiVectorClass { content: ExpressionContent::Variable(parameter_a.data_type.clone(), parameter_a.name), }), geometric_product_result.name, + geometric_product_result.data_type.clone(), vec![( DataType::MultiVector(inverse_result.multi_vector_class()), Expression { @@ -1154,6 +1169,7 @@ impl MultiVectorClass { content: ExpressionContent::Variable(parameter_b.data_type.clone(), parameter_b.name), }), inverse_result.name, + inverse_result.data_type.clone(), vec![], ), }, @@ -1189,6 +1205,7 @@ impl MultiVectorClass { content: ExpressionContent::Variable(parameter_a.data_type.clone(), parameter_a.name), }), geometric_product_result.name, + geometric_product_result.data_type.clone(), vec![( DataType::MultiVector(parameter_b.multi_vector_class()), Expression { @@ -1199,6 +1216,7 @@ impl MultiVectorClass { ), }), geometric_product_2_result.name, + geometric_product_2_result.data_type.clone(), vec![( DataType::MultiVector(involution_result.multi_vector_class()), Expression { @@ -1210,6 +1228,7 @@ impl MultiVectorClass { content: ExpressionContent::Variable(parameter_a.data_type.clone(), parameter_a.name), }), involution_result.name, + involution_result.data_type.clone(), vec![], ), }, diff --git a/codegen/src/glsl.rs b/codegen/src/glsl.rs index b761a83..bab3ddb 100644 --- a/codegen/src/glsl.rs +++ b/codegen/src/glsl.rs @@ -20,9 +20,9 @@ fn emit_expression(collector: &mut W, expression: &Expression ExpressionContent::Variable(_data_type, name) => { collector.write_all(name.bytes().collect::>().as_slice())?; } - ExpressionContent::InvokeClassMethod(_, _, arguments) | ExpressionContent::InvokeInstanceMethod(_, _, _, arguments) => { + ExpressionContent::InvokeClassMethod(_, _, arguments) | ExpressionContent::InvokeInstanceMethod(_, _, _, _, arguments) => { match &expression.content { - ExpressionContent::InvokeInstanceMethod(result_class, inner_expression, method_name, _) => { + ExpressionContent::InvokeInstanceMethod(result_class, inner_expression, method_name, _, _) => { if let DataType::MultiVector(result_class) = result_class { camel_to_snake_case(collector, &result_class.class_name)?; collector.write_all(b"_")?; diff --git a/codegen/src/rust.rs b/codegen/src/rust.rs index 590c9be..a141e03 100644 --- a/codegen/src/rust.rs +++ b/codegen/src/rust.rs @@ -18,9 +18,9 @@ fn emit_expression(collector: &mut W, expression: &Expression ExpressionContent::Variable(_data_type, name) => { collector.write_all(name.bytes().collect::>().as_slice())?; } - ExpressionContent::InvokeClassMethod(_, method_name, arguments) | ExpressionContent::InvokeInstanceMethod(_, _, method_name, arguments) => { + ExpressionContent::InvokeClassMethod(_, method_name, arguments) | ExpressionContent::InvokeInstanceMethod(_, _, method_name, _, arguments) => { match &expression.content { - ExpressionContent::InvokeInstanceMethod(_result_class, inner_expression, _, _) => { + ExpressionContent::InvokeInstanceMethod(_result_class, inner_expression, _, _, _) => { emit_expression(collector, inner_expression)?; collector.write_all(b".")?; }