Skip to content

Commit

Permalink
Speedup breadth-first search by one-hot encoding indices
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Jan 20, 2024
1 parent 45429f8 commit 33b1627
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 47 deletions.
6 changes: 6 additions & 0 deletions src/Counters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ flops(sexpr::SizedEinExpr) =

flops(expr::EinExpr, size) = flops(SizedEinExpr(expr, size))

function flops(_out, _suminds, sizelist)
mapreduce(*, enumerate(sizelist)) do (i, size)
onehot_in(i, _out) || onehot_in(i, _suminds) ? size : 1
end
end

function fastflops(sexpr::SizedEinExpr)
if nargs(sexpr) == 0 || nargs(sexpr) == 1 && isempty(suminds(sexpr))
return 0
Expand Down
2 changes: 2 additions & 0 deletions src/EinExprs.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
module EinExprs

include("Utils.jl")

include("EinExpr.jl")
export EinExpr
export head, args, inds, hyperinds, suminds, parsuminds, collapse!, contractorder, select, neighbours
Expand Down
74 changes: 27 additions & 47 deletions src/Optimizers/Exhaustive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,44 +83,6 @@ function exhaustive_depthfirst(
end
end

onehot_init(T::Type{<:Integer}) = zero(T)
onehot_init(::Type{BitSet}) = BitSet()

function onehot_in(i, set::T) where {T<:Integer}
i > sizeof(T) * 8 && return false
mask = one(T) << (i - 1)
return mask & set != zero(T)
end
onehot_in(i, set::BitSet) = in(i, set)

function onehot_push!(set::T, i) where {T<:Integer}
i > sizeof(T) * 8 && error("Index out of bounds")
mask = one(T) << (i - 1)
set |= mask
return set
end
onehot_push!(set::BitSet, i) = push!(set, i)

function onehot_pop!(set::T, i) where {T<:Integer}
i > sizeof(T) * 8 && error("Index out of bounds")
mask = one(T) << (i - 1)
set &= ~mask
return set
end
onehot_pop!(set::BitSet, i) = pop!(set, i)

onehot_isdisjoint(a::T, b::T) where {T<:Integer} = a & b == zero(T)
onehot_isdisjoint(a::BitSet, b::BitSet) = isdisjoint(a, b)

onehot_union(a::T, b::T) where {T<:Integer} = a | b
onehot_union(a::BitSet, b::BitSet) = union(a, b)

onehot_only(set::T) where {T<:Integer} = count_ones(set) == 1 ? trailing_zeros(set) + 1 : error("Expected 1 element")
onehot_only(set::BitSet) = only(set)

onehot_isempty(set::T) where {T<:Integer} = set == zero(T)
onehot_isempty(set::BitSet) = isempty(set)

function exhaustive_breadthfirst(
@specialize(metric::Val{Metric}),
expr::SizedEinExpr{L},
Expand Down Expand Up @@ -150,8 +112,21 @@ function exhaustive_breadthfirst(
# NOTE no cost because no contraction on S₁ (only input tensors)
costs = Dict{SetType,BigInt}(s => zero(BigInt) for s in S[1])

index_enc = Dict{L,SetType}(index => onehot_push!(onehot_init(SetType), i) for (i, index) in enumerate(inds(expr)))
index_dec = inds(expr)
size_enc = map(Base.Fix1(size, expr), index_dec)
skip = reduce(Iterators.map(x -> index_enc[x], expr.head); init = onehot_init(SetType)) do acc, index
index = trailing_zeros(index) + 1
onehot_push!(acc, index)
end

# contains the indices of the intermediate tensors in S
indices = Dict{SetType,Vector{L}}(s => head(expr.args[onehot_only(s)]) for s in S[1])
indices =
Iterators.map(S[1]) do s
s => reduce(head(expr.args[onehot_only(s)]); init = onehot_init(SetType)) do acc, index
onehot_union(acc, index_enc[index])
end
end |> Dict{SetType,SetType}

# contains the best-known contraction tree for constructing each object in S[c]
trees = Dict{SetType,Tuple{SetType,SetType}}(s => (onehot_init(SetType), onehot_init(SetType)) for s in S[1])
Expand All @@ -171,24 +146,25 @@ function exhaustive_breadthfirst(
onehot_isdisjoint(ta, tb) || continue

# outer products do not generally improve contraction path
!outer && isdisjoint(indices[ta], indices[tb]) && continue
!outer && onehot_isdisjoint(indices[ta], indices[tb]) && continue

# new candidate contraction
tc = onehot_union(ta, tb) # aka Q in the paper
get(costs, tc, cost_cur) > cost_prev || continue

# compute cost of getting `tc` by contracting `ta` and `tb
shallow_expr_a = EinExpr(indices[ta])
shallow_expr_b = EinExpr(indices[tb])
expr_c = sum(shallow_expr_a, shallow_expr_b; skip = expr.head)
inds_a = indices[ta]
inds_b = indices[tb]
contracting_inds = onehot_setdiff(onehot_intersect(inds_a, inds_b), skip)
inds_c = onehot_setdiff(onehot_union(inds_a, inds_b), contracting_inds)

μ = costs[ta] + costs[tb] + Metric(SizedEinExpr(expr_c, expr.size))
μ = costs[ta] + costs[tb] + Metric(inds_c, contracting_inds, size_enc)

# if `μ` is the cheapest known cost for constructing `tc`, record it
if μ <= get(costs, tc, cost_cur)
tc S[c] && push!(S[c], tc)
costs[tc] = μ
indices[tc] = head(expr_c)
indices[tc] = inds_c
trees[tc] = (ta, tb)

elseif cost_cur < μ < cost_next
Expand All @@ -205,11 +181,15 @@ function exhaustive_breadthfirst(
function recurse_construct(tc)
ta, tb = trees[tc]

inds_c = map(last, Iterators.filter(enumerate(index_dec)) do (i, _)
onehot_in(i, indices[tc])
end)

if onehot_isempty(ta) && onehot_isempty(tb)
return EinExpr(indices[tc]::Vector{L})
return EinExpr(inds_c::Vector{L})
end

return EinExpr(indices[tc], map(recurse_construct, [ta, tb]))
return EinExpr(inds_c, map(recurse_construct, [ta, tb]))
end

path = recurse_construct(only(S[n]))
Expand Down
46 changes: 46 additions & 0 deletions src/Utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
onehot_init(T::Type{<:Integer}) = zero(T)
onehot_init(::Type{BitSet}) = BitSet()

function onehot_in(i, set::T) where {T<:Integer}
i > sizeof(T) * 8 && return false
mask = one(T) << (i - 1)
return mask & set != zero(T)
end
onehot_in(i, set::BitSet) = in(i, set)

function onehot_push!(set::T, i) where {T<:Integer}
i > sizeof(T) * 8 && error("Index out of bounds")
mask = one(T) << (i - 1)
set |= mask
return set
end
onehot_push!(set::BitSet, i) = push!(set, i)

function onehot_pop!(set::T, i) where {T<:Integer}
i > sizeof(T) * 8 && error("Index out of bounds")
mask = one(T) << (i - 1)
set &= ~mask
return set
end
onehot_pop!(set::BitSet, i) = pop!(set, i)

onehot_isdisjoint(a::T, b::T) where {T<:Integer} = a & b == zero(T)
onehot_isdisjoint(a::BitSet, b::BitSet) = isdisjoint(a, b)

onehot_union(a::T, b::T) where {T<:Integer} = a | b
onehot_union(a::BitSet, b::BitSet) = union(a, b)

onehot_intersect(a::T, b::T) where {T<:Integer} = a & b
onehot_intersect(a::BitSet, b::BitSet) = intersect(a, b)

onehot_setdiff(a::T, b::T) where {T<:Integer} = a & ~b
onehot_setdiff(a::BitSet, b::BitSet) = setdiff(a, b)

onehot_symdiff(a::T, b::T) where {T<:Integer} = a b
onehot_symdiff(a::BitSet, b::BitSet) = symdiff(a, b)

onehot_only(set::T) where {T<:Integer} = count_ones(set) == 1 ? trailing_zeros(set) + 1 : error("Expected 1 element")
onehot_only(set::BitSet) = only(set)

onehot_isempty(set::T) where {T<:Integer} = set == zero(T)
onehot_isempty(set::BitSet) = isempty(set)

0 comments on commit 33b1627

Please sign in to comment.