Using Zygote on cost functions with ITensors.truncate

Dear ITensor Teams

Thank you for developing such user-friendly and efficient code.

I am currently working on optimizing matrix product states (MPS)/ tensor trains (TT) using Zygote for automatic differentiation. Specifically, I am trying to optimize each element of the MPS/TT tensors. However, I have encountered an issue when attempting to compress the bond dimensions within the loss function.

In the provided code, I am using ITensors.truncate inside the ttfit function, which serves as the loss function. Unfortunately, this results in the error: “no method matching iterate(::Nothing)”. I suspect this error arises because ITensors.truncate might return nothing under certain conditions.

My main question is: What would be a suitable approach to perform bond dimension compression within the loss function, given the constraints of automatic differentiation with Zygote?
I would be grateful for any insights, suggestions, or best practices you could share regarding this problem.

I have attached a minimum code to reproduce my issue below.

using ITensors
using Random
using OptimKit
using Zygote
ITensors.disable_warn_order()

R = 8 #
D = 2 #
fit_itr = 10 #
sitedim = 2
sites = siteinds(sitedim, R)

#Make random MPS 
function make_random_mps(sitedims::Vector{Int64}, linkdims::Vector{Int64}, nsites::Int64)::Vector{Array{Float64, 3}}
    res = []
    for i in 1:nsites
        if i == 1
            push!(res, randn(1, sitedims[i], linkdims[i]))
        elseif i == nsites
            push!(res, randn(linkdims[end], sitedims[i], 1))
        else
            push!(res, randn(linkdims[i-1], sitedims[i], linkdims[i]))
        end
    end
    return res
end

# Get bond dimension
function linkdims(tt::Vector{Array{Float64, 3}})
    return [size(T, 1) for T in tt[2:end]]
end

# Convert MPS to ITensors.MPS
function ITensors.MPS(tt::Vector{Array{Float64, 3}}, sites=nothing)::MPS
    N = length(tt)
    localdims = [size(t, 2) for t in tt]
    if sites === nothing
        sites = [Index(localdims[n], "n=$n") for n in 1:N]
    else
        all(localdims .== [ITensors.dim(sites[i]) for i in 1:N]) || 
            error("ranks are not consistent with dimension of sites")
    end
    
    linkdims_ = [1, linkdims(tt)..., 1]
    links = [Index(linkdims_[l + 1], "link, l=$l") for l in 0:N]
    tensors_ = [ITensor(deepcopy(tt[n]), links[n], sites[n], links[n+1]) for n in 1:N]
    return MPS(tensors_)
end


# TensorTrainFit struct
struct TensorTrainFit{ValueType}
    tt::Vector{Array{ValueType, 3}}
    offsets::Vector{Int}
end

# Make offset, which is used to convert 1D array to MPS
function TensorTrainFit{ValueType}(tt) where {ValueType}
    offsets = [0]
    for n in 1:length(tt)
        push!(offsets, offsets[end] + length(tt[n]))
    end
    return TensorTrainFit{ValueType}(tt, offsets)
end


function (obj::TensorTrainFit{ValueType})(x::Vector{ValueType}) where {ValueType}
    tensors = to_tensors(obj, x) # Convert 1D array to MPS
    tensors2 = ITensors.MPS(tensors, sites) # Convert to ITensors.MPS
    tensors3 = ITensors.truncate(tensors2; cutoff=1e-10) # Truncate
    return inner(tensors3, tensors3) # Output inner product,  which is the loss function
    #return inner(tensors2, tensors2) # this is working if we comment out above 2 lines
    
    
end

# Convert each core tensors of MPS to 1D array
function to_tensors(obj::TensorTrainFit{ValueType}, x::Vector{ValueType}) where {ValueType}
    return [
        reshape(
            x[obj.offsets[n]+1:obj.offsets[n+1]],
            size(obj.tt[n])
        )
        for n in 1:length(obj.tt)
    ]
end

#Convert to 1D array
flatten_(obj::Vector{Array{ValueType, 3}}) where {ValueType} = vcat(vec.(obj)...)

tt0 = make_random_mps([2 for i in 1:R], [D for i in 1:R-1], R)
ttfit = TensorTrainFit{Float64}(tt0) # Create TensorTrainFit object
x0::Vector{Float64} = flatten_(tt0) #Convert to 1D array

function loss(x) 
    ttfit(x)
end 

optimizer = LBFGS(; maxiter=fit_itr, verbosity=1)
function loss_and_grad(x)
    y, (∇,) = withgradient(loss, x)
    return y, ∇
