Skip to content

Commit

Permalink
Add enzyme rules for incomplete Elliptic Pi & Bump to version 0.3.3
Browse files Browse the repository at this point in the history
  • Loading branch information
dominic-chang committed Nov 27, 2024
1 parent cd4b350 commit 202010f
Show file tree
Hide file tree
Showing 3 changed files with 231 additions and 48 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "JacobiElliptic"
uuid = "2a8b799e-c098-4961-872a-356c768d184c"
authors = ["dominicchang <dochang@g.harvard.com>"]
version = "0.3.2"
version = "0.3.3"

[deps]
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Expand All @@ -23,4 +23,4 @@ Enzyme = "0.13"
ForwardDiff = "0.10"
StaticArrays = "1.6"
Zygote = "0.6"
julia = "1.8"
julia = "1.10"
140 changes: 139 additions & 1 deletion ext/JacobiEllipticEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ function forward(
i ->isa Const ? zero.val) : ∂F_∂ϕ.val, m.val)*ϕ.dval[i]) + (m isa Const ? zero(m.val) : ∂F_∂m.val, m.val)*m.dval[i]), Val(EnzymeRules.width(config))
)
)
d end
end
elseif EnzymeRules.needs_shadow(config)
if EnzymeRules.width(config) == 1
returnisa Const ? zero.val) : ∂F_∂ϕ.val, m.val)*ϕ.dval) +(m isa Const ? zero(m.val) : ∂F_∂m.val, m.val)*m.dval)
Expand Down Expand Up @@ -211,4 +211,142 @@ function reverse(
end
return (dϕ, dm)
end

#----------------------------------------------------------------------------------------
# Elliptic Pi(n, ϕ, m)
#----------------------------------------------------------------------------------------
function ∂Pi_∂n(n, ϕ, m)
return (
JacobiElliptic.CarlsonAlg.E(ϕ, m) +
(m-n)*JacobiElliptic.CarlsonAlg.F(ϕ, m)/n +
(n^2-m)*JacobiElliptic.CarlsonAlg.Pi(n, ϕ, m)/n -
n*√(1-m*sin(ϕ)^2)*sin(2ϕ) / (2(1 - n*sin(ϕ)^2))
) / (2 * (m-n)*(n-1))
end

function ∂Pi_∂m(n, ϕ, m)
return (
JacobiElliptic.CarlsonAlg.E(ϕ, m) / (m-1) +
JacobiElliptic.CarlsonAlg.Pi(n, ϕ, m) -
m*sin(2*ϕ) / (2*(m-1) * (1 - m*sin(ϕ)^2))
) / (2 * (n-m))
end

function ∂Pi_∂ϕ(n, ϕ, m)
return 1 / ((1 - m*sin(ϕ)^2)*(1-n*sin(ϕ)^2))
end


function forward(
# https://enzymead.github.io/Enzyme.jl/stable/#Forward-mode
# Of note, when we seed both arguments at once the tangent return is the sum of both.
config::EnzymeRules.FwdConfig,
func::Const{typeof(JacobiElliptic.CarlsonAlg.Pi)},
RT,
n::Annotation{<:Real},
ϕ::Annotation{<:Real},
m::Annotation{<:Real}
)
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
if EnzymeRules.width(config) == 1
return Duplicated(
func.val(n.val, ϕ.val, m.val),
(n isa Const ? zero(n.val) : ∂Pi_∂n(n.val, ϕ.val, m.val)*n.dval) +isa Const ? zero.val) : ∂Pi_∂ϕ(n.val, ϕ.val, m.val)*ϕ.dval) + (m isa Const ? zero(m.val) : ∂Pi_∂m(n.val, ϕ.val, m.val)*m.dval)
)
else
return BatchDuplicated(
func.val(n.val, ϕ.val, m.val),
ntuple(
i -> (n isa Const ? zero(n.val) : ∂Pi_∂n(n.val, ϕ.val, m.val)*n.dval[i]) +isa Const ? zero.val) : ∂Pi_∂ϕ(n.val, ϕ.val, m.val)*ϕ.dval[i]) + (m isa Const ? zero(m.val) : ∂Pi_∂m(n.val, ϕ.val, m.val)*m.dval[i]), Val(EnzymeRules.width(config))
)
)
end
elseif EnzymeRules.needs_shadow(config)
if EnzymeRules.width(config) == 1
return (n isa Const ? zero(n.val) : ∂Pi_∂n(n.val, ϕ.val, m.val)*n.dval) +isa Const ? zero.val) : ∂Pi_∂ϕ(n.val, ϕ.val, m.val)*ϕ.dval) + (m isa Const ? zero(m.val) : ∂Pi_∂m(n.val, ϕ.val, m.val)*m.dval)
else
return ntuple(i -> (n isa Const ? zero(n.val) : ∂Pi_∂n(n.val, ϕ.val, m.val)*n.dval[i]) +isa Const ? zero.val) : ∂Pi_∂ϕ(n.val, ϕ.val, m.val)*ϕ.dval[i]) + (m isa Const ? zero(m.val) : ∂Pi_∂m(n.val, ϕ.val, m.val)*m.dval[i]), Val(EnzymeRules.width(config)))
end
elseif EnzymeRules.needs_primal(config)
return func.val(n.val, ϕ.val, m.val)
else
return nothing
end
end

