Skip to content

Commit

Permalink
Fix default noise level for NoiseAugmentation (#179)
Browse files Browse the repository at this point in the history
  • Loading branch information
adrhill authored Dec 2, 2024
1 parent 61cf3f4 commit bd0e400
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 11 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ExplainableAI"
uuid = "4f1bc3e1-d60d-4ed0-9367-9bdff9846d3b"
authors = ["Adrian Hill <gh@adrianhill.de>"]
version = "0.9.0"
version = "0.10.0-DEV"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
4 changes: 2 additions & 2 deletions src/gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ function call_analyzer(
end

"""
SmoothGrad(analyzer, [n=50, std=0.1, rng=GLOBAL_RNG])
SmoothGrad(analyzer, [n=50, distribution=Normal(0, σ²=0.01), rng=GLOBAL_RNG])
SmoothGrad(analyzer, [n=50, std=1.0f0, rng=GLOBAL_RNG])
SmoothGrad(analyzer, [n=50, distribution=Normal(0.0f0, 1.0f0), rng=GLOBAL_RNG])
Analyze model by calculating a smoothed sensitivity map.
This is done by averaging sensitivity maps of a `Gradient` analyzer over random samples
Expand Down
25 changes: 17 additions & 8 deletions src/input_augmentation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ end
"""
augment_indices(indices, n)
Strip batch indices and return inidices for batch augmented by n samples.
Strip batch indices and return indices for batch augmented by n samples.
## Example
```julia-repl
Expand All @@ -83,11 +83,20 @@ function augment_indices(inds::Vector{CartesianIndex{N}}, n) where {N}
end

"""
NoiseAugmentation(analyzer, n, [std=1, rng=GLOBAL_RNG])
NoiseAugmentation(analyzer, n, distribution, [rng=GLOBAL_RNG])
NoiseAugmentation(analyzer, n)
NoiseAugmentation(analyzer, n, std::Real)
NoiseAugmentation(analyzer, n, distribution::Sampleable)
A wrapper around analyzers that augments the input with `n` samples of additive noise sampled from `distribution`.
A wrapper around analyzers that augments the input with `n` samples of additive noise sampled from a scalar `distribution`.
This input augmentation is then averaged to return an `Explanation`.
Defaults to the normal distribution `Normal(0, std^2)` with `std=1.0f0`.
For optimal results, $REF_SMILKOV_SMOOTHGRAD recommends setting `std` between 10% and 20% of the input range of each sample,
e.g. `std = 0.1 * (maximum(input) - minimum(input))`.
## Keyword arguments
- `rng::AbstractRNG`: Specify the random number generator that is used to sample noise from the `distribution`.
Defaults to `GLOBAL_RNG`.
"""
struct NoiseAugmentation{A<:AbstractXAIMethod,D<:Sampleable,R<:AbstractRNG} <:
AbstractXAIMethod
Expand All @@ -96,11 +105,11 @@ struct NoiseAugmentation{A<:AbstractXAIMethod,D<:Sampleable,R<:AbstractRNG} <:
distribution::D
rng::R
end
function NoiseAugmentation(analyzer, n, distr::Sampleable, rng=GLOBAL_RNG)
return NoiseAugmentation(analyzer, n, distr::Sampleable, rng)
function NoiseAugmentation(analyzer, n, distribution::Sampleable, rng=GLOBAL_RNG)
return NoiseAugmentation(analyzer, n, distribution::Sampleable, rng)
end
function NoiseAugmentation(analyzer, n, σ::Real=0.1f0, args...)
return NoiseAugmentation(analyzer, n, Normal(0.0f0, Float32(σ)^2), args...)
function NoiseAugmentation(analyzer, n, std::T=1.0f0, rng=GLOBAL_RNG) where {T<:Real}
return NoiseAugmentation(analyzer, n, Normal(zero(T), std^2), rng)
end

function call_analyzer(input, aug::NoiseAugmentation, ns::AbstractOutputSelector; kwargs...)
Expand Down

0 comments on commit bd0e400

Please sign in to comment.