Skip to content

Commit

Permalink
Extend KrylovKit.eigsolve by allowing the initial guess x₀ to be …
Browse files Browse the repository at this point in the history
…a `Tensor` (#171)

* Extend eigsolve by allowing initial guess to be a Tensor

* Minor fixes

* Apply suggestions from code review

Co-authored-by: Sergio Sánchez Ramírez <15837247+mofeing@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Sergio Sánchez Ramírez <15837247+mofeing@users.noreply.github.com>

* Deduplicate code

* Reorder code to reuse variable

---------

Co-authored-by: Sergio Sánchez Ramírez <15837247+mofeing@users.noreply.github.com>
  • Loading branch information
jofrevalles and mofeing authored Jul 16, 2024
1 parent 0500be5 commit 0b68cf6
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 19 deletions.
60 changes: 43 additions & 17 deletions ext/TenetKrylovKitExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,28 @@ function eigsolve_prehook_tensor_reshape(A::Tensor, left_inds, right_inds)
return Amat, left_sizes, right_sizes
end

function KrylovKit.eigselector(A::Tensor, T::Type; left_inds=Symbol[], right_inds=Symbol[], kwargs...)
Amat, _, _ = eigsolve_prehook_tensor_reshape(A, left_inds, right_inds)
return KrylovKit.eigselector(Amat, T; kwargs...)
function eigsolve_prehook_tensor_reshape(A::Tensor, x₀::Tensor, left_inds, right_inds)
left_inds, right_inds = Tenet.factorinds(A, left_inds, right_inds)

Amat, left_sizes, right_sizes = eigsolve_prehook_tensor_reshape(A, left_inds, right_inds)
prod_left_sizes = prod(left_sizes)

inds(x₀) != left_inds && throw(
ArgumentError(
"The initial guess must have the same left indices as the tensor, but got $(inds(x₀)) and $left_inds."
),
)
prod(size.((x₀,), left_inds)) != prod_left_sizes && throw(
ArgumentError(
"The initial guess must have the same size as the left indices, but got sizes $prod_x₀_sizes and $prod_left_sizes.",
),
)

# Permute and reshape the tensor
x₀ = permutedims(x₀, left_inds)
x₀vec = reshape(parent(x₀), prod_left_sizes)

return Amat, left_sizes, right_sizes, x₀vec
end

function KrylovKit.eigsolve(
Expand All @@ -50,20 +69,6 @@ function KrylovKit.eigsolve(
return vals, Avecs, info
end

function KrylovKit.eigsolve(
f::Tensor, x₀, howmany::Int=1, which::KrylovKit.Selector=:LM; left_inds=Symbol[], right_inds=Symbol[], kwargs...
)
Amat, left_sizes, right_sizes = eigsolve_prehook_tensor_reshape(A, left_inds, right_inds)

# Compute eigenvalues and eigenvectors
vals, vecs, info = KrylovKit.eigsolve(Amat, x₀, howmany, which; kwargs...)

# Tensorify the eigenvectors
Avecs = [Tensor(reshape(vec, left_sizes...), left_inds) for vec in vecs]

return vals, Avecs, info
end

"""
KrylovKit.eigsolve(tensor::Tensor; left_inds, right_inds, kwargs...)
Expand Down Expand Up @@ -95,4 +100,25 @@ function KrylovKit.eigsolve(
return vals, Avecs, info
end

function KrylovKit.eigsolve(
A::Tensor,
x₀::Tensor,
howmany::Int,
which::KrylovKit.Selector,
alg::Algorithm;
left_inds=inds(x₀),
right_inds=Symbol[],
kwargs...,
) where {Algorithm<:KrylovKit.Lanczos} # KrylovKit.KrylovAlgorithm}
Amat, left_sizes, right_sizes, x₀vec = eigsolve_prehook_tensor_reshape(A, x₀, left_inds, right_inds)

# Compute eigenvalues and eigenvectors
vals, vecs, info = KrylovKit.eigsolve(Amat, x₀vec, howmany, which, alg; kwargs...)

# Tensorify the eigenvectors
Avecs = [Tensor(reshape(vec, left_sizes...), left_inds) for vec in vecs]

return vals, Avecs, info
end

end
20 changes: 18 additions & 2 deletions test/integration/KrylovKit_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,24 @@
@test parent(reconstructed_tensor) parent(transpose(reconstructed_tensor_perm))

@testset "Lanczos" begin
vals_lanczos, vecs_lanczos = eigsolve(
tensor, rand(ComplexF64, 4), 1, :SR, Lanczos(; krylovdim=2, tol=1e-16); left_inds=[:i], right_inds=[:j]
@test_throws ArgumentError eigsolve(
tensor,
Tensor(rand(ComplexF64, 4), (:j,)),
1,
:SR,
Lanczos(; krylovdim=2, tol=1e-16);
left_inds=[:i],
right_inds=[:j],
)

vals_lanczos, vecs_lanczos, info = eigsolve(
tensor,
Tensor(rand(ComplexF64, 4), (:i,)),
1,
:SR,
Lanczos(; krylovdim=2, tol=1e-16);
left_inds=[:i],
right_inds=[:j],
)

@test length(vals_lanczos) == 1
Expand Down

0 comments on commit 0b68cf6

Please sign in to comment.