Skip to content

Commit

Permalink
Refactor selectdim and slicing methods for SizedEinExpr
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Dec 28, 2023
1 parent f1e7d31 commit 49448ae
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 31 deletions.
56 changes: 27 additions & 29 deletions src/Slicing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,34 +13,33 @@ Project `index` to dimension `i` in a EinExpr. This is equivalent to tensor cutt
See also: [`view`](@ref).
"""
function Base.selectdim(path::EinExpr, index::Symbol, i)
Base.selectdim(path::EinExpr, ::Symbol, i) = path

function Base.selectdim(path::EinExpr, index::Symbol, i::Integer)
path = deepcopy(path)

for leave in Iterators.filter((index) head, Leaves(path))
leave.size[index] = length(i)
for expr in PreOrderDFS(path)
filter!(!=(index), expr.head)
end

return path
end

function Base.selectdim(path::EinExpr, index::Symbol, _::Integer)
path = deepcopy(path)
function Base.selectdim(sexpr::SizedEinExpr, index::Symbol, i)
path = selectdim(sexpr.path, index, i)

index head(path) && (path = EinExpr(filter(!=(index), path.head), path.args))

for branch in Branches(path)
for arg in Iterators.filter((index) head, branch.args)
replace!(
branch.args,
arg => EinExpr(
filter(!=(index), arg.head),
isempty(arg.args) ? filter(p -> p.first != index, arg.size) : arg.args,
),
)
end
end
size = copy(sexpr.size)
size[index] = length(i)

return path
return SizedEinExpr(path, size)
end

function Base.selectdim(sexpr::SizedEinExpr, index::Symbol, i::Integer)
path = selectdim(sexpr.path, index, i)

size = filter(!=(index) first, sexpr.size)

return SizedEinExpr(path, size)
end

"""
Expand Down Expand Up @@ -88,8 +87,7 @@ Reimplementation based on [`contengra`](https://github.com/jcmgray/cotengra)'s `
"""
function findslices(
scorer,
path::EinExpr,
sizedict;
path;
size = nothing,
overhead = nothing,
slices = nothing,
Expand All @@ -101,8 +99,8 @@ function findslices(

candidates = Set(setdiff(mapreduce(head, , PostOrderDFS(path)), skip))
solution = Set{Symbol}()
current = (; slices = 1, size = maximum(Base.Fix2(length, sizedict), PostOrderDFS(path)), overhead = 1.0)
original_flops = mapreduce(flops(sizedict), +, Branches(path; inverse = true))
current = (; slices = 1, size = maximum(length, PostOrderDFS(path)), overhead = 1.0)
original_flops = mapreduce(flops, +, Branches(path; inverse = true))

sliced_path = path
while !isempty(candidates)
Expand All @@ -114,15 +112,15 @@ function findslices(

sliced_path = selectdim(sliced_path, winner, 1)
cur_overhead =
prod(i -> sizedict[i], [solution..., winner]) *
mapreduce(flops(sizedict), +, Branches(sliced_path; inverse = true)) / original_flops
prod(i -> Base.size(path, i), [solution..., winner]) *
mapreduce(flops, +, Branches(sliced_path; inverse = true)) / original_flops

!isnothing(overhead) && cur_overhead > overhead && break
push!(solution, winner)

current = (;
slices = current.slices * Base.Fix2(length, sizedict)(path, winner),
size = maximum(Base.Fix2(length, sizedict), PostOrderDFS(sliced_path)),
slices = current.slices * Base.size(path, winner),
size = maximum(length, PostOrderDFS(sliced_path)),
overhead = cur_overhead,
)

Expand Down Expand Up @@ -150,7 +148,7 @@ function (cb::FlopsScorer)(path, index)
slice = selectdim(path, index, 1)

flops_reduction = mapreduce(flops, +, PostOrderDFS(path)) - mapreduce(flops, +, PostOrderDFS(slice))
write_reduction = mapreduce(prod size, +, PostOrderDFS(path)) - mapreduce(prod size, +, PostOrderDFS(slice))
write_reduction = mapreduce(length, +, PostOrderDFS(path)) - mapreduce(length, +, PostOrderDFS(slice))

log(flops_reduction + write_reduction * cb.weight + 1)
end
Expand All @@ -170,7 +168,7 @@ function (cb::SizeScorer)(path, index)
slice = selectdim(path, index, 1)

flops_reduction = mapreduce(flops, +, PostOrderDFS(path)) - mapreduce(flops, +, PostOrderDFS(slice))
write_reduction = mapreduce(prod size, +, PostOrderDFS(path)) - mapreduce(prod size, +, PostOrderDFS(slice))
write_reduction = mapreduce(length, +, PostOrderDFS(path)) - mapreduce(length, +, PostOrderDFS(slice))

log(write_reduction + flops_reduction * cb.weight + 1)
end
5 changes: 3 additions & 2 deletions test/Slicing_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@
EinExpr((:p, :k)),
],
)
sexpr = SizedEinExpr(expr, sizes)

cuttings = findslices(FlopsScorer(), expr, slices = 1000)
cuttings = findslices(FlopsScorer(), sexpr, slices = 1000)

@test prod(i -> sizedict[i], cuttings) >= 1000
@test prod(i -> sizes[i], cuttings) >= 1000
end

0 comments on commit 49448ae

Please sign in to comment.