function augmented_primal(
config::RevConfigWidth,
func::Const{typeof(JacobiElliptic.CarlsonAlg.F)},
::Type,
n::Annotation{<:Real},
ϕ::Annotation{<:Real},
m::Annotation{<:Real}
)
primal = EnzymeRules.needs_primal(config) ? func.val(n.val, ϕ.val, m.val) : nothing

return EnzymeRules.AugmentedReturn(primal, nothing, nothing)
end

function reverse(
config::RevConfigWidth,
func::Const{typeof(JacobiElliptic.CarlsonAlg.F)},
dret,
tape,
n::Annotation{T},
ϕ::Annotation{T},
m::Annotation{T}
) where T
dn = if n isa Const
nothing
elseif EnzymeRules.width(config) == 1
if dret isa Type{<:Const}
zero(n.val)
else
∂Pi_∂n(n.val, ϕ.val, m.val) * dret.val
end
else
if dret isa Type{<:Const}
ntuple(i -> zero(n.val), Val(EnzymeRules.width(config)))
else
ntuple(i -> ∂Pi_∂n(n.val, ϕ.val, m.val) * dret.val[i], Val(EnzymeRules.width(config)))
end
end


= if ϕ isa Const
nothing
elseif EnzymeRules.width(config) == 1
if dret isa Type{<:Const}
zero.val)
else
∂Pi_∂ϕ(n.val, ϕ.val, m.val) * dret.val
end
else
if dret isa Type{<:Const}
ntuple(i -> zero.val), Val(EnzymeRules.width(config)))
else
ntuple(i -> ∂Pi_∂ϕ(n.val, ϕ.val, m.val) * dret.val[i], Val(EnzymeRules.width(config)))
end
end

dm = if m isa Const
nothing
elseif EnzymeRules.width(config) == 1
if dret isa Type{<:Const}
zero.val)
else
∂Pi_∂m(n.val, ϕ.val, m.val) * dret.val
end
else
if dret isa Type{<:Const}
ntuple(i -> zero.val), Val(EnzymeRules.width(config)))
else
ntuple(i -> ∂Pi_∂m(n.val, ϕ.val, m.val) * dret.val[i], Val(EnzymeRules.width(config)))
end
end
return (dn, dϕ, dm)
end



end
135 changes: 90 additions & 45 deletions test/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,62 +7,107 @@ using SpecialFunctions
@testset "Zygote and ForwardDiff" begin

num_trials = 1
ks = rand(num_trials)
ms = rand(num_trials)
ϕs = rand(num_trials) .* 2π
ns = rand(num_trials)

@show ks, ϕs, ns
@show ms, ϕs, ns

# Test several known derivative identities across a wide range of values of ϕ and m
# in order to verify that derivatives work correctly
@testset for (k, ϕ, n) in zip(ks, ϕs, ns)
m = k^2
dk_dm = inv(2k)
@testset for (m, ϕ, n) in zip(ms, ϕs, ns)

