Skip to content

Commit

Permalink
check convergence
Browse files Browse the repository at this point in the history
  • Loading branch information
stecrotti committed Oct 11, 2023
1 parent 5f73938 commit b3435c8
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 13 deletions.
16 changes: 12 additions & 4 deletions src/bp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ bethe_free_energy(f, bp::BP) = f(bp)
function iterate!(bp::BP; update_variable! = update_v_bp!, update_factor! = update_f_bp!,
maxiter=100, tol=1e-6, damp::Real=0.0, rein::Real=0.0,
f::AbstractVector{<:Real} = zeros(nvariables(bp.g)),
callback = (bp, ε, it) -> nothing,
extra_kwargs...
)
(; g, u, h) = bp
Expand All @@ -60,11 +61,14 @@ function iterate!(bp::BP; update_variable! = update_v_bp!, update_factor! = upda
for it in 1:maxiter
f .= 0
for i in variables(bp.g)
update_variable!(bp, i, hnew, damp, rein*it, f; extra_kwargs...)
errv[i] = update_variable!(bp, i, hnew, damp, rein*it, f; extra_kwargs...)
end
for a in factors(bp.g)
update_factor!(bp, a, unew, damp, f; extra_kwargs...)
errf[a] = update_factor!(bp, a, unew, damp, f; extra_kwargs...)
end
ε = max(maximum(errv), maximum(errf))
callback(bp, ε, it)
ε < tol && return it
end
return maxiter
end
Expand Down Expand Up @@ -94,16 +98,18 @@ function update_v_bp!(bp::BP, i::Integer, hnew, damp::Real, rein::Real,
msg_mult(m1, m2) = m1 .* m2
hnew[idx.(∂i)], b[i] = cavity(u[idx.(∂i)], msg_mult, ϕᵢ)
d = (degree(g, factor(a)) for a in neighbors(g, variable(i)))
err = -Inf
for ((_,_,id), dₐ) in zip(∂i, d)
zᵢ₂ₐ = sum(hnew[id])
f[i] -= log(zᵢ₂ₐ) * (1 - 1/dₐ)
hnew[id] ./= zᵢ₂ₐ
err = max(err, mean(abs, hnew[id] - h[id]))
h[id] = damp!(h[id], hnew[id], damp)
end
zᵢ = sum(b[i])
f[i] -= log(zᵢ) * (1 - degree(g, variable(i)) + sum(1/dₐ for dₐ in d; init=0.0))
b[i] ./= zᵢ
return nothing
return err
end

function update_f_bp!(bp::BP, a::Integer, unew, damp::Real, f=zeros(nvariables(bp.g));
Expand All @@ -121,13 +127,15 @@ function update_f_bp!(bp::BP, a::Integer, unew, damp::Real, f=zeros(nvariables(b
end
end
dₐ = degree(g, factor(a))
err = -Inf
for (i, _, id) in ∂a
zₐ₂ᵢ = sum(unew[id])
f[i] -= log(zₐ₂ᵢ) / dₐ
unew[id] ./= zₐ₂ᵢ
err = max(err, mean(abs, unew[id] - u[id]))
u[id] = damp!(u[id], unew[id], damp)
end
return nothing
return err
end

beliefs_bp(bp::BP) = bp.b
Expand Down
8 changes: 6 additions & 2 deletions src/maxsum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,18 @@ function update_v_ms!(bp::BP, i::Integer, hnew, damp::Real, rein::Real,
msg_sum(m1, m2) = m1 .+ m2
hnew[idx.(∂i)], b[i] = cavity(u[idx.(∂i)], msg_sum, logϕᵢ)
d = (degree(g, factor(a)) for a in neighbors(g, variable(i)))
err = -Inf
for ((_,_,id), dₐ) in zip(∂i, d)
fᵢ₂ₐ = maximum(hnew[id])
f[i] -= fᵢ₂ₐ * (1 - 1/dₐ)
hnew[id] .-= fᵢ₂ₐ
err = max(err, mean(abs, hnew[id] - h[id]))
h[id] = damp!(h[id], hnew[id], damp)
end
fᵢ = maximum(b[i])
f[i] -= fᵢ * (1 - degree(g, variable(i)) + sum(1/dₐ for dₐ in d; init=0.0))
b[i] .-= fᵢ
return nothing
return err
end

function update_f_ms!(bp::BP, a::Integer, unew, damp::Real, f=zeros(nvariables(bp.g));
Expand All @@ -38,13 +40,15 @@ function update_f_ms!(bp::BP, a::Integer, unew, damp::Real, f=zeros(nvariables(b
end
end
dₐ = degree(g, factor(a))
err = -Inf
for (i, _, id) in ∂a
fₐ₂ᵢ = maximum(unew[id])
f[i] -= fₐ₂ᵢ / dₐ
unew[id] .-= fₐ₂ᵢ
err = max(err, mean(abs, unew[id] - u[id]))
u[id] = damp!(u[id], unew[id], damp)
end
return nothing
return err
end

beliefs_ms(bp) = bp.b
Expand Down
6 changes: 3 additions & 3 deletions test/Models/ising.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
g = IndexedGraph(path_graph(2))
ising = Ising(g, J, h, β)
bp = BP(ising)
iterate!(bp, maxiter=2)
iterate!(bp, maxiter=2, tol=0)
b = beliefs(bp)
m = reduce.(-, b)
pb = only(factor_beliefs(bp))
Expand All @@ -32,7 +32,7 @@ end
ising = Ising(g, J, h, β)
bp = BP(ising)
f = zeros(N)
iterate!(bp; maxiter=100, f)
iterate!(bp; maxiter=20, f, tol=0)
b = beliefs(bp)
b_ex = exact_marginals(bp)
@test b b_ex
Expand All @@ -58,7 +58,7 @@ end
ising = Ising(g, J, h, β)
bp = BP(ising)
f = zeros(N)
iterate_ms!(bp; maxiter=100, f)
iterate_ms!(bp; maxiter=20, f, tol=0)
e = avg_energy(avg_energy_ms, bp)
e_ex = exact_minimum_energy(bp)
@test e e_ex
Expand Down
8 changes: 4 additions & 4 deletions test/bp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ rng = MersenneTwister(0)
qs = rand(rng, 2:2, nvar)
bp = rand_bp(rng, g, qs)
f = zeros(nvar)
iterate!(bp; maxiter=100, f, damp=0.2)
iterate!(bp; maxiter=50, f, damp=0.2, tol=0)
b = beliefs(bp)
@test b exact_marginals(bp)
bf = factor_beliefs(bp)
Expand All @@ -22,7 +22,7 @@ rng = MersenneTwister(0)

bp = rand_bp(rng, g, qs)
f = zeros(nvar)
iterate_ms!(bp; maxiter=10, f)
iterate_ms!(bp; maxiter=10, f, tol=0)
bfe = bethe_free_energy_ms(bp)
@test bfe sum(f)
@test bfe exact_minimum_energy(bp)
Expand All @@ -36,8 +36,8 @@ end
nfact, nvar = size(adjacency_matrix(g))
qs = rand(rng, 2:2, nvar)
bp = rand_bp(rng, g, qs)
iterate!(bp; maxiter=100, rein=0)
iterate!(bp; maxiter=100, rein=10)
iterate!(bp; maxiter=50, rein=0, tol=0)
iterate!(bp; maxiter=50, rein=10, tol=0)
b_bp = beliefs(bp)

refresh!(bp)
Expand Down

0 comments on commit b3435c8

Please sign in to comment.