Skip to content

Commit

Permalink
Add division and negation for SO3Signal.
Browse files Browse the repository at this point in the history
  • Loading branch information
ameya98 committed Dec 15, 2024
1 parent 08a7e81 commit 5d72147
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion e3nn_jax/_src/so3grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,23 @@ def __mul__(self, other: Union[float, "SO3Signal"]) -> "SO3Signal":

return SO3Signal(self.s2_signals * other)

def __truediv__(self, other: float) -> "SO3Signal":
def __rmul__(self, other: float) -> "SO3Signal":
return self * other

def __neg__(self) -> "SO3Signal":
return self * -1

def __truediv__(self, other: Union[float, "SO3Signal"]) -> "SO3Signal":
if isinstance(other, SO3Signal):
if self.shape != other.shape:
raise ValueError(
f"Shapes of the two signals do not match: {self.shape} != {other.shape}"
)

return self.replace_values(
self.grid_values / other.grid_values
)

return self * (1 / other)

def apply(self, func: Callable[..., jnp.ndarray]) -> "SO3Signal":
Expand Down

0 comments on commit 5d72147

Please sign in to comment.