# I. Tests for complete integrals, from https://en.wikipedia.org/wiki/Elliptic_integral
# 1. K'(m) == K'(k) * dk/dm == E(k) / (k * (1 - k^2)) - K(k)/k
@test Zygote.gradient(alg.K, m)[1] (alg.E(m) / (k * (1 - k^2)) - alg.K(m) / k) * dk_dm
@test ForwardDiff.derivative(alg.K, m) (alg.E(m) / (k * (1 - k^2)) - alg.K(m) / k) * dk_dm
@test Enzyme.autodiff(Reverse, alg.K, Active, Active(m))[1][1] (alg.E(m) / (k * (1 - k^2)) - alg.K(m) / k) * dk_dm
# 1. K'(m) = E(m) / (2m * (1 - m)) - K(m)/2m
@testset "Complete K" begin
grad = (alg.E(m) / (2m * (1 - m)) - alg.K(m) / (2m))
@test Zygote.gradient(alg.K, m)[1] grad
@test ForwardDiff.derivative(alg.K, m) grad
@test Enzyme.autodiff(Reverse, alg.K, Active, Active(m))[1][1] grad
@test Enzyme.autodiff(Forward, alg.K, Duplicated, Duplicated(m, 1.0))[1][1] grad
end

# 2. E'(m) == E'(k) * dk/dm == (E(k) - K(k))/k
@test Zygote.gradient(alg.E, m)[1] (alg.E(m) - alg.K(m))/k * dk_dm
@test ForwardDiff.derivative(alg.E, m) (alg.E(m) - alg.K(m))/k * dk_dm
@test Enzyme.autodiff(Reverse, alg.E, Active, Active(m))[1][1] (alg.E(m) - alg.K(m))/k * dk_dm
# 2. E'(m) = (E(m) - K(m))/2m
@testset "Complete E" begin
grad = (alg.E(m) - alg.K(m))/2m
@test Zygote.gradient(alg.E, m)[1] grad
@test ForwardDiff.derivative(alg.E, m) grad
@test Enzyme.autodiff(Reverse, alg.E, Active, Active(m))[1][1] grad
@test Enzyme.autodiff(Forward, alg.E, Duplicated, Duplicated(m, 1.0))[1][1] grad
end

# 3. ∂_n Pi(n, m) = (E(m) + (m-n)*K(m)/n + (n^2-m)*Pi(n,m)/n)/(2*(m-n)*(n-1))
@testset "Complete Pi" begin
_Pi = alg.Pi
grad = (alg.E(m)/(m-1) + _Pi(n, m))/(2*(n-m))
@test Zygote.gradient(m->_Pi(n, m), m)[1] grad
@test ForwardDiff.derivative(m->_Pi(n, m), m) grad
@test Enzyme.autodiff(Reverse, _Pi, Active, Const(n), Active(m))[1][2] grad
end

# II. Tests for incomplete integrals, from https://functions.wolfram.com/EllipticIntegrals/EllipticF/introductions/IncompleteEllipticIntegrals/ShowAll.html
# 3. ∂ϕ(F(ϕ, m)) == 1 / √(1 - m*sin(ϕ)^2)
_F = alg.F
@test Zygote.gradient-> _F(ϕ, m), ϕ)[1] 1 / (1 - m*sin(ϕ)^2) atol=1e-5
@test ForwardDiff.derivative-> _F(ϕ, m), ϕ) 1 / (1 - m*sin(ϕ)^2) atol=1e-5
@test Enzyme.autodiff(Reverse, _F, Active, Active(ϕ), Const(m))[1][1] 1 / (1 - m*sin(ϕ)^2) atol=1e-5
@test Enzyme.autodiff(Forward, _F, Duplicated, Duplicated(ϕ, 1.0), Const(m))[1][1] 1 / (1 - m*sin(ϕ)^2) atol=1e-5

