Skip to content

Commit

Permalink
* fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
Joshuaalbert committed Aug 14, 2024
1 parent 0c51ac8 commit 0aa3bd4
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions essm_jax/tests/test_jvp_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_jvp_linear_op():
def fn(x):
return jnp.asarray([jnp.sum(jnp.sin(x) ** i) for i in range(m)])

x = jnp.arange(n).astype(jnp.float32)
x = jnp.arange(n).astype(float)

jvp_op = JVPLinearOp(fn)
jvp_op = jvp_op(x)
Expand Down Expand Up @@ -73,8 +73,8 @@ def test_multiple_primals(init_primals: bool):
def fn(x, y):
return jnp.stack([x * y, y, -y], axis=-1) # [n, 3]

x = jnp.arange(n).astype(jnp.float32)
y = jnp.arange(n).astype(jnp.float32)
x = jnp.arange(n).astype(float)
y = jnp.arange(n).astype(float)
if init_primals:
jvp_op = JVPLinearOp(fn, primals=(x, y))
else:
Expand Down

0 comments on commit 0aa3bd4

Please sign in to comment.