From 5d72147f2087397d5fad3c251edf7a3da808841a Mon Sep 17 00:00:00 2001 From: Ameya Daigavane Date: Sun, 15 Dec 2024 16:37:18 +0530 Subject: [PATCH] Add division and negation for SO3Signal. --- e3nn_jax/_src/so3grid.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/e3nn_jax/_src/so3grid.py b/e3nn_jax/_src/so3grid.py index 6c94c77..fbefdd2 100644 --- a/e3nn_jax/_src/so3grid.py +++ b/e3nn_jax/_src/so3grid.py @@ -125,7 +125,23 @@ def __mul__(self, other: Union[float, "SO3Signal"]) -> "SO3Signal": return SO3Signal(self.s2_signals * other) - def __truediv__(self, other: float) -> "SO3Signal": + def __rmul__(self, other: float) -> "SO3Signal": + return self * other + + def __neg__(self) -> "SO3Signal": + return self * -1 + + def __truediv__(self, other: Union[float, "SO3Signal"]) -> "SO3Signal": + if isinstance(other, SO3Signal): + if self.shape != other.shape: + raise ValueError( + f"Shapes of the two signals do not match: {self.shape} != {other.shape}" + ) + + return self.replace_values( + self.grid_values / other.grid_values + ) + return self * (1 / other) def apply(self, func: Callable[..., jnp.ndarray]) -> "SO3Signal":