diff --git a/src/Slicing.jl b/src/Slicing.jl index 82e6b07..281ecf3 100644 --- a/src/Slicing.jl +++ b/src/Slicing.jl @@ -13,34 +13,33 @@ Project `index` to dimension `i` in a EinExpr. This is equivalent to tensor cutt See also: [`view`](@ref). """ -function Base.selectdim(path::EinExpr, index::Symbol, i) +Base.selectdim(path::EinExpr, ::Symbol, i) = path + +function Base.selectdim(path::EinExpr, index::Symbol, i::Integer) path = deepcopy(path) - for leave in Iterators.filter(∋(index) ∘ head, Leaves(path)) - leave.size[index] = length(i) + for expr in PreOrderDFS(path) + filter!(!=(index), expr.head) end return path end -function Base.selectdim(path::EinExpr, index::Symbol, _::Integer) - path = deepcopy(path) +function Base.selectdim(sexpr::SizedEinExpr, index::Symbol, i) + path = selectdim(sexpr.path, index, i) - index ∈ head(path) && (path = EinExpr(filter(!=(index), path.head), path.args)) - - for branch in Branches(path) - for arg in Iterators.filter(∋(index) ∘ head, branch.args) - replace!( - branch.args, - arg => EinExpr( - filter(!=(index), arg.head), - isempty(arg.args) ? filter(p -> p.first != index, arg.size) : arg.args, - ), - ) - end - end + size = copy(sexpr.size) + size[index] = length(i) - return path + return SizedEinExpr(path, size) +end + +function Base.selectdim(sexpr::SizedEinExpr, index::Symbol, i::Integer) + path = selectdim(sexpr.path, index, i) + + size = filter(!=(index) ∘ first, sexpr.size) + + return SizedEinExpr(path, size) end """ @@ -88,8 +87,7 @@ Reimplementation based on [`contengra`](https://github.com/jcmgray/cotengra)'s ` """ function findslices( scorer, - path::EinExpr, - sizedict; + path; size = nothing, overhead = nothing, slices = nothing, @@ -101,8 +99,8 @@ function findslices( candidates = Set(setdiff(mapreduce(head, ∪, PostOrderDFS(path)), skip)) solution = Set{Symbol}() - current = (; slices = 1, size = maximum(Base.Fix2(length, sizedict), PostOrderDFS(path)), overhead = 1.0) - original_flops = mapreduce(flops(sizedict), +, Branches(path; inverse = true)) + current = (; slices = 1, size = maximum(length, PostOrderDFS(path)), overhead = 1.0) + original_flops = mapreduce(flops, +, Branches(path; inverse = true)) sliced_path = path while !isempty(candidates) @@ -114,15 +112,15 @@ function findslices( sliced_path = selectdim(sliced_path, winner, 1) cur_overhead = - prod(i -> sizedict[i], [solution..., winner]) * - mapreduce(flops(sizedict), +, Branches(sliced_path; inverse = true)) / original_flops + prod(i -> Base.size(path, i), [solution..., winner]) * + mapreduce(flops, +, Branches(sliced_path; inverse = true)) / original_flops !isnothing(overhead) && cur_overhead > overhead && break push!(solution, winner) current = (; - slices = current.slices * Base.Fix2(length, sizedict)(path, winner), - size = maximum(Base.Fix2(length, sizedict), PostOrderDFS(sliced_path)), + slices = current.slices * Base.size(path, winner), + size = maximum(length, PostOrderDFS(sliced_path)), overhead = cur_overhead, ) @@ -150,7 +148,7 @@ function (cb::FlopsScorer)(path, index) slice = selectdim(path, index, 1) flops_reduction = mapreduce(flops, +, PostOrderDFS(path)) - mapreduce(flops, +, PostOrderDFS(slice)) - write_reduction = mapreduce(prod ∘ size, +, PostOrderDFS(path)) - mapreduce(prod ∘ size, +, PostOrderDFS(slice)) + write_reduction = mapreduce(length, +, PostOrderDFS(path)) - mapreduce(length, +, PostOrderDFS(slice)) log(flops_reduction + write_reduction * cb.weight + 1) end @@ -170,7 +168,7 @@ function (cb::SizeScorer)(path, index) slice = selectdim(path, index, 1) flops_reduction = mapreduce(flops, +, PostOrderDFS(path)) - mapreduce(flops, +, PostOrderDFS(slice)) - write_reduction = mapreduce(prod ∘ size, +, PostOrderDFS(path)) - mapreduce(prod ∘ size, +, PostOrderDFS(slice)) + write_reduction = mapreduce(length, +, PostOrderDFS(path)) - mapreduce(length, +, PostOrderDFS(slice)) log(write_reduction + flops_reduction * cb.weight + 1) end diff --git a/test/Slicing_test.jl b/test/Slicing_test.jl index 6e9f6ef..5980863 100644 --- a/test/Slicing_test.jl +++ b/test/Slicing_test.jl @@ -65,8 +65,9 @@ EinExpr((:p, :k)), ], ) + sexpr = SizedEinExpr(expr, sizes) - cuttings = findslices(FlopsScorer(), expr, slices = 1000) + cuttings = findslices(FlopsScorer(), sexpr, slices = 1000) - @test prod(i -> sizedict[i], cuttings) >= 1000 + @test prod(i -> sizes[i], cuttings) >= 1000 end