diff --git a/benchmark/make.jl b/benchmark/make.jl index d7002fc..d9ffa8e 100644 --- a/benchmark/make.jl +++ b/benchmark/make.jl @@ -15,18 +15,15 @@ suite["greedy"] = BenchmarkGroup([]) suite["kahypar"] = BenchmarkGroup([]) # BENCHMARK 1 -expr = EinExpr( - Symbol[], - [ - EinExpr([:j, :b, :i, :h], Dict(i => 2 for i in [:j, :b, :i, :h])), - EinExpr([:a, :c, :e, :f], Dict(i => 2 for i in [:a, :c, :e, :f])), - EinExpr([:j], Dict(i => 2 for i in [:j])), - EinExpr([:e, :a, :g], Dict(i => 2 for i in [:e, :a, :g])), - EinExpr([:f, :b], Dict(i => 2 for i in [:f, :b])), - EinExpr([:i, :h, :d], Dict(i => 2 for i in [:i, :h, :d])), - EinExpr([:d, :g, :c], Dict(i => 2 for i in [:d, :g, :c])), - ], -) +expr = sum([ + EinExpr([:j, :b, :i, :h], Dict(i => 2 for i in [:j, :b, :i, :h])), + EinExpr([:a, :c, :e, :f], Dict(i => 2 for i in [:a, :c, :e, :f])), + EinExpr([:j], Dict(i => 2 for i in [:j])), + EinExpr([:e, :a, :g], Dict(i => 2 for i in [:e, :a, :g])), + EinExpr([:f, :b], Dict(i => 2 for i in [:f, :b])), + EinExpr([:i, :h, :d], Dict(i => 2 for i in [:i, :h, :d])), + EinExpr([:d, :g, :c], Dict(i => 2 for i in [:d, :g, :c])), +]) suite["naive"][1] = @benchmarkable einexpr(EinExprs.Naive(), $expr) suite["exhaustive"][1] = @benchmarkable einexpr(Exhaustive(), $expr) @@ -41,7 +38,7 @@ D = EinExpr([:c, :h, :d, :i], Dict(:c => 2, :h => 2, :d => 2, :i => 2)) E = EinExpr([:f, :i, :g, :j], Dict(:f => 2, :i => 2, :g => 2, :j => 2)) F = EinExpr([:B, :h, :k, :l], Dict(:B => 2, :h => 2, :k => 2, :l => 2)) G = EinExpr([:j, :k, :l, :D], Dict(:j => 2, :k => 2, :l => 2, :D => 2)) -expr = EinExpr([:A, :B, :C, :D], [A, B, C, D, E, F, G]) +expr = sum([A, B, C, D, E, F, G], skip = [:A, :B, :C, :D]) suite["naive"][2] = @benchmarkable einexpr(EinExprs.Naive(), $expr) suite["exhaustive"][2] = @benchmarkable einexpr(Exhaustive(), $expr) diff --git a/ext/EinExprsMakieExt.jl b/ext/EinExprsMakieExt.jl index 7d3a36c..79e52ed 100644 --- a/ext/EinExprsMakieExt.jl +++ b/ext/EinExprsMakieExt.jl @@ -15,13 +15,13 @@ const MAX_EDGE_WIDTH = 10.0 const MAX_ARROW_SIZE = 35.0 const MAX_NODE_SIZE = 40.0 -function Makie.plot(path::EinExpr; kwargs...) +function Makie.plot(path::SizedEinExpr; kwargs...) f = Figure() ax, p = plot!(f[1, 1], path; kwargs...) return Makie.FigureAxisPlot(f, ax, p) end -function Makie.plot!(f::Union{Figure,GridPosition}, path::EinExpr; kwargs...) +function Makie.plot!(f::Union{Figure,GridPosition}, path::SizedEinExpr; kwargs...) ax = if haskey(kwargs, :layout) && __networklayout_dim(kwargs[:layout]) == 3 Axis3(f[1, 1]) else @@ -65,13 +65,13 @@ end # TODO replace `to_colormap(:viridis)[begin:end-10]` with a custom colormap function Makie.plot!( ax::Union{Axis,Axis3}, - path::EinExpr; + path::SizedEinExpr; colormap = to_colormap(:viridis)[begin:end-10], inds = false, kwargs..., ) - handles = IdDict(obj => i for (i, obj) in enumerate(PostOrderDFS(path))) - graph = SimpleDiGraph([Edge(handles[from], handles[to]) for to in Branches(path) for from in to.args]) + handles = IdDict(obj => i for (i, obj) in enumerate(PostOrderDFS(path.path))) + graph = SimpleDiGraph([Edge(handles[from], handles[to]) for to in Branches(path.path) for from in to.args]) lin_size = length.(PostOrderDFS(path))[1:end-1] lin_flops = map(max, Iterators.repeated(1), Iterators.map(flops, PostOrderDFS(path))) diff --git a/src/Counters.jl b/src/Counters.jl index 1ec72fd..eca2b8e 100644 --- a/src/Counters.jl +++ b/src/Counters.jl @@ -3,19 +3,30 @@ Count the number of mathematical operations will be performed by the contraction of the root of the `path` tree. """ -flops(expr::EinExpr) = - if length(expr.args) == 0 || length(expr.args) == 1 && isempty(suminds(expr)) +flops(sexpr::SizedEinExpr) = + if nargs(sexpr) == 0 || nargs(sexpr) == 1 && isempty(suminds(sexpr)) 0 else - mapreduce(Base.Fix1(size, expr), *, Iterators.flatten((head(expr), suminds(expr))), init = one(BigInt)) + mapreduce( + Base.Fix1(getindex, sexpr.size), + *, + Iterators.flatten((head(sexpr), suminds(sexpr))), + init = one(BigInt), + ) end +flops(expr::EinExpr, size) = flops(SizedEinExpr(expr, size)) + """ removedsize(path::EinExpr) Count the amount of memory that will be freed after performing the contraction of the root of the `path` tree. """ -removedsize(expr::EinExpr) = mapreduce(prod ∘ size, +, expr.args) - prod(size(expr)) +removedsize(sexpr::SizedEinExpr) = -length(sexpr) + mapreduce(+, sexpr.args) do arg + length(SizedEinExpr(arg, sexpr.size)) +end + +removedsize(expr::EinExpr, size) = removedsize(SizedEinExpr(expr, size)) """ removedrank(path::EinExpr) @@ -23,3 +34,10 @@ removedsize(expr::EinExpr) = mapreduce(prod ∘ size, +, expr.args) - prod(size( Count the rank reduction after performing the contraction of the root of the `path` tree. """ removedrank(expr::EinExpr) = mapreduce(ndims, max, expr.args) - ndims(expr) +removedrank(expr::EinExpr, _) = removedrank(expr) +removedrank(sexpr::SizedEinExpr, _) = removedrank(sexpr.path) + +for f in [:flops, :removedsize] + @eval $f(sizedict::Dict{Symbol}) = Base.Fix2($f, sizedict) +end +removedrank(::Dict) = removedrank diff --git a/src/EinExpr.jl b/src/EinExpr.jl index cc65304..e4976a4 100644 --- a/src/EinExpr.jl +++ b/src/EinExpr.jl @@ -2,32 +2,18 @@ using Base: AbstractVecOrTuple using DataStructures: DefaultDict using AbstractTrees -struct EinExpr +Base.@kwdef struct EinExpr head::Vector{Symbol} - args::Vector{EinExpr} - size::Dict{Symbol,Int} - - # TODO checks: same dim for index, valid indices - EinExpr(head, args) = new(head, args, Dict{Symbol,EinExpr}()) - - function EinExpr(head::AbstractVector{Symbol}, size::AbstractDict{Symbol,Int}) - head ⊆ keys(size) || throw(ArgumentError("Missing sizes for indices $(setdiff(head, keys(size)))")) - new(head, EinExpr[], size) - end + args::Vector{EinExpr} = EinExpr[] end +EinExpr(head) = EinExpr(head, EinExpr[]) +EinExpr(head, args::AbstractVecOrTuple{<:AbstractVecOrTuple{Symbol}}) = EinExpr(head, map(EinExpr, args)) + EinExpr(head::NTuple, args) = EinExpr(collect(head), args) EinExpr(head, args::NTuple) = EinExpr(head, collect(args)) EinExpr(head::NTuple, args::NTuple) = EinExpr(collect(head), collect(args)) -function EinExpr(head, args::AbstractVecOrTuple{<:AbstractVecOrTuple{Symbol}}, sizes) - args = map(args) do arg - sizedict = filter(∈(arg) ∘ first, sizes) - EinExpr(arg, sizedict) - end - EinExpr(head, args) -end - """ head(path::EinExpr) @@ -46,6 +32,8 @@ See also: [`head`](@ref). """ args(path::EinExpr) = path.args +nargs(path::EinExpr) = length(path.args) + """ inds(path) @@ -100,11 +88,8 @@ Base.ndims(path::EinExpr) = length(head(path)) Return the size of the resulting tensor from contracting `path`. If `index` is specified, return the size of such index. """ -Base.size(path::EinExpr) = (size(path, i) for i in head(path)) |> splat(tuple) -Base.size(path::EinExpr, i::Symbol) = - Iterators.filter(∋(i) ∘ head, Leaves(path)) |> first |> Base.Fix2(getproperty, :size) |> Base.Fix2(getindex, i) - -Base.length(path::EinExpr) = (prod ∘ size)(path) +Base.size(path::EinExpr, sizedict) = (sizedict[i] for i in head(path)) |> splat(tuple) +Base.length(path::EinExpr, sizedict) = (prod ∘ size)(path, sizedict) """ collapse!(path::EinExpr) @@ -241,6 +226,7 @@ Create an `EinExpr` from other `EinExpr`s. function Base.sum(args::Vector{EinExpr}; skip = Symbol[]) _head = Symbol[] _counts = Int[] + for arg in args for index in head(arg) i = findfirst(Base.Fix1(===, index), _head) @@ -248,17 +234,37 @@ function Base.sum(args::Vector{EinExpr}; skip = Symbol[]) push!(_head, index) push!(_counts, 1) else - _counts[i] += 1 + @inbounds _counts[i] += 1 end end end - _head = map(first, Iterators.filter(zip(_head, _counts)) do (index, count) - count == 1 || index ∈ skip - end) + # NOTE `map` with `Iterators.filter` induces many heap grows; allocating once and deleting is faster + for i in Iterators.reverse(eachindex(_head, _counts)) + (_counts[i] == 1 || _head[i] ∈ skip) && continue + deleteat!(_head, i) + end + EinExpr(_head, args) end +function Base.sum(a::EinExpr, b::EinExpr; skip = Symbol[]) + _head = copy(head(a)) + + for index in head(b) + i = findfirst(Base.Fix1(===, index), _head) + if isnothing(i) + push!(_head, index) + elseif index ∈ skip + continue + else + deleteat!(_head, i) + end + end + + EinExpr(_head, [a, b]) +end + function Base.string(path::EinExpr; recursive::Bool = false) !recursive && return "$(join(map(x -> string.(head(x)) |> join, args(path)), ","))->$(string.(head(path)) |> join)" map(string, Branches(path)) diff --git a/src/EinExprs.jl b/src/EinExprs.jl index 589af7b..c3d52f8 100644 --- a/src/EinExprs.jl +++ b/src/EinExprs.jl @@ -5,6 +5,9 @@ export EinExpr export head, args, inds, hyperinds, suminds, parsuminds, collapse!, contractorder, select, neighbours export Branches, branches, leaves +include("SizedEinExpr.jl") +export SizedEinExpr + include("Counters.jl") export flops, removedsize diff --git a/src/Optimizers/Exhaustive.jl b/src/Optimizers/Exhaustive.jl index ea3591c..ef99a19 100644 --- a/src/Optimizers/Exhaustive.jl +++ b/src/Optimizers/Exhaustive.jl @@ -23,35 +23,41 @@ The algorithm has a ``\mathcal{O}(n!)`` time complexity if `outer = true` and `` end function einexpr(config::Exhaustive, path; cost = BigInt(0)) - leader = Ref{NamedTuple{(:path, :cost),Tuple{EinExpr,BigInt}}}((; + # metric = Base.Fix2(config.metric, path.size) + leader = Ref((; path = einexpr(Naive(), path), cost = mapreduce(config.metric, +, Branches(einexpr(Naive(), path), inverse = true), init = BigInt(0))::BigInt, )) - cache = Dict{Vector{Symbol},BigInt}() - __einexpr_exhaustive_it(path, cost, config.metric, config.outer, leader, cache) + __einexpr_exhaustive_it(path, cost, Val(config.metric), config.outer, leader) return leader[].path end -function __einexpr_exhaustive_it(path, cost, metric, outer, leader, cache) - if length(path.args) == 1 - # remove identity einsum (i.e. "i...->i...") - path = path.args[1] - - leader[] = (; path, cost = mapreduce(metric, +, Branches(path, inverse = true), init = BigInt(0))::BigInt) +function __einexpr_exhaustive_it( + path, + cost, + @specialize(metric::Val{Metric}), + outer, + leader; + cache = Dict{Vector{Symbol},BigInt}(), + hashyperinds = !isempty(hyperinds(path)), +) where {Metric} + if nargs(path) <= 2 + #= mapreduce(metric, +, Branches(path, inverse = true), init = BigInt(0))) =# + leader[] = (; path = path, cost = cost) return end - for (i, j) in combinations(args(path), 2) + for (i, j) in combinations(path.args, 2) !outer && isdisjoint(head(i), head(j)) && continue - candidate = sum([i, j], skip = path.head ∪ hyperinds(path)) + candidate = sum(i, j; skip = hashyperinds ? path.head ∪ hyperinds(path) : path.head) # prune paths based on metric new_cost = cost + get!(cache, head(candidate)) do - metric(candidate) + Metric(SizedEinExpr(candidate, path.size)) end new_cost >= leader[].cost && continue - new_path = EinExpr(head(path), [candidate, filter(∉([i, j]), args(path))...]) - __einexpr_exhaustive_it(new_path, new_cost, metric, outer, leader, cache) + new_path = SizedEinExpr(EinExpr(head(path), [candidate, filter(∉([i, j]), path.args)...]), path.size) # sum([candidate, filter(∉([i, j]), args(path))...], skip = path.head) + __einexpr_exhaustive_it(new_path, new_cost, metric, outer, leader; cache, hashyperinds) end end diff --git a/src/Optimizers/Greedy.jl b/src/Optimizers/Greedy.jl index 24a27bc..11c53a9 100644 --- a/src/Optimizers/Greedy.jl +++ b/src/Optimizers/Greedy.jl @@ -27,7 +27,9 @@ The implementation uses a binary heaptree to sort candidate pairwise tensor cont outer::Bool = false end -function einexpr(config::Greedy, path) +function einexpr(config::Greedy, path, sizedict) + metric = config.metric(sizedict) + # generate initial candidate contractions queue = MutableBinaryHeap{Tuple{Float64,EinExpr}}( Base.By(first, Base.Reverse), @@ -36,12 +38,12 @@ function einexpr(config::Greedy, path) ) do (a, b) # TODO don't consider outer products candidate = sum([a, b], skip = path.head ∪ hyperinds(path)) - weight = config.metric(candidate) + weight = metric(candidate) (weight, candidate) end, ) - while length(path.args) > 2 && length(queue) > 1 + while nargs(path) > 2 && length(queue) > 1 # choose winner _, winner = config.choose(queue) @@ -55,7 +57,7 @@ function einexpr(config::Greedy, path) for other in Iterators.filter(other -> config.outer || !isdisjoint(winner.head, other.head), path.args) # TODO don't consider outer products candidate = sum([winner, other], skip = path.head ∪ hyperinds(path)) - weight = config.metric(candidate) + weight = metric(candidate) push!(queue, (weight, candidate)) end @@ -65,3 +67,5 @@ function einexpr(config::Greedy, path) return path end + +einexpr(config::Greedy, path::SizedEinExpr) = SizedEinExpr(einexpr(config, path.path, path.size), path.size) diff --git a/src/Optimizers/KaHyPar.jl b/src/Optimizers/KaHyPar.jl index b3c040c..96dc4b4 100644 --- a/src/Optimizers/KaHyPar.jl +++ b/src/Optimizers/KaHyPar.jl @@ -6,7 +6,7 @@ using Suppressor @kwdef struct HyPar <: Optimizer parts::Int = 2 imbalance::Float32 = 0.03 - stop::Function = <=(2) ∘ length ∘ Base.Fix2(getfield, :args) + stop::Function = <=(2) ∘ length ∘ Base.Fix2(getproperty, :args) configuration::Union{Nothing,Symbol,String} = nothing edge_scaler::Function = Base.Fix1(*, 1000) ∘ Int ∘ round ∘ log2 vertex_scaler::Function = Base.Fix1(*, 1000) ∘ Int ∘ round ∘ log2 @@ -25,7 +25,7 @@ function EinExprs.einexpr(config::HyPar, path) # NOTE indices in `inds` should be in the same order as unique indices appear by iterating on `path.args` because `∪` retains order edge_weights = map(config.edge_scaler ∘ Base.Fix1(size, path), inds) - vertex_weights = map(config.vertex_scaler ∘ length, path.args) + vertex_weights = map(config.vertex_scaler ∘ length, args(path)) hypergraph = KaHyPar.HyperGraph(incidence_matrix, vertex_weights, edge_weights) @@ -36,13 +36,13 @@ function EinExprs.einexpr(config::HyPar, path) configuration = config.configuration, ) - args = map(unique(partitions)) do partition + _args = map(unique(partitions)) do partition selection = partitions .== partition - count(selection) == 1 && return only(path.args[selection]) + count(selection) == 1 && return only(args(path)[selection]) - expr = sum(path.args[selection], skip = path.head) + expr = sum(args(path)[selection], skip = path.head) einexpr(config, expr) end - return EinExpr(path.head, args) + return sum(_args, skip = path.head) end diff --git a/src/Optimizers/Naive.jl b/src/Optimizers/Naive.jl index 88d9a5c..9b31d93 100644 --- a/src/Optimizers/Naive.jl +++ b/src/Optimizers/Naive.jl @@ -1,10 +1,14 @@ +using AbstractTrees + struct Naive <: Optimizer end +einexpr(::Naive, path, _) = einexpr(Naive(), path) + function einexpr(::Naive, path) - hist = Dict(i => count(∋(i) ∘ head, path.args) for i in hyperinds(path)) + hist = Dict(i => count(∋(i) ∘ head, args(path)) for i in hyperinds(path)) - foldl(path.args) do a, b - expr = sum([a, b], skip = path.head ∪ collect(keys(hist))) + foldl(args(path)) do a, b + expr = sum([a, b], skip = head(path) ∪ collect(keys(hist))) for i in Iterators.filter(∈(keys(hist)), ∩(head(a), head(b))) hist[i] -= 1 @@ -14,3 +18,5 @@ function einexpr(::Naive, path) return expr end end + +einexpr(::Naive, path::SizedEinExpr) = SizedEinExpr(einexpr(Naive(), path.path), path.size) diff --git a/src/Optimizers/Optimizers.jl b/src/Optimizers/Optimizers.jl index d4fb537..516b8ce 100644 --- a/src/Optimizers/Optimizers.jl +++ b/src/Optimizers/Optimizers.jl @@ -3,7 +3,7 @@ abstract type Optimizer end function einexpr end einexpr(T::Type{<:Optimizer}, args...; kwargs...) = einexpr(T(; kwargs...), args...) -einexpr(config::Optimizer, expr) = einexpr(config, expr) +einexpr(config::Optimizer, expr, sizedict) = einexpr(config, expr, sizedict) include("Naive.jl") include("Exhaustive.jl") diff --git a/src/SizedEinExpr.jl b/src/SizedEinExpr.jl new file mode 100644 index 0000000..c63e555 --- /dev/null +++ b/src/SizedEinExpr.jl @@ -0,0 +1,80 @@ +using AbstractTrees + +struct SizedEinExpr + path::EinExpr + size::Dict{Symbol,Int} + + function SizedEinExpr(path, size) + # inds(path) ⊆ keys(size) || throw(ArgumentError("")) + new(path, size) + end +end + +EinExpr(path::Vector{Symbol}, size::Dict{Symbol}) = SizedEinExpr(EinExpr(path), size) + +head(sexpr::SizedEinExpr) = head(sexpr.path) + +""" + args(sexpr::SizedEinExpr) + +# Note + +Unlike `args(::EinExpr)`, this function returns `SizedEinExpr` objects. +""" +args(sexpr::SizedEinExpr) = map(Base.Fix2(SizedEinExpr, sexpr.size), sexpr.path.args) # sexpr.path.args + +nargs(sexpr::SizedEinExpr) = nargs(sexpr.path) +inds(sexpr::SizedEinExpr) = inds(sexpr.path) + +function Base.getproperty(sexpr::SizedEinExpr, name::Symbol) + name === :head && return getfield(sexpr, :path).head + name === :args && return getfield(sexpr, :path).args + return getfield(sexpr, name) +end + +Base.:(==)(a::SizedEinExpr, b::SizedEinExpr) = a.path == b.path && a.size == b.size + +Base.ndims(sexpr::SizedEinExpr) = ndims(sexpr.path) + +Base.size(sexpr::SizedEinExpr) = size(sexpr.path, sexpr.size) +Base.size(sexpr::SizedEinExpr, i) = sexpr.size[i] +Base.length(sexpr::SizedEinExpr) = length(sexpr.path, sexpr.size) + +collapse!(sexpr::SizedEinExpr) = collapse!(sexpr.path) + +select(sexpr::SizedEinExpr, i) = map(Base.Fix2(SizedEinExpr, sexpr.size), select(sexpr.path, i)) + +neighbours(sexpr::SizedEinExpr, i) = map(Base.Fix2(SizedEinExpr, sexpr.size), neighbours(sexpr.path, i)) + +contractorder(sexpr::SizedEinExpr) = contractorder(sexpr.path) + +hyperinds(sexpr::SizedEinExpr) = hyperinds(sexpr.path) + +suminds(sexpr::SizedEinExpr) = suminds(sexpr.path) +parsuminds(sexpr::SizedEinExpr) = parsuminds(sexpr.path) + +Base.sum!(sexpr::SizedEinExpr, inds) = sum!(sexpr.path, inds) +Base.sum(sexpr::SizedEinExpr, inds) = sum(sexpr.path, inds) + +function Base.sum(sexpr::Vector{SizedEinExpr}; skip = Symbol[]) + path = sum(map(x -> x.path, sexpr); skip) + size = allequal(Iterators.map(x -> x.size, sexpr)) ? first(sexpr).size : merge(map(x -> x.size, sexpr)...) + # size = merge(map(x -> x.size, sexpr)...) + SizedEinExpr(path, size) +end + +# Iteration interface +Base.IteratorEltype(::Type{<:TreeIterator{SizedEinExpr}}) = Base.HasEltype() +Base.eltype(::Type{<:TreeIterator{SizedEinExpr}}) = SizedEinExpr + +# AbstractTrees interface and traits +AbstractTrees.children(sexpr::SizedEinExpr) = args(sexpr) +AbstractTrees.childtype(::Type{SizedEinExpr}) = SizedEinExpr +AbstractTrees.childrentype(::Type{SizedEinExpr}) = Vector{SizedEinExpr} +AbstractTrees.childstatetype(::Type{SizedEinExpr}) = Int +AbstractTrees.nodetype(::Type{SizedEinExpr}) = SizedEinExpr + +AbstractTrees.ParentLinks(::Type{SizedEinExpr}) = ImplicitParents() +AbstractTrees.SiblingLinks(::Type{SizedEinExpr}) = ImplicitSiblings() +AbstractTrees.ChildIndexing(::Type{SizedEinExpr}) = IndexedChildren() +AbstractTrees.NodeType(::Type{SizedEinExpr}) = HasNodeType() diff --git a/src/Slicing.jl b/src/Slicing.jl index 1330ec9..281ecf3 100644 --- a/src/Slicing.jl +++ b/src/Slicing.jl @@ -13,34 +13,33 @@ Project `index` to dimension `i` in a EinExpr. This is equivalent to tensor cutt See also: [`view`](@ref). """ -function Base.selectdim(path::EinExpr, index::Symbol, i) +Base.selectdim(path::EinExpr, ::Symbol, i) = path + +function Base.selectdim(path::EinExpr, index::Symbol, i::Integer) path = deepcopy(path) - for leave in Iterators.filter(∋(index) ∘ head, Leaves(path)) - leave.size[index] = length(i) + for expr in PreOrderDFS(path) + filter!(!=(index), expr.head) end return path end -function Base.selectdim(path::EinExpr, index::Symbol, _::Integer) - path = deepcopy(path) +function Base.selectdim(sexpr::SizedEinExpr, index::Symbol, i) + path = selectdim(sexpr.path, index, i) - index ∈ head(path) && (path = EinExpr(filter(!=(index), path.head), path.args)) - - for branch in Branches(path) - for arg in Iterators.filter(∋(index) ∘ head, branch.args) - replace!( - branch.args, - arg => EinExpr( - filter(!=(index), arg.head), - isempty(arg.args) ? filter(p -> p.first != index, arg.size) : arg.args, - ), - ) - end - end + size = copy(sexpr.size) + size[index] = length(i) - return path + return SizedEinExpr(path, size) +end + +function Base.selectdim(sexpr::SizedEinExpr, index::Symbol, i::Integer) + path = selectdim(sexpr.path, index, i) + + size = filter(!=(index) ∘ first, sexpr.size) + + return SizedEinExpr(path, size) end """ @@ -88,7 +87,7 @@ Reimplementation based on [`contengra`](https://github.com/jcmgray/cotengra)'s ` """ function findslices( scorer, - path::EinExpr; + path; size = nothing, overhead = nothing, slices = nothing, @@ -100,7 +99,7 @@ function findslices( candidates = Set(setdiff(mapreduce(head, ∪, PostOrderDFS(path)), skip)) solution = Set{Symbol}() - current = (; slices = 1, size = maximum(prod ∘ Base.size, PostOrderDFS(path)), overhead = 1.0) + current = (; slices = 1, size = maximum(length, PostOrderDFS(path)), overhead = 1.0) original_flops = mapreduce(flops, +, Branches(path; inverse = true)) sliced_path = path @@ -120,8 +119,8 @@ function findslices( push!(solution, winner) current = (; - slices = current.slices * (prod ∘ Base.size)(path, winner), - size = maximum(prod ∘ Base.size, PostOrderDFS(sliced_path)), + slices = current.slices * Base.size(path, winner), + size = maximum(length, PostOrderDFS(sliced_path)), overhead = cur_overhead, ) @@ -149,7 +148,7 @@ function (cb::FlopsScorer)(path, index) slice = selectdim(path, index, 1) flops_reduction = mapreduce(flops, +, PostOrderDFS(path)) - mapreduce(flops, +, PostOrderDFS(slice)) - write_reduction = mapreduce(prod ∘ size, +, PostOrderDFS(path)) - mapreduce(prod ∘ size, +, PostOrderDFS(slice)) + write_reduction = mapreduce(length, +, PostOrderDFS(path)) - mapreduce(length, +, PostOrderDFS(slice)) log(flops_reduction + write_reduction * cb.weight + 1) end @@ -169,7 +168,7 @@ function (cb::SizeScorer)(path, index) slice = selectdim(path, index, 1) flops_reduction = mapreduce(flops, +, PostOrderDFS(path)) - mapreduce(flops, +, PostOrderDFS(slice)) - write_reduction = mapreduce(prod ∘ size, +, PostOrderDFS(path)) - mapreduce(prod ∘ size, +, PostOrderDFS(slice)) + write_reduction = mapreduce(length, +, PostOrderDFS(path)) - mapreduce(length, +, PostOrderDFS(slice)) log(write_reduction + flops_reduction * cb.weight + 1) end diff --git a/test/Counters_test.jl b/test/Counters_test.jl index 6edc54e..8675ea5 100644 --- a/test/Counters_test.jl +++ b/test/Counters_test.jl @@ -1,75 +1,77 @@ @testset "Counters" begin using EinExprs: removedrank + sizedict = Dict(:i => 2, :j => 3, :k => 4, :l => 5) + @testset "identity" begin - tensor = EinExpr([:i, :j], Dict(:i => 2, :j => 3)) - expr = EinExpr([:i, :j], [tensor]) + tensor = EinExpr((:i, :j)) + expr = EinExpr((:i, :j), [tensor]) - @test flops(expr) == 0 - @test removedsize(expr) == 0 - @test removedrank(expr) == 0 + @test flops(expr, sizedict) == 0 + @test removedsize(expr, sizedict) == 0 + @test removedrank(expr, sizedict) == 0 end @testset "transpose" begin - tensor = EinExpr([:i, :j], Dict(:i => 2, :j => 3)) + tensor = EinExpr((:i, :j)) expr = EinExpr([:j, :i], [tensor]) - @test flops(expr) == 0 - @test removedsize(expr) == 0 - @test removedrank(expr) == 0 + @test flops(expr, sizedict) == 0 + @test removedsize(expr, sizedict) == 0 + @test removedrank(expr, sizedict) == 0 end @testset "axis sum" begin - tensor = EinExpr([:i, :j], Dict(:i => 2, :j => 3)) - expr = EinExpr([:i], [tensor]) + tensor = EinExpr((:i, :j)) + expr = EinExpr((:i,), [tensor]) - @test flops(expr) == 6 - @test removedsize(expr) == 4 - @test removedrank(expr) == 1 + @test flops(expr, sizedict) == 6 + @test removedsize(expr, sizedict) == 4 + @test removedrank(expr, sizedict) == 1 end @testset "diagonal" begin - tensor = EinExpr([:i, :i], Dict(:i => 2)) - expr = EinExpr([:i], [tensor]) + tensor = EinExpr((:i, :i)) + expr = EinExpr((:i,), [tensor]) - @test flops(expr) == 0 - @test removedsize(expr) == 2 - @test removedrank(expr) == 1 + @test flops(expr, sizedict) == 0 + @test removedsize(expr, sizedict) == 2 + @test removedrank(expr, sizedict) == 1 end @testset "trace" begin - tensor = EinExpr([:i, :i], Dict(:i => 2)) + tensor = EinExpr((:i, :i)) expr = EinExpr(Symbol[], [tensor]) - @test flops(expr) == 2 - @test removedsize(expr) == 3 - @test removedrank(expr) == 2 + @test flops(expr, sizedict) == 2 + @test removedsize(expr, sizedict) == 3 + @test removedrank(expr, sizedict) == 2 end @testset "outer product" begin - tensors = [EinExpr([:i, :j], Dict(:i => 2, :j => 3)), EinExpr([:k, :l], Dict(:k => 4, :l => 5))] - expr = EinExpr([:i, :j, :k, :l], tensors) + tensors = [EinExpr((:i, :j)), EinExpr((:k, :l))] + expr = EinExpr((:i, :j, :k, :l), tensors) - @test flops(expr) == prod(2:5) - @test removedsize(expr) == -94 - @test removedrank(expr) == -2 + @test flops(expr, sizedict) == prod(2:5) + @test removedsize(expr, sizedict) == -94 + @test removedrank(expr, sizedict) == -2 end @testset "inner product" begin - tensors = [EinExpr([:i], Dict(:i => 2)), EinExpr([:i], Dict(:i => 2))] + tensors = [EinExpr((:i,)), EinExpr((:i,))] expr = EinExpr(Symbol[], tensors) - @test flops(expr) == 2 - @test removedsize(expr) == 3 - @test removedrank(expr) == 1 + @test flops(expr, sizedict) == 2 + @test removedsize(expr, sizedict) == 3 + @test removedrank(expr, sizedict) == 1 end @testset "matrix multiplication" begin - tensors = [EinExpr([:i, :k], Dict(:i => 2, :k => 3)), EinExpr([:k, :j], Dict(:k => 3, :j => 4))] - expr = EinExpr([:i, :j], tensors) + tensors = [EinExpr((:i, :j)), EinExpr((:j, :k))] + expr = EinExpr((:i, :k), tensors) - @test flops(expr) == 2 * 3 * 4 - @test removedsize(expr) == 10 - @test removedrank(expr) == 0 + @test flops(expr, sizedict) == 2 * 3 * 4 + @test removedsize(expr, sizedict) == 10 + @test removedrank(expr, sizedict) == 0 end end diff --git a/test/EinExpr_test.jl b/test/EinExpr_test.jl index 76c1d79..114e133 100644 --- a/test/EinExpr_test.jl +++ b/test/EinExpr_test.jl @@ -2,7 +2,7 @@ using LinearAlgebra @testset "identity" begin - tensor = EinExpr([:i, :j], Dict(:i => 2, :j => 3)) + tensor = EinExpr([:i, :j]) expr = EinExpr([:i, :j], [tensor]) @test expr.head == head(tensor) @@ -11,10 +11,6 @@ @test head(expr) == head(tensor) @test ndims(expr) == 2 - @test size(expr, :i) == 2 - @test size(expr, :j) == 3 - @test size(expr) == (2, 3) - @test isempty(hyperinds(expr)) @test isempty(suminds(expr)) @test isempty(parsuminds(expr)) @@ -27,7 +23,7 @@ end @testset "transpose" begin - tensor = EinExpr([:i, :j], Dict(:i => 2, :j => 3)) + tensor = EinExpr([:i, :j]) expr = EinExpr([:j, :i], [tensor]) @test expr.head == reverse(inds(tensor)) @@ -36,10 +32,6 @@ @test head(expr) == reverse(inds(tensor)) @test ndims(expr) == 2 - @test size(expr, :i) == 2 - @test size(expr, :j) == 3 - @test size(expr) == (3, 2) - @test isempty(hyperinds(expr)) @test isempty(suminds(expr)) @test isempty(parsuminds(expr)) @@ -52,7 +44,7 @@ end @testset "axis sum" begin - tensor = EinExpr([:i, :j], Dict(:i => 2, :j => 3)) + tensor = EinExpr([:i, :j]) expr = EinExpr((:i,), [tensor]) @test all(splat(==), zip(expr.head, [:i])) @@ -61,10 +53,6 @@ @test all(splat(==), zip(head(expr), (:i,))) @test all(splat(==), zip(inds(expr), [:i, :j])) - @test size(expr, :i) == 2 - @test size(expr, :j) == 3 - @test size(expr) == (2,) - @test isempty(hyperinds(expr)) @test suminds(expr) == [:j] @test isempty(parsuminds(expr)) @@ -77,7 +65,7 @@ end @testset "diagonal" begin - tensor = EinExpr([:i, :i], Dict(:i => 2)) + tensor = EinExpr([:i, :i]) expr = EinExpr((:i,), [tensor]) @test all(splat(==), zip(expr.head, (:i,))) @@ -86,9 +74,6 @@ @test all(splat(==), zip(head(expr), (:i,))) @test all(splat(==), zip(inds(expr), head(expr))) - @test size(expr, :i) == 2 - @test size(expr) == (2,) - @test isempty(hyperinds(expr)) @test isempty(suminds(expr)) @test isempty(parsuminds(expr)) @@ -99,7 +84,7 @@ end @testset "trace" begin - tensor = EinExpr([:i, :i], Dict(:i => 2)) + tensor = EinExpr([:i, :i]) expr = EinExpr(Symbol[], [tensor]) @test isempty(expr.head) @@ -108,9 +93,6 @@ @test isempty(head(expr)) @test all(splat(==), zip(inds(expr), (:i,))) - @test size(expr, :i) == 2 - @test size(expr) == () - @test isempty(hyperinds(expr)) @test suminds(expr) == [:i] @test isempty(parsuminds(expr)) @@ -121,7 +103,7 @@ end @testset "outer product" begin - tensors = [EinExpr([:i, :j], Dict(:i => 2, :j => 3)), EinExpr([:k, :l], Dict(:k => 4, :l => 5))] + tensors = [EinExpr([:i, :j]), EinExpr([:k, :l])] expr = EinExpr([:i, :j, :k, :l], tensors) @test all(splat(==), zip(expr.head, (:i, :j, :k, :l))) @@ -131,11 +113,6 @@ @test all(splat(==), zip(inds(expr), head(expr))) @test ndims(expr) == 4 - for (i, d) in zip([:i, :j, :k, :l], [2, 3, 4, 5]) - @test size(expr, i) == d - end - @test size(expr) == (2, 3, 4, 5) - @test isempty(hyperinds(expr)) @test isempty(suminds(expr)) @test isempty(parsuminds(expr)) @@ -151,7 +128,7 @@ @testset "inner product" begin @testset "Vector" begin - tensors = [EinExpr([:i], Dict(:i => 2)), EinExpr([:i], Dict(:i => 2))] + tensors = [EinExpr([:i]), EinExpr([:i])] expr = EinExpr(Symbol[], tensors) @test isempty(expr.head) @@ -161,9 +138,6 @@ @test all(splat(==), zip(inds(expr), (:i,))) @test ndims(expr) == 0 - @test size(expr, :i) == 2 - @test size(expr) == () - @test isempty(hyperinds(expr)) @test suminds(expr) == [:i] @test parsuminds(expr) == [[:i]] @@ -173,7 +147,7 @@ @test isempty(neighbours(expr, :i)) end @testset "Matrix" begin - tensors = [EinExpr([:i, :j], Dict(:i => 2, :j => 3)), EinExpr([:i, :j], Dict(:i => 2, :j => 3))] + tensors = [EinExpr([:i, :j]), EinExpr([:i, :j])] expr = EinExpr(Symbol[], tensors) @test isempty(expr.head) @@ -183,10 +157,6 @@ @test all(splat(==), zip(inds(expr), [:i, :j])) @test ndims(expr) == 0 - @test size(expr, :i) == 2 - @test size(expr, :j) == 3 - @test size(expr) == () - @test isempty(hyperinds(expr)) @test issetequal(suminds(expr), [:i, :j]) @test Set(Set.(parsuminds(expr))) == Set([Set([:i, :j])]) @@ -199,7 +169,7 @@ end @testset "matrix multiplication" begin - tensors = [EinExpr([:i, :k], Dict(:i => 2, :k => 3)), EinExpr([:k, :j], Dict(:k => 3, :j => 4))] + tensors = [EinExpr([:i, :k]), EinExpr([:k, :j])] expr = EinExpr([:i, :j], tensors) @test all(splat(==), zip(expr.head, [:i, :j])) @@ -209,11 +179,6 @@ @test all(splat(==), zip(inds(expr), (:i, :k, :j))) @test ndims(expr) == 2 - @test size(expr, :i) == 2 - @test size(expr, :j) == 4 - @test size(expr, :k) == 3 - @test size(expr) == (2, 4) - @test isempty(hyperinds(expr)) @test suminds(expr) == [:k] @test parsuminds(expr) == [[:k]] @@ -228,21 +193,13 @@ @testset "hyperindex contraction" begin @testset "hyperindex is not summed" begin - tensors = [ - EinExpr([:i, :β, :j], Dict(i => 2 for i in [:i, :β, :j])), - EinExpr([:k, :β], Dict(i => 2 for i in [:k, :β])), - EinExpr([:β, :l, :m], Dict(i => 2 for i in [:β, :l, :m])), - ] - + tensors = [EinExpr([:i, :β, :j]), EinExpr([:k, :β]), EinExpr([:β, :l, :m])] expr = sum(tensors, skip = [:β]) @test issetequal(head(expr), (:i, :j, :k, :l, :m, :β)) @test issetequal(inds(expr), (:i, :j, :k, :l, :m, :β)) @test ndims(expr) == 6 - @test all(i -> size(expr, i) == 2, inds(expr)) - @test size(expr) == tuple(fill(2, 6)...) - @test issetequal(hyperinds(expr), [:β]) @test isempty(suminds(expr)) @test_broken isempty(parsuminds(expr)) @@ -257,12 +214,7 @@ end @testset "hyperindex is summed" begin - tensors = [ - EinExpr([:i, :β, :j], Dict(i => 2 for i in [:i, :β, :j])), - EinExpr([:k, :β], Dict(i => 2 for i in [:k, :β])), - EinExpr([:β, :l, :m], Dict(i => 2 for i in [:β, :l, :m])), - ] - + tensors = [EinExpr([:i, :β, :j]), EinExpr([:k, :β]), EinExpr([:β, :l, :m])] expr = sum(tensors) @test all(splat(==), zip(expr.head, (:i, :j, :k, :l, :m))) @@ -272,9 +224,6 @@ @test issetequal(inds(expr), (:i, :j, :k, :l, :m, :β)) @test ndims(expr) == 5 - @test all(i -> size(expr, i) == 2, inds(expr)) - @test size(expr) == tuple(fill(2, 5)...) - @test issetequal(hyperinds(expr), [:β]) @test issetequal(suminds(expr), [:β]) @test issetequal(parsuminds(expr), [[:β]]) @@ -291,13 +240,13 @@ @testset "manual path" begin tensors = [ - EinExpr([:j, :b, :i, :h], Dict(i => 2 for i in [:j, :b, :i, :h])), - EinExpr([:a, :c, :e, :f], Dict(i => 2 for i in [:a, :c, :e, :f])), - EinExpr([:j], Dict(i => 2 for i in [:j])), - EinExpr([:e, :a, :g], Dict(i => 2 for i in [:e, :a, :g])), - EinExpr([:f, :b], Dict(i => 2 for i in [:f, :b])), - EinExpr([:i, :h, :d], Dict(i => 2 for i in [:i, :h, :d])), - EinExpr([:d, :g, :c], Dict(i => 2 for i in [:d, :g, :c])), + EinExpr([:j, :b, :i, :h]), + EinExpr([:a, :c, :e, :f]), + EinExpr([:j]), + EinExpr([:e, :a, :g]), + EinExpr([:f, :b]), + EinExpr([:i, :h, :d]), + EinExpr([:d, :g, :c]), ] path = EinExpr(Symbol[], tensors) diff --git a/test/Exhaustive_test.jl b/test/Exhaustive_test.jl index a27316c..d4954ed 100644 --- a/test/Exhaustive_test.jl +++ b/test/Exhaustive_test.jl @@ -1,31 +1,35 @@ @testset "Exhaustive" begin + sizedict = Dict(i => 2 for i in [:a, :b, :c, :d, :e, :f, :g, :h, :i, :j]) tensors = [ - EinExpr([:j, :b, :i, :h], Dict(i => 2 for i in [:j, :b, :i, :h])), - EinExpr([:a, :c, :e, :f], Dict(i => 2 for i in [:a, :c, :e, :f])), - EinExpr([:j], Dict(i => 2 for i in [:j])), - EinExpr([:e, :a, :g], Dict(i => 2 for i in [:e, :a, :g])), - EinExpr([:f, :b], Dict(i => 2 for i in [:f, :b])), - EinExpr([:i, :h, :d], Dict(i => 2 for i in [:i, :h, :d])), - EinExpr([:d, :g, :c], Dict(i => 2 for i in [:d, :g, :c])), + EinExpr([:j, :b, :i, :h]), + EinExpr([:a, :c, :e, :f]), + EinExpr([:j]), + EinExpr([:e, :a, :g]), + EinExpr([:f, :b]), + EinExpr([:i, :h, :d]), + EinExpr([:d, :g, :c]), ] + expr = EinExpr(Symbol[], tensors) + sexpr = SizedEinExpr(expr, sizedict) - path = einexpr(Exhaustive, EinExpr(Symbol[], tensors)) + path = einexpr(Exhaustive, sexpr) - @test path isa EinExpr + @test path isa SizedEinExpr - @test mapreduce(flops, +, Branches(path)) == 90 + @test mapreduce(flops, +, Branches(path)) == 92 - @test all(splat(issetequal), zip(contractorder(path), [[:a, :e], [:c, :g], [:f], [:d], [:b, :i, :h], [:j]])) + @test all(splat(issetequal), zip(contractorder(path), [[:a, :e], [:c, :g], [:f], [:j], [:i, :h], [:d, :b]])) @testset "hyperedges" begin - a = EinExpr([:i, :β, :j], Dict(i => 2 for i in [:i, :β, :j])) - b = EinExpr([:k, :β], Dict(i => 2 for i in [:k, :β])) - c = EinExpr([:β, :l, :m], Dict(i => 2 for i in [:β, :l, :m])) + sizedict = Dict(i => 2 for i in [:i, :j, :k, :l, :m, :β]) + a = EinExpr([:i, :β, :j]) + b = EinExpr([:k, :β]) + c = EinExpr([:β, :l, :m]) - path = einexpr(EinExprs.Exhaustive(), sum([a, b, c], skip = [:β])) + path = einexpr(EinExprs.Exhaustive(), SizedEinExpr(sum([a, b, c], skip = [:β]), sizedict)) @test all(∋(:β) ∘ head, branches(path)) - path = einexpr(EinExprs.Exhaustive(), sum([a, b, c], skip = Symbol[])) + path = einexpr(EinExprs.Exhaustive(), SizedEinExpr(sum([a, b, c], skip = Symbol[]), sizedict)) @test all(∋(:β) ∘ head, branches(path)[1:end-1]) @test all(!∋(:β) ∘ head, branches(path)[end:end]) end diff --git a/test/Greedy_test.jl b/test/Greedy_test.jl index f010dfb..7416719 100644 --- a/test/Greedy_test.jl +++ b/test/Greedy_test.jl @@ -1,42 +1,45 @@ @testset "Greedy" begin + sizedict = Dict(i => 2 for i in [:a, :b, :c, :d, :e, :f, :g, :h, :i, :j]) tensors = [ - EinExpr([:j, :b, :i, :h], Dict(i => 2 for i in [:j, :b, :i, :h])), - EinExpr([:a, :c, :e, :f], Dict(i => 2 for i in [:a, :c, :e, :f])), - EinExpr([:j], Dict(i => 2 for i in [:j])), - EinExpr([:e, :a, :g], Dict(i => 2 for i in [:e, :a, :g])), - EinExpr([:f, :b], Dict(i => 2 for i in [:f, :b])), - EinExpr([:i, :h, :d], Dict(i => 2 for i in [:i, :h, :d])), - EinExpr([:d, :g, :c], Dict(i => 2 for i in [:d, :g, :c])), + EinExpr([:j, :b, :i, :h]), + EinExpr([:a, :c, :e, :f]), + EinExpr([:j]), + EinExpr([:e, :a, :g]), + EinExpr([:f, :b]), + EinExpr([:i, :h, :d]), + EinExpr([:d, :g, :c]), ] + expr = sum(tensors) - path = einexpr(Greedy(), EinExpr(Symbol[], tensors)) + path = einexpr(Greedy(), SizedEinExpr(expr, sizedict)) - @test path isa EinExpr + @test path isa SizedEinExpr @test mapreduce(flops, +, Branches(path)) == 100 @test all(splat(issetequal), zip(contractorder(path), [[:i, :h], [:j], [:a, :e], [:g, :c], [:f], [:b, :d]])) @testset "example: let unchanged" begin - tensors = [ - EinExpr([:i, :j, :k], Dict(:i => 2, :j => 2, :k => 2)), - EinExpr([:k, :l, :m], Dict(:k => 2, :l => 2, :m => 2)), - ] + sizedict = Dict(i => 2 for i in [:i, :j, :k, :l, :m]) + tensors = [EinExpr([:i, :j, :k]), EinExpr([:k, :l, :m])] + expr = sum(tensors, skip = [:i, :j, :l, :m]) + sexpr = SizedEinExpr(expr, sizedict) - path = einexpr(Greedy(), EinExpr(Symbol[:i, :j, :l, :m], tensors)) + path = einexpr(Greedy(), sexpr) @test suminds(path) == [:k] end @testset "hyperedges" begin - a = EinExpr([:i, :β, :j], Dict(i => 2 for i in [:i, :β, :j])) - b = EinExpr([:k, :β], Dict(i => 2 for i in [:k, :β])) - c = EinExpr([:β, :l, :m], Dict(i => 2 for i in [:β, :l, :m])) + sizedict = Dict(i => 2 for i in [:i, :j, :k, :l, :m, :β]) + a = EinExpr([:i, :β, :j]) + b = EinExpr([:k, :β]) + c = EinExpr([:β, :l, :m]) - path = einexpr(EinExprs.Greedy(), sum([a, b, c], skip = [:β])) + path = einexpr(EinExprs.Greedy(), SizedEinExpr(sum([a, b, c], skip = [:β]), sizedict)) @test all(∋(:β) ∘ head, branches(path)) - path = einexpr(EinExprs.Greedy(), sum([a, b, c], skip = Symbol[])) + path = einexpr(EinExprs.Greedy(), SizedEinExpr(sum([a, b, c], skip = Symbol[]), sizedict)) @test all(∋(:β) ∘ head, branches(path)[1:end-1]) @test all(!∋(:β) ∘ head, branches(path)[end:end]) end diff --git a/test/KaHyPar_test.jl b/test/KaHyPar_test.jl index f4cdc84..0794b45 100644 --- a/test/KaHyPar_test.jl +++ b/test/KaHyPar_test.jl @@ -9,10 +9,11 @@ EinExpr([:i, :h, :d], Dict(i => 2 for i in [:i, :h, :d])), EinExpr([:d, :g, :c], Dict(i => 2 for i in [:d, :g, :c])), ] + sexpr = sum(tensors) - path = einexpr(HyPar(imbalance=0.42), EinExpr(Symbol[], tensors)) + path = einexpr(HyPar(imbalance = 0.42), sexpr) - @test path isa EinExpr + @test path isa SizedEinExpr @test mapreduce(flops, +, Branches(path)) == 108 end @@ -40,10 +41,11 @@ EinExpr([:A, :W], Dict(:A => 6, :W => 6)), EinExpr([:a, :C, :d], Dict(:a => 3, :d => 6, :C => 4)), ] + sexpr = sum(tensors) - path = einexpr(HyPar(imbalance=0.45), EinExpr(Symbol[], tensors)) + path = einexpr(HyPar(imbalance = 0.45), sexpr) - @test path isa EinExpr + @test path isa SizedEinExpr @test mapreduce(flops, +, Branches(path)) == 19099592 end diff --git a/test/Naive_test.jl b/test/Naive_test.jl index 3e01037..2b8f15c 100644 --- a/test/Naive_test.jl +++ b/test/Naive_test.jl @@ -1,21 +1,22 @@ @testset "Naive" begin + sizedict = Dict(i => 2 for i in [:a, :b, :c, :d, :e, :f, :g, :h, :i, :j]) tensors = [ - EinExpr([:j, :b, :i, :h], Dict(i => 2 for i in [:j, :b, :i, :h])), - EinExpr([:a, :c, :e, :f], Dict(i => 2 for i in [:a, :c, :e, :f])), - EinExpr([:j], Dict(i => 2 for i in [:j])), - EinExpr([:e, :a, :g], Dict(i => 2 for i in [:e, :a, :g])), - EinExpr([:f, :b], Dict(i => 2 for i in [:f, :b])), - EinExpr([:i, :h, :d], Dict(i => 2 for i in [:i, :h, :d])), - EinExpr([:d, :g, :c], Dict(i => 2 for i in [:d, :g, :c])), + EinExpr([:j, :b, :i, :h]), + EinExpr([:a, :c, :e, :f]), + EinExpr([:j]), + EinExpr([:e, :a, :g]), + EinExpr([:f, :b]), + EinExpr([:i, :h, :d]), + EinExpr([:d, :g, :c]), ] - path = einexpr(EinExprs.Naive(), EinExpr(Symbol[], tensors)) + path = einexpr(EinExprs.Naive(), EinExpr(Symbol[], tensors), sizedict) @test path isa EinExpr @test foldl((a, b) -> sum([a, b]), tensors) == path # TODO traverse through the tree and check everything is ok - @test mapreduce(flops, +, Branches(path)) == 872 + @test mapreduce(flops(sizedict), +, Branches(path)) == 872 # FIXME non-determinist behaviour on order @test all( @@ -24,14 +25,15 @@ ) @testset "hyperedges" begin - a = EinExpr([:i, :β, :j], Dict(i => 2 for i in [:i, :β, :j])) - b = EinExpr([:k, :β], Dict(i => 2 for i in [:k, :β])) - c = EinExpr([:β, :l, :m], Dict(i => 2 for i in [:β, :l, :m])) + sizedict = Dict(i => 2 for i in [:i, :j, :k, :l, :m, :β]) + a = EinExpr([:i, :β, :j]) + b = EinExpr([:k, :β]) + c = EinExpr([:β, :l, :m]) - path = einexpr(EinExprs.Naive(), sum([a, b, c], skip = [:β])) + path = einexpr(EinExprs.Naive(), sum([a, b, c], skip = [:β]), sizedict) @test all(∋(:β) ∘ head, branches(path)) - path = einexpr(EinExprs.Naive(), sum([a, b, c], skip = Symbol[])) + path = einexpr(EinExprs.Naive(), sum([a, b, c], skip = Symbol[]), sizedict) @test all(∋(:β) ∘ head, branches(path)[1:end-1]) @test all(!∋(:β) ∘ head, branches(path)[end:end]) end diff --git a/test/SizedEinExpr_test.jl b/test/SizedEinExpr_test.jl new file mode 100644 index 0000000..3dccf4d --- /dev/null +++ b/test/SizedEinExpr_test.jl @@ -0,0 +1,23 @@ +@testset "SizedEinExpr" begin + using LinearAlgebra + + tensor = EinExpr([:i, :j]) + expr = EinExpr([:i, :j], [tensor]) + sexpr = SizedEinExpr(expr, Dict(:i => 2, :j => 3)) + + @test head(sexpr) === head(expr) === sexpr.head + @test args(expr) === sexpr.args + @test args(sexpr) == map(Base.Fix2(SizedEinExpr, Dict(:i => 2, :j => 3)), args(expr)) + @test EinExprs.nargs(sexpr) == EinExprs.nargs(expr) + + @test inds(sexpr) == inds(expr) + @test ndims(sexpr) == ndims(expr) + @test length(sexpr) == 6 + + @test size(sexpr, :i) == 2 + @test size(sexpr, :j) == 3 + @test size(sexpr) == (2, 3) + + @test select(sexpr, :i) == SizedEinExpr[sexpr, SizedEinExpr(tensor, Dict(:i => 2, :j => 3))] + @test select(sexpr, :j) == SizedEinExpr[sexpr, SizedEinExpr(tensor, Dict(:i => 2, :j => 3))] +end diff --git a/test/Slicing_test.jl b/test/Slicing_test.jl index 720038d..5980863 100644 --- a/test/Slicing_test.jl +++ b/test/Slicing_test.jl @@ -42,49 +42,32 @@ [ EinExpr( [:m, :f, :g], - [ - EinExpr( - [:m, :f, :q], - Dict(i => sizes[i] for i in [:m, :f, :q]), - ), - EinExpr( - [:g, :q], - Dict(i => sizes[i] for i in [:g, :q]), - ), - ], - ), - EinExpr( - [:o, :i, :m, :c], - Dict(i => sizes[i] for i in [:o, :i, :m, :c]), + [EinExpr((:m, :f, :q),), EinExpr((:g, :q),)], ), + EinExpr((:o, :i, :m, :c),), ], ), - EinExpr([:f, :l, :i], Dict(i => sizes[i] for i in [:f, :l, :i])), + EinExpr((:f, :l, :i)), ], ), - EinExpr([:g, :n, :l, :a], Dict(i => sizes[i] for i in [:g, :n, :l, :a])), - ], - ), - EinExpr( - [:e, :d, :o], - [ - EinExpr([:b, :e], Dict(i => sizes[i] for i in [:b, :e])), - EinExpr([:d, :b, :o], Dict(i => sizes[i] for i in [:d, :b, :o])), + EinExpr((:g, :n, :l, :a)), ], ), + EinExpr([:e, :d, :o], [EinExpr((:b, :e)), EinExpr((:d, :b, :o))]), ], ), - EinExpr([:c, :e, :h], Dict(i => sizes[i] for i in [:c, :e, :h])), + EinExpr((:c, :e, :h)), ], ), - EinExpr([:k, :d, :h, :a, :n, :j], Dict(i => sizes[i] for i in [:k, :d, :h, :a, :n, :j])), + EinExpr((:k, :d, :h, :a, :n, :j)), ], ), - EinExpr([:p, :k], Dict(i => sizes[i] for i in [:p, :k])), + EinExpr((:p, :k)), ], ) + sexpr = SizedEinExpr(expr, sizes) - cuttings = findslices(FlopsScorer(), expr, slices = 1000) + cuttings = findslices(FlopsScorer(), sexpr, slices = 1000) - @test prod(i -> size(expr, i), cuttings) >= 1000 + @test prod(i -> sizes[i], cuttings) >= 1000 end diff --git a/test/ext/Makie_test.jl b/test/ext/Makie_test.jl index 74cccdb..98543a7 100644 --- a/test/ext/Makie_test.jl +++ b/test/ext/Makie_test.jl @@ -34,8 +34,8 @@ EinExpr([:g, :q], filter(p -> p.first ∈ [:g, :q], sizes)), EinExpr([:d, :b, :o], filter(p -> p.first ∈ [:d, :b, :o], sizes)), ] - - path = einexpr(Exhaustive(), EinExpr([:p, :j], tensors)) + expr = sum(tensors, skip = [:p, :j]) + path = einexpr(Exhaustive(), expr) @testset "plot!" begin f = Figure() diff --git a/test/runtests.jl b/test/runtests.jl index ead32fd..9248f5f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,7 @@ using EinExprs @testset "Unit tests" verbose = true begin include("EinExpr_test.jl") + include("SizedEinExpr_test.jl") include("Counters_test.jl") @testset "Optimizers" begin include("Naive_test.jl")