Skip to content

Commit

Permalink
Use DifferentiationInterface for gradient-based analyzers (#167)
Browse files Browse the repository at this point in the history
  • Loading branch information
adrhill authored Jul 27, 2024
1 parent aaa3c72 commit 0bb2b17
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 18 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
# ExplainableAI.jl
## Version `v0.9.0`
- ![Feature][badge-feature] Support selection of AD backend via DifferentiationInterface.jl ([#167])
- `Gradient`, `InputTimesGradient` and `GradCAM` analyzers now have an additional `backend` field and type parameter
- ![Maintenance][badge-maintenance] Update XAIBase interface to v4 ([#166])

## Version `v0.8.0`
This release removes the automatic reexport of heatmapping functionality.
Users are now required to manually load
Expand Down Expand Up @@ -210,6 +215,8 @@ Performance improvements:
[VisionHeatmaps]: https://julia-xai.github.io/XAIDocs/VisionHeatmaps/stable/
[TextHeatmaps]: https://julia-xai.github.io/XAIDocs/TextHeatmaps/stable/

[#167]: https://github.com/Julia-XAI/ExplainableAI.jl/pull/167
[#166]: https://github.com/Julia-XAI/ExplainableAI.jl/pull/166
[#162]: https://github.com/Julia-XAI/ExplainableAI.jl/pull/162
[#159]: https://github.com/Julia-XAI/ExplainableAI.jl/pull/159
[#157]: https://github.com/Julia-XAI/ExplainableAI.jl/pull/157
Expand Down
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
name = "ExplainableAI"
uuid = "4f1bc3e1-d60d-4ed0-9367-9bdff9846d3b"
authors = ["Adrian Hill <gh@adrianhill.de>"]
version = "0.8.1"
version = "0.9.0-DEV"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Expand All @@ -12,6 +14,8 @@ XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ADTypes = "1"
DifferentiationInterface = "0.5"
Distributions = "0.25"
Random = "<0.0.1, 1"
Reexport = "1"
Expand Down
5 changes: 5 additions & 0 deletions src/ExplainableAI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@ import XAIBase: call_analyzer
using Base.Iterators
using Distributions: Distribution, Sampleable, Normal
using Random: AbstractRNG, GLOBAL_RNG

# Automatic differentiation
using ADTypes: AbstractADType, AutoZygote
using DifferentiationInterface: value_and_pullback
using Zygote
const DEFAULT_AD_BACKEND = AutoZygote()

include("compat.jl")
include("bibliography.jl")
Expand Down
13 changes: 11 additions & 2 deletions src/gradcam.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,25 @@ GradCAM is compatible with a wide variety of CNN model-families.
# References
- $REF_SELVARAJU_GRADCAM
"""
struct GradCAM{F,A} <: AbstractXAIMethod
struct GradCAM{F,A,B<:AbstractADType} <: AbstractXAIMethod
feature_layers::F
adaptation_layers::A
backend::B

function GradCAM(
feature_layers::F, adaptation_layers::A, backend::B=DEFAULT_AD_BACKEND
) where {F,A,B<:AbstractADType}
new{F,A,B}(feature_layers, adaptation_layers, backend)
end
end
function call_analyzer(input, analyzer::GradCAM, ns::AbstractOutputSelector; kwargs...)
A = analyzer.feature_layers(input) # feature map
feature_map_size = size(A, 1) * size(A, 2)

# Determine neuron importance αₖᶜ = 1/Z * ∑ᵢ ∑ⱼ ∂yᶜ / ∂Aᵢⱼᵏ
grad, output, output_indices = gradient_wrt_input(analyzer.adaptation_layers, A, ns)
grad, output, output_indices = gradient_wrt_input(
analyzer.adaptation_layers, A, ns, analyzer.backend
)
αᶜ = sum(grad; dims=(1, 2)) / feature_map_size
Lᶜ = max.(sum(αᶜ .* A; dims=3), 0)
return Explanation(Lᶜ, input, output, output_indices, :GradCAM, :cam, nothing)
Expand Down
57 changes: 42 additions & 15 deletions src/gradient.jl
Original file line number Diff line number Diff line change
@@ -1,26 +1,45 @@
function gradient_wrt_input(model, input, ns::AbstractOutputSelector)
output, back = Zygote.pullback(model, input)
output_indices = ns(output)

# Compute VJP w.r.t. full model output, selecting vector s.t. it masks output neurons
v = zero(output)
v[output_indices] .= 1
grad = only(back(v))
return grad, output, output_indices
function forward_with_output_selection(model, input, selector::AbstractOutputSelector)
output = model(input)
sel = selector(output)
return output[sel]
end

function gradient_wrt_input(
model, input, output_selector::AbstractOutputSelector, backend::AbstractADType
)
output = model(input)
return gradient_wrt_input(model, input, output, output_selector, backend)
end

function gradient_wrt_input(
model, input, output, output_selector::AbstractOutputSelector, backend::AbstractADType
)
output_selection = output_selector(output)
dy = zero(output)
dy[output_selection] .= 1

output, grad = value_and_pullback(model, backend, input, dy)
return grad, output, output_selection
end

"""
Gradient(model)
Analyze model by calculating the gradient of a neuron activation with respect to the input.
"""
struct Gradient{M} <: AbstractXAIMethod
struct Gradient{M,B<:AbstractADType} <: AbstractXAIMethod
model::M
Gradient(model) = new{typeof(model)}(model)
backend::B

function Gradient(model::M, backend::B=DEFAULT_AD_BACKEND) where {M,B<:AbstractADType}
new{M,B}(model, backend)
end
end

function call_analyzer(input, analyzer::Gradient, ns::AbstractOutputSelector; kwargs...)
grad, output, output_indices = gradient_wrt_input(analyzer.model, input, ns)
grad, output, output_indices = gradient_wrt_input(
analyzer.model, input, ns, analyzer.backend
)
return Explanation(
grad, input, output, output_indices, :Gradient, :sensitivity, nothing
)
Expand All @@ -32,15 +51,23 @@ end
Analyze model by calculating the gradient of a neuron activation with respect to the input.
This gradient is then multiplied element-wise with the input.
"""
struct InputTimesGradient{M} <: AbstractXAIMethod
struct InputTimesGradient{M,B<:AbstractADType} <: AbstractXAIMethod
model::M
InputTimesGradient(model) = new{typeof(model)}(model)
backend::B

function InputTimesGradient(
model::M, backend::B=DEFAULT_AD_BACKEND
) where {M,B<:AbstractADType}
new{M,B}(model, backend)
end
end

function call_analyzer(
input, analyzer::InputTimesGradient, ns::AbstractOutputSelector; kwargs...
)
grad, output, output_indices = gradient_wrt_input(analyzer.model, input, ns)
grad, output, output_indices = gradient_wrt_input(
analyzer.model, input, ns, analyzer.backend
)
attr = input .* grad
return Explanation(
attr, input, output, output_indices, :InputTimesGradient, :attribution, nothing
Expand Down

0 comments on commit 0bb2b17

Please sign in to comment.