Skip to content

Commit

Permalink
Add full example to README
Browse files Browse the repository at this point in the history
  • Loading branch information
adrhill committed Oct 11, 2024
1 parent 12046e5 commit c91aa45
Showing 1 changed file with 22 additions and 11 deletions.
33 changes: 22 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,31 @@ Let's explain why an image of a castle is classified as such by a vision model:

```julia
using ExplainableAI

# Load model and input
model = ... # load classifier model
input = ... # input in batch-dimension-last format
using VisionHeatmaps # visualization of explanations as heatmaps
using Zygote # load autodiff backend for gradient-based methods
using Flux, Metalhead # pre-trained vision models in Flux
using DataAugmentation # input preprocessing
using HTTP, FileIO, ImageIO # load image from URL
using ImageInTerminal # show heatmap in terminal

# Load & prepare model
model = VGG(16, pretrain=true)

# Load input
url = HTTP.URI("https://raw.githubusercontent.com/Julia-XAI/ExplainableAI.jl/gh-pages/assets/heatmaps/castle.jpg")
img = load(url)

# Preprocess input
mean = (0.485f0, 0.456f0, 0.406f0)
std = (0.229f0, 0.224f0, 0.225f0)
tfm = CenterResizeCrop((224, 224)) |> ImageToTensor() |> Normalize(mean, std)
input = apply(tfm, Image(img)) # apply DataAugmentation transform
input = reshape(input.data, 224, 224, 3, :) # unpack data and add batch dimension

# Run XAI method
analyzer = SmoothGrad(model)
expl = analyze(input, analyzer) # or: analyzer(input)

# Show heatmap
heatmap(expl)

# Or analyze & show heatmap directly
heatmap(input, analyzer)
expl = analyze(input, analyzer) # or: expl = analyzer(input)
heatmap(expl) # show heatmap using VisionHeatmaps.jl
```

By default, explanations are computed for the class with the highest activation.
Expand Down

0 comments on commit c91aa45

Please sign in to comment.