end

withgradient(loss, x0) # falied
xopt, fs, gs, niter, normgradhistory = optimize(loss_and_grad, x0, optimizer)

The version of the package I use is the following.

ITensors v0.3.68
OptimKit v0.3.1
Zygote v0.6.70

Thank you for your time.

Best,
Rihito

Unfortunately that won’t work right now because truncate uses matrix factorization like QR and SVD and we haven’t defined derivative rules for those operations for ITensors yet. You could try writing your own AD rules.

I’m not sure where that error you are quoting is coming from, that one we may be able to circumvent in a simpler way, but ultimately even if we do you will hit issues that are related to not having derivative rules for matrix factorizations. Could you post the entire stack trace as a reference?

Thank you for your response and for providing an answer to this issue.
For your reference, this is the stack trance:

Stacktrace:
  [1] indexed_iterate(I::Nothing, i::Int64)
    @ Base ./tuple.jl:95
  [2] chain_rrule_kw
    @ ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:235 [inlined]
  [3] macro expansion
    @ ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0 [inlined]
  [4] _pullback(::Zygote.Context{false}, ::typeof(Core.kwcall), ::@NamedTuple{allow_alias::Bool}, ::typeof(permute), ::ITensor, ::Index{Int64}, ::Index{Int64})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:87
  [5] #qx#333
    @ ~/.julia/packages/ITensors/4M2ep/src/tensor_operations/matrix_decomposition.jl:511 [inlined]
  [6] _pullback(::Zygote.Context{false}, ::ITensors.var\“##qx#333\“, ::TagSet, ::Bool, ::typeof(ITensors.qx), ::typeof(qr), ::ITensor, ::Vector{Index{Int64}}, ::Vector{Index{Int64}})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
  [7] qx
    @ ~/.julia/packages/ITensors/4M2ep/src/tensor_operations/matrix_decomposition.jl:488 [inlined]
  [8] _pullback(::Zygote.Context{false}, ::typeof(Core.kwcall), ::@NamedTuple{tags::TagSet, positive::Bool}, ::typeof(ITensors.qx), ::typeof(qr), ::ITensor, ::Vector{Index{Int64}}, ::Vector{Index{Int64}})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
  [9] #qr#329
    @ ~/.julia/packages/ITensors/4M2ep/src/tensor_operations/matrix_decomposition.jl:474 [inlined]
 [10] qr
    @ ~/.julia/packages/ITensors/4M2ep/src/tensor_operations/matrix_decomposition.jl:473 [inlined]
 [11] _pullback(::Zygote.Context{false}, ::typeof(Core.kwcall), ::@NamedTuple{tags::TagSet, positive::Bool}, ::typeof(qr), ::ITensor, ::Vector{Index{Int64}}, ::Vector{Index{Int64}})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [12] #qr#321
    @ ~/.julia/packages/ITensors/4M2ep/src/tensor_operations/matrix_decomposition.jl:458 [inlined]
 [13] _pullback(::Zygote.Context{false}, ::ITensors.var\“##qr#321\“, ::@Kwargs{tags::TagSet, positive::Bool}, ::typeof(qr), ::ITensor, ::Vector{Index{Int64}})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [14] qr
    @ ~/.julia/packages/ITensors/4M2ep/src/tensor_operations/matrix_decomposition.jl:458 [inlined]
 [15] _apply
    @ ./boot.jl:838 [inlined]
 [16] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
 [17] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [18] #factorize_qr#336
    @ ~/.julia/packages/ITensors/4M2ep/src/tensor_operations/matrix_decomposition.jl:566 [inlined]
 [19] _pullback(::Zygote.Context{false}, ::ITensors.var\“##factorize_qr#336\“, ::String, ::TagSet, ::Bool, ::typeof(ITensors.factorize_qr), ::ITensor, ::Vector{Index{Int64}})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [20] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [21] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
 [22] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [23] factorize_qr
    @ ~/.julia/packages/ITensors/4M2ep/src/tensor_operations/matrix_decomposition.jl:564 [inlined]
 [24] _pullback(::Zygote.Context{false}, ::typeof(Core.kwcall), ::@NamedTuple{ortho::String, tags::TagSet}, ::typeof(ITensors.factorize_qr), ::ITensor, ::Vector{Index{Int64}})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [25] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [26] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
 [27] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [28] #factorize#340
    @ ~/.julia/packages/ITensors/4M2ep/src/tensor_operations/matrix_decomposition.jl:841 [inlined]
 [29] _pullback(::Zygote.Context{false}, ::ITensors.var\“##factorize#340\“, ::Nothing, ::Nothing, ::Nothing, ::Nothing, ::TagSet, ::Nothing, ::Nothing, ::Nothing, ::Nothing, ::Nothing, ::Nothing, ::Nothing, ::Nothing, ::Nothing, ::typeof(factorize), ::ITensor, ::Vector{Index{Int64}})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [30] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [31] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
 [32] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [33] factorize
    @ ~/.julia/packages/ITensors/4M2ep/src/tensor_operations/matrix_decomposition.jl:746 [inlined]
 [34] _pullback(::Zygote.Context{false}, ::typeof(Core.kwcall), ::@NamedTuple{tags::TagSet, maxdim::Nothing}, ::typeof(factorize), ::ITensor, ::Vector{Index{Int64}})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [35] #orthogonalize!#336
    @ ~/.julia/packages/ITensors/4M2ep/src/ITensorMPS/abstractmps.jl:1610 [inlined]
 [36] _pullback(::Zygote.Context{false}, ::ITensors.ITensorMPS.var\“##orthogonalize!#336\“, ::Nothing, ::Nothing, ::typeof(orthogonalize!), ::MPS, ::Int64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [37] orthogonalize!
    @ ~/.julia/packages/ITensors/4M2ep/src/ITensorMPS/abstractmps.jl:1593 [inlined]
 [38] #truncate!#340
    @ ~/.julia/packages/ITensors/4M2ep/src/ITensorMPS/abstractmps.jl:1677 [inlined]
 [39] _pullback(::Zygote.Context{false}, ::ITensors.ITensorMPS.var\“##truncate!#340\“, ::UnitRange{Int64}, ::@Kwargs{cutoff::Float64}, ::typeof(truncate!), ::NDTensors.BackendSelection.Algorithm{:frobenius, @NamedTuple{}}, ::MPS)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [40] truncate!
    @ ~/.julia/packages/ITensors/4M2ep/src/ITensorMPS/abstractmps.jl:1670 [inlined]
 [41] _pullback(::Zygote.Context{false}, ::typeof(Core.kwcall), ::@NamedTuple{cutoff::Float64}, ::typeof(truncate!), ::NDTensors.BackendSelection.Algorithm{:frobenius, @NamedTuple{}}, ::MPS)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [42] #truncate!#339
    @ ~/.julia/packages/ITensors/4M2ep/src/ITensorMPS/abstractmps.jl:1667 [inlined]
 [43] _pullback(::Zygote.Context{false}, ::ITensors.ITensorMPS.var\“##truncate!#339\“, ::String, ::@Kwargs{cutoff::Float64}, ::typeof(truncate!), ::MPS)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [44] truncate!
    @ ~/.julia/packages/ITensors/4M2ep/src/ITensorMPS/abstractmps.jl:1666 [inlined]
 [45] #truncate#341
    @ ~/.julia/packages/ITensors/4M2ep/src/ITensorMPS/abstractmps.jl:1693 [inlined]
 [46] _pullback(::Zygote.Context{false}, ::ITensors.ITensorMPS.var\“##truncate#341\“, ::@Kwargs{cutoff::Float64}, ::typeof(truncate), ::MPS)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [47] truncate
    @ ~/.julia/packages/ITensors/4M2ep/src/ITensorMPS/abstractmps.jl:1691 [inlined]
 [48] _pullback(::Zygote.Context{false}, ::typeof(Core.kwcall), ::@NamedTuple{cutoff::Float64}, ::typeof(truncate), ::MPS)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [49] TensorTrainFit
    @ ~/Documents/Research/post-doc/todo-group/QuanticsOperatorLearning/notebook/fitting_MPS/to_shinaoka/MPS_with_autodiff_truncate_to_ItensorForum.ipynb:63 [inlined]
 [50] _pullback(ctx::Zygote.Context{false}, f::TensorTrainFit{Float64}, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [51] loss
    @ ~/Documents/Research/post-doc/todo-group/QuanticsOperatorLearning/notebook/fitting_MPS/to_shinaoka/MPS_with_autodiff_truncate_to_ItensorForum.ipynb:87 [inlined]
 [52] _pullback(ctx::Zygote.Context{false}, f::typeof(loss), args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [53] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:90
 [54] pullback
    @ ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:88 [inlined]
 [55] withgradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:205”
}

Thanks, from the stacktrace you can see that Zygote is indeed failing at some point inside of the QR decomposition:

#...
 [14] qr
    @ ~/.julia/packages/ITensors/4M2ep/src/tensor_operations/matrix_decomposition.jl:458 [inlined]
#...

Note that we definitely have plans to add AD support for matrix factorizations. However, we are in the process of rewriting the entire NDTensors.jl backend library where those operations are defined: [NDTensors] Roadmap to removing `TensorStorage` types · Issue #1250 · ITensor/ITensors.jl · GitHub, and it will be much easier to define AD rules once that rewrite is complete, so we don’t plan to add AD rules for matrix factorizations until then. You are of course free to write them yourself in the meantime (we would recommend writing AD rules using ChainRules.jl). We don’t have a specific timeline on when the NDTensors.jl rewrite will be complete and then when the AD rules will be written, but I would hope by the end of 2024. Being able to differentiate through matrix factorizations is the main limitation right now in users being able to differentiate arbitrary ITensor code using AD and would unlock a lot of new applications for AD, so we’re well aware it is a high priority thing to support.

Something else to keep in mind is that you may be able to write higher level AD rules that avoid having to AD through matrix factorization altogether, if you use some higher level knowledge of the structure of your program and the derivatives of certain steps of your program.

2 Likes

Dear Matt Fishman,

Thank you for sharing the detailed plan for implementing AD. I look forward to the completion of the rewrite and AD’s ability to use matrix factorization!

In the meantime, I will consider using ChainRules.jl. If you happen to know of any Julia packages that have implemented AD for matrix factorizations, could you please let me know? In my case, support for real numbers would be sufficient.

Best regards,
Rihito Sakurai

Best
Rihito

You can see derivative rules for LinearAlgebra.svd (real only) and LinearAlgebra.eigen (non-Hermitian and Hermitian, as well as real and complex) here: ChainRules.jl/src/rulesets/LinearAlgebra/factorization.jl at v1.68.0 · JuliaDiff/ChainRules.jl · GitHub and here: ChainRules.jl/src/rulesets/LinearAlgebra/symmetric.jl at v1.68.0 · JuliaDiff/ChainRules.jl · GitHub.

There are open pull requests for LinearAlgebra.qr here: Implement QR pullback by Kolaru · Pull Request #306 · JuliaDiff/ChainRules.jl · GitHub and here: Support for qr decomposition pullback by rkube · Pull Request #469 · JuliaDiff/ChainRules.jl · GitHub.

There are some derivative rules for more matrix factorizations not defined in ChainRules here: GitHub - GiggleLiu/BackwardsLinalg.jl: Auto differentiation over linear algebras (a Zygote extension). They are based on this blog post: Linear Algebra Autodiff (complex valued by Jinguo Liu. Those are written using the older ZygoteRules system but would be easy to translate to ChainRules. See also: List of papers, PDFs etc. with rules · Issue #117 · JuliaDiff/ChainRules.jl · GitHub for references on derivative rules of matrix factorizations and other operations.

It would be great if the rules in BackwardsLinalg.jl for QR and complex SVD were contributed to ChainRules.jl. That would be a good first start to getting derivative rules for those operations working in ITensor.

Additionally, even if those rules are all defined, there may be some issues that pop up differentiating through the NDTensors.jl and ITensors.jl wrappers for matrix factorizations like svd, qr, and eigen, since they additionally convert the ITensors to Julia matrices and then rewrap those Julia matrices back into ITensors, and Zygote.jl would need to differentiate through all of those operations. Any issues related to that should be relatively easy to fix, but that’s something to keep in mind for anyone looking for a fun project. There may also be additional complexities involved in differentiating through truncated SVD and eigendecompositions.

All of this should be relatively easy to implement for dense ITensor operations. Block sparse operations will be more complicated, and is the main reason why we haven’t implemented these rules yet, since we try to equally support dense and block sparse operations, and as I said the NDTensors.jl backend is being totally rewritten in a way to make it simpler and therefore make it easier to define matrix factorization AD rules for block sparse operations. It may not even require specialized code for block sparse operations if the AD rules are written generically enough, but we would have to see about that.

Finally, it may be worth trying out Enzyme.jl, which is an alternative AD system to Zygote.jl which is meant to cover more of the language, for example by working on in-place operations. I haven’t tried it to differentiate ITensor operations with Enzyme so I don’t know how well it works for that. Additionally, I see that it has limited support for differentiating matrix factorizations:

1 Like

Dear Matthew

Thank you very much for your kind and detailed explanation regarding my question. I now have a better understanding of why our code did not work. I will look through the useful links you provided. Thanks again for that.

Best regards
Rihito

This topic was automatically closed 10 days after the last reply. New replies are no longer allowed.