# 4. ∂m(F(ϕ, m)) == E(ϕ, m) / (2 * m * (1 - m)) - F(ϕ, m) / 2m - sin(2ϕ) / (4 * (1-m) * √(1 - m * sin(ϕ)^2))
@test Zygote.gradient(m -> _F(ϕ, m), m)[1]
alg.E(ϕ, m) / (2 * m * (1 - m)) -
alg.F(ϕ, m) / 2 / m -
sin(2*ϕ) / (4 * (1 - m) * (1 - m * sin(ϕ)^2)) atol=1e-5
@test ForwardDiff.derivative(m -> _F(ϕ, m), m)
alg.E(ϕ, m) / (2 * m * (1 - m)) -
alg.F(ϕ, m) / 2 / m -
sin(2*ϕ) / (4 * (1 - m) * (1 - m * sin(ϕ)^2)) atol=1e-5
@test Enzyme.autodiff(Reverse, _F, Active, Const(ϕ), Active(m))[1][2]
alg.E(ϕ, m) / (2 * m * (1 - m)) -
alg.F(ϕ, m) / 2 / m -
sin(2*ϕ) / (4 * (1 - m) * (1 - m * sin(ϕ)^2)) atol=1e-5

_E = alg.E
# 5. ∂ϕ(E(ϕ, m)) == √(1 - m * sin(ϕ)^2)
@test Zygote.gradient-> _E(ϕ, m), ϕ)[1] (1 - m * sin(ϕ)^2) atol=1e-5
@test ForwardDiff.derivative-> _E(ϕ, m), ϕ) (1 - m * sin(ϕ)^2) atol=1e-5
@test Enzyme.autodiff(Reverse, _E, Active, Active(ϕ), Const(m))[1][1] (1 - m * sin(ϕ)^2) atol=1e-5
@test Enzyme.autodiff(Forward, _E, Duplicated, Duplicated(ϕ, 1.0), Const(m))[1][1] (1 - m*sin(ϕ)^2) atol=1e-5

# 6. ∂m(E(ϕ, m)) == (E(ϕ, m) - F(ϕ, m)) / 2m
@test Zygote.gradient(m -> _E(ϕ, m), m)[1] (alg.E(ϕ, m) - alg.F(ϕ, m)) / 2m
@test ForwardDiff.derivative(m -> _E(ϕ, m), m) (alg.E(ϕ, m) - alg.F(ϕ, m)) / 2m
@test Enzyme.autodiff(Reverse, _E, Active, Const(ϕ), Active(m))[1][2] (alg.E(ϕ, m) - alg.F(ϕ, m)) / 2m atol=1e-5
@testset "Incomplete F" begin
_F = alg.F

# 4. ∂ϕ(F(ϕ, m)) == 1 / √(1 - m*sin(ϕ)^2)
grad = 1 / (1 - m*sin(ϕ)^2)
@test Zygote.gradient-> _F(ϕ, m), ϕ)[1] grad atol=1e-5
@test ForwardDiff.derivative-> _F(ϕ, m), ϕ) grad atol=1e-5
@test Enzyme.autodiff(Reverse, _F, Active, Active(ϕ), Const(m))[1][1] grad atol=1e-5
@test Enzyme.autodiff(Forward, _F, Duplicated, Duplicated(ϕ, 1.0), Const(m))[1][1] grad atol=1e-5

# 5. ∂m(F(ϕ, m)) == E(ϕ, m) / (2 * m * (1 - m)) - F(ϕ, m) / 2m - sin(2ϕ) / (4 * (1-m) * √(1 - m * sin(ϕ)^2))
grad = alg.E(ϕ, m) / (2 * m * (1 - m)) -
alg.F(ϕ, m) / 2 / m -
sin(2*ϕ) / (4 * (1 - m) * (1 - m * sin(ϕ)^2))
@test Zygote.gradient(m -> _F(ϕ, m), m)[1] grad atol=1e-5
@test ForwardDiff.derivative(m -> _F(ϕ, m), m) grad atol=1e-5
@test Enzyme.autodiff(Reverse, _F, Active, Const(ϕ), Active(m))[1][2] grad atol=1e-5
@test Enzyme.autodiff(Forward, _F, Duplicated, Const(ϕ), Duplicated(m, 1.0))[1][1] grad atol=1e-5
end

@testset "Incomplete E" begin
_E = alg.E

