diff --git a/essm_jax/tests/test_jvp_op.py b/essm_jax/tests/test_jvp_op.py index 5022f45..1799632 100644 --- a/essm_jax/tests/test_jvp_op.py +++ b/essm_jax/tests/test_jvp_op.py @@ -16,7 +16,7 @@ def test_jvp_linear_op(): def fn(x): return jnp.asarray([jnp.sum(jnp.sin(x) ** i) for i in range(m)]) - x = jnp.arange(n).astype(jnp.float32) + x = jnp.arange(n).astype(float) jvp_op = JVPLinearOp(fn) jvp_op = jvp_op(x) @@ -73,8 +73,8 @@ def test_multiple_primals(init_primals: bool): def fn(x, y): return jnp.stack([x * y, y, -y], axis=-1) # [n, 3] - x = jnp.arange(n).astype(jnp.float32) - y = jnp.arange(n).astype(jnp.float32) + x = jnp.arange(n).astype(float) + y = jnp.arange(n).astype(float) if init_primals: jvp_op = JVPLinearOp(fn, primals=(x, y)) else: