Skip to content

Commit

Permalink
Allow Gradient analyzers on non-Flux models (#150)
Browse files Browse the repository at this point in the history
* Allow Gradient analyzers on non-Flux models

* Fix typo in `BATCHDIM_MISSING` error
  • Loading branch information
adrhill authored Oct 17, 2023
1 parent bfaf500 commit 5dd3dfd
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 15 deletions.
12 changes: 5 additions & 7 deletions src/analyze_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ abstract type AbstractXAIMethod end

const BATCHDIM_MISSING = ArgumentError(
"""The input is a 1D vector and therefore missing the required batch dimension.
Call `analyze` with the keyword argument `add_batch_dim=false`."""
Call analyze with the keyword argument add_batch_dim=true."""
)

"""
Expand Down Expand Up @@ -46,16 +46,14 @@ end

# lower-level call to method
function _analyze(
input::AbstractArray{T,N},
input::AbstractArray,
method::AbstractXAIMethod,
sel::AbstractNeuronSelector;
add_batch_dim::Bool=false,
kwargs...,
) where {T<:Real,N}
if add_batch_dim
return method(batch_dim_view(input), sel; kwargs...)
end
N < 2 && throw(BATCHDIM_MISSING)
)
add_batch_dim && (input = batch_dim_view(input))
ndims(input) < 2 && throw(BATCHDIM_MISSING)
return method(input, sel; kwargs...)
end

Expand Down
14 changes: 9 additions & 5 deletions src/gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@ end
Analyze model by calculating the gradient of a neuron activation with respect to the input.
"""
struct Gradient{C<:Chain} <: AbstractXAIMethod
model::C
struct Gradient{M} <: AbstractXAIMethod
model::M
Gradient(model) = new{typeof(model)}(model)
Gradient(model::Chain) = new{typeof(model)}(Flux.testmode!(check_output_softmax(model)))
end

function (analyzer::Gradient)(input, ns::AbstractNeuronSelector)
grad, output, output_indices = gradient_wrt_input(analyzer.model, input, ns)
return Explanation(grad, output, output_indices, :Gradient, nothing)
Expand All @@ -29,12 +31,14 @@ end
Analyze model by calculating the gradient of a neuron activation with respect to the input.
This gradient is then multiplied element-wise with the input.
"""
struct InputTimesGradient{C<:Chain} <: AbstractXAIMethod
model::C
struct InputTimesGradient{M} <: AbstractXAIMethod
model::M
InputTimesGradient(model) = new{typeof(model)}(model)
function InputTimesGradient(model::Chain)
return new{typeof(model)}(Flux.testmode!(check_output_softmax(model)))
new{typeof(model)}(Flux.testmode!(check_output_softmax(model)))
end
end

function (analyzer::InputTimesGradient)(input, ns::AbstractNeuronSelector)
grad, output, output_indices = gradient_wrt_input(analyzer.model, input, ns)
attr = input .* grad
Expand Down
6 changes: 3 additions & 3 deletions src/heatmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ See also [`analyze`](@ref).
instead of computing it individually for each sample. Defaults to `false`.
"""
function heatmap(
val::AbstractArray{T,N};
val::AbstractArray;
cs::ColorScheme=ColorSchemes.seismic,
reduce::Symbol=:sum,
rangescale::Symbol=:centered,
permute::Bool=true,
unpack_singleton::Bool=true,
process_batch::Bool=false,
) where {T,N}
N != 4 && throw(
)
ndims(val) != 4 && throw(
ArgumentError(
"heatmap assumes Flux's WHCN convention (width, height, color channels, batch size) for the input.
Please reshape your explanation to match this format if your model doesn't adhere to this convention.",
Expand Down

0 comments on commit 5dd3dfd

Please sign in to comment.