# 6. ∂ϕ(E(ϕ, m)) == √(1 - m * sin(ϕ)^2)
grad = (1 - m * sin(ϕ)^2)
@test Zygote.gradient-> _E(ϕ, m), ϕ)[1] grad atol=1e-5
@test ForwardDiff.derivative-> _E(ϕ, m), ϕ) grad atol=1e-5
@test Enzyme.autodiff(Reverse, _E, Active, Active(ϕ), Const(m))[1][1] grad atol=1e-5
@test Enzyme.autodiff(Forward, _E, Duplicated, Duplicated(ϕ, 1.0), Const(m))[1][1] grad atol=1e-5

# 7. ∂m(E(ϕ, m)) == (E(ϕ, m) - F(ϕ, m)) / 2m
grad = (alg.E(ϕ, m) - alg.F(ϕ, m)) / 2m
@test Zygote.gradient(m -> _E(ϕ, m), m)[1] grad atol=1e-5
@test ForwardDiff.derivative(m -> _E(ϕ, m), m) grad atol=1e-5
@test Enzyme.autodiff(Reverse, _E, Active, Const(ϕ), Active(m))[1][2] grad atol=1e-5
@test Enzyme.autodiff(Forward, _E, Duplicated, Const(ϕ), Duplicated(m, 1.0))[1][1] grad atol=1e-5
end
@testset "Incomplete Pi" begin
_Pi = alg.Pi

# 7. ∂n(Pi(n, ϕ, m))
grad = (alg.E(ϕ, m) + (m-n)*alg.F(ϕ, m)/n + (n^2-m)*_Pi(n, ϕ, m)/n -n*√(1-m*sin(ϕ)^2)*sin(2ϕ)/(2*(1-n*sin(ϕ)^2)))/(2*(m-n)*(n-1))
@test Zygote.gradient(n -> _Pi(n, ϕ, m), n)[1] grad atol=1e-5
@test ForwardDiff.derivative(n -> _Pi(n, ϕ, m), n) grad atol=1e-5
@test Enzyme.autodiff(Reverse, _Pi, Active, Active(n), Const(ϕ), Const(m))[1][1] grad atol=1e-5
@test Enzyme.autodiff(Forward, _Pi, Duplicated, Duplicated(n, 1.0), Const(ϕ), Const(m))[1][1] grad atol=1e-5


# 8. ∂ϕ(Pi(n, ϕ, m))
grad = 1/((1 - m * sin(ϕ)^2)*(1-n*sin(ϕ)^2))
@test Zygote.gradient-> _Pi(n, ϕ, m), ϕ)[1] grad atol=1e-5
@test ForwardDiff.derivative-> _Pi(n, ϕ, m), ϕ) grad atol=1e-5
@test Enzyme.autodiff(Reverse, _Pi, Active, Const(n), Active(ϕ), Const(m))[1][2] grad atol=1e-5
@test Enzyme.autodiff(Forward, _Pi, Duplicated, Const(n), Duplicated(ϕ, 1.0), Const(m))[1][1] grad atol=1e-5

# 9. ∂m(Pi(n, ϕ, m))
grad = (alg.E(ϕ, m)/(m-1) + _Pi(n, ϕ, m) - m*sin(2ϕ)/(2*(m-1)*√(1-m*sin(ϕ)^2))) / 2(n-m)
@test Zygote.gradient(m -> _Pi(n, ϕ, m), m)[1] grad atol=1e-5
@test ForwardDiff.derivative(m -> _Pi(n, ϕ, m), m) grad atol=1e-5
@test Enzyme.autodiff(Reverse, _Pi, Active, Const(n), Const(ϕ), Active(m))[1][3] grad atol=1e-5
@test Enzyme.autodiff(Forward, _Pi, Duplicated, Const(n), Const(ϕ), Duplicated(m, 1.0))[1][1] grad atol=1e-5
end

end
end
Expand Down

2 comments on commit 202010f

@dominic-chang
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Change:

  • Add Enzyme rules for incomplete Elliptic Pi

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/120297

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.3 -m "<description of version>" 202010f56a1dc1ca08eca74b2ad54750d08e00b4
git push origin v0.3.3

Please sign in to comment.