From c806f94dc1222817f65ac854277235d73512fc0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexander=20Mei=C3=9Fner?= Date: Wed, 20 Sep 2023 16:10:25 +0200 Subject: [PATCH] Improves MultiVectorClass::is_scalar(). --- codegen/src/compile.rs | 2 +- codegen/src/main.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/codegen/src/compile.rs b/codegen/src/compile.rs index f531be4..c7db1cd 100644 --- a/codegen/src/compile.rs +++ b/codegen/src/compile.rs @@ -147,7 +147,7 @@ impl MultiVectorClass { } pub fn is_scalar(&self) -> bool { - self.flat_basis() == vec![BasisElement { scalar: 1, index: 0 }] + self.grouped_basis == vec![vec![BasisElement::from_index(0)]] } pub fn signature(&self) -> Vec { diff --git a/codegen/src/main.rs b/codegen/src/main.rs index fe0831a..7d76bd2 100644 --- a/codegen/src/main.rs +++ b/codegen/src/main.rs @@ -141,7 +141,7 @@ fn main() { } for (parameter_b, pair_trait_implementations) in pair_trait_implementations.values() { if let Some(geometric_product) = pair_trait_implementations.get("GeometricProduct") { - if parameter_b.multi_vector_class().grouped_basis == vec![vec![BasisElement::from_index(0)]] { + if parameter_b.data_type.is_scalar() { let scale = MultiVectorClass::derive_scale("Scale", geometric_product, ¶meter_a, parameter_b); emitter.emit(&scale).unwrap(); if let Some(magnitude) = single_trait_implementations.get("Magnitude") {