Skip to content

Commit

Permalink
Update benchmarks and test them with PkgJogger (#175)
Browse files Browse the repository at this point in the history
  • Loading branch information
adrhill authored Oct 10, 2024
1 parent cf65291 commit 12046e5
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 96 deletions.
3 changes: 1 addition & 2 deletions benchmark/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ExplainableAI = "4f1bc3e1-d60d-4ed0-9367-9bdff9846d3b"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
PkgJogger = "10150987-6cc1-4b76-abee-b1c1cbd91c01"

[compat]
BenchmarkTools = "1"
Expand Down
47 changes: 47 additions & 0 deletions benchmark/bench_jogger.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
using BenchmarkTools
using Flux
using ExplainableAI

on_CI = haskey(ENV, "GITHUB_ACTIONS")

T = Float32
input_size = (32, 32, 3, 1)
input = rand(T, input_size)

model = Chain(
Chain(
Conv((3, 3), 3 => 8, relu; pad=1),
Conv((3, 3), 8 => 8, relu; pad=1),
MaxPool((2, 2)),
Conv((3, 3), 8 => 16, relu; pad=1),
Conv((3, 3), 16 => 16, relu; pad=1),
MaxPool((2, 2)),
),
Chain(
Flux.flatten,
Dense(1024 => 512, relu), # 102_764_544 parameters
Dropout(0.5),
Dense(512 => 100, relu),
),
)
Flux.testmode!(model, true)

# Use one representative algorithm of each type
METHODS = Dict(
"Gradient" => Gradient,
"InputTimesGradient" => InputTimesGradient,
"SmoothGrad" => model -> SmoothGrad(model, 5),
"IntegratedGradients" => model -> IntegratedGradients(model, 5),
)

# Define benchmark
construct(method, model) = method(model) # for use with @benchmarkable macro

suite = BenchmarkGroup()
suite["CNN"] = BenchmarkGroup([k for k in keys(METHODS)])
for (name, method) in METHODS
analyzer = method(model)
suite["CNN"][name] = BenchmarkGroup(["construct analyzer", "analyze"])
suite["CNN"][name]["constructor"] = @benchmarkable construct($(method), $(model))
suite["CNN"][name]["analyze"] = @benchmarkable analyze($(input), $(analyzer))
end
98 changes: 4 additions & 94 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
@@ -1,95 +1,5 @@
using BenchmarkTools
using LoopVectorization
using Tullio
using Flux
using PkgJogger
using ExplainableAI
using ExplainableAI: lrp!, modify_layer

on_CI = haskey(ENV, "GITHUB_ACTIONS")

T = Float32
input_size = (32, 32, 3, 1)
input = rand(T, input_size)

model = Chain(
Chain(
Conv((3, 3), 3 => 8, relu; pad=1),
Conv((3, 3), 8 => 8, relu; pad=1),
MaxPool((2, 2)),
Conv((3, 3), 8 => 16, relu; pad=1),
Conv((3, 3), 16 => 16, relu; pad=1),
MaxPool((2, 2)),
),
Chain(
Flux.flatten,
Dense(1024 => 512, relu), # 102_764_544 parameters
Dropout(0.5),
Dense(512 => 100, relu),
),
)
Flux.testmode!(model, true)

# Use one representative algorithm of each type
algs = Dict(
"Gradient" => Gradient,
"InputTimesGradient" => InputTimesGradient,
"LRP" => LRP,
"LREpsilonPlusFlat" => model -> LRP(model, EpsilonPlusFlat()),
"SmoothGrad" => model -> SmoothGrad(model, 5),
"IntegratedGradients" => model -> IntegratedGradients(model, 5),
)

# Define benchmark
_alg(alg, model) = alg(model) # for use with @benchmarkable macro

SUITE = BenchmarkGroup()
SUITE["CNN"] = BenchmarkGroup([k for k in keys(algs)])
for (name, alg) in algs
analyzer = alg(model)
SUITE["CNN"][name] = BenchmarkGroup(["construct analyzer", "analyze"])
SUITE["CNN"][name]["construct analyzer"] = @benchmarkable _alg($(alg), $(model))
SUITE["CNN"][name]["analyze"] = @benchmarkable analyze($(input), $(analyzer))
end

# generate input for conv layers
insize = (32, 32, 3, 1)
in_dense = 64
out_dense = 10
aᵏ = rand(T, insize)

layers = Dict(
"Conv" => (Conv((3, 3), 3 => 2), aᵏ),
"Dense" => (Dense(in_dense, out_dense, relu), randn(T, in_dense, 1)),
)
rules = Dict(
"ZeroRule" => ZeroRule(),
"EpsilonRule" => EpsilonRule(),
"GammaRule" => GammaRule(),
"WSquareRule" => WSquareRule(),
"FlatRule" => FlatRule(),
"AlphaBetaRule" => AlphaBetaRule(),
"ZPlusRule" => ZPlusRule(),
"ZBoxRule" => ZBoxRule(zero(T), oneunit(T)),
)

layernames = String.(keys(layers))
rulenames = String.(keys(rules))

SUITE["modify layer"] = BenchmarkGroup(rulenames)
SUITE["apply rule"] = BenchmarkGroup(rulenames)
for rname in rulenames
SUITE["modify layer"][rname] = BenchmarkGroup(layernames)
SUITE["apply rule"][rname] = BenchmarkGroup(layernames)
end

for (lname, (layer, aᵏ)) in layers
Rᵏ = similar(aᵏ)
Rᵏ⁺¹ = layer(aᵏ)
for (rname, rule) in rules
modified_layer = modify_layer(rule, layer)
SUITE["modify layer"][rname][lname] = @benchmarkable modify_layer($(rule), $(layer))
SUITE["apply rule"][rname][lname] = @benchmarkable lrp!(
$(Rᵏ), $(rule), $(layer), $(modified_layer), $(aᵏ), $(Rᵏ⁺¹)
)
end
end
# Use PkgJogger.@jog to create the JogExplainableAI module
@jog ExplainableAI
SUITE = JogExplainableAI.suite()
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PkgJogger = "10150987-6cc1-4b76-abee-b1c1cbd91c01"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Expand Down
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,8 @@ using JET
@info "Testing analyzers on batches..."
include("test_batches.jl")
end
@testset "Benchmark correctness" begin
@info "Testing whether benchmarks are up-to-date..."
include("test_benchmarks.jl")
end
end
1 change: 1 addition & 0 deletions test/test_batches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using Test

using Flux
using Random
using StableRNGs: StableRNG
using Distributions: Laplace

pseudorand(dims...) = rand(StableRNG(123), Float32, dims...)
Expand Down
4 changes: 4 additions & 0 deletions test/test_benchmarks.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
using PkgJogger
using ExplainableAI

PkgJogger.@test_benchmarks ExplainableAI

0 comments on commit 12046e5

Please sign in to comment.