-
Notifications
You must be signed in to change notification settings - Fork 128
Closed
Description
Hello!
I have a code in which Zygote works ok if an ordinary ArrayReg is given, but fails in the case of BatchedArrayReg.
Here is a part of it:
function loss(param)
vqc_p = dispatch(vqc, param)
bstate_p = apply(bstate, vqc_p)
pr = probs(bstate_p)
pr_s = sum(pr, dims = 2)
r = -sum(pr_s[1:2^(q-1)])
return r
end
grad = Zygote.gradient(loss, param)[1]
where vqc is just a variational circuit, and bstate is an ArrayReg. If bstate is a BatchedArrayReg the following error occurs:
ERROR: MethodError: no method matching BatchedArrayReg{2, ComplexF64, LinearAlgebra.Transpose{ComplexF64, Matrix{ComplexF64}}}(::Matrix{ComplexF64})
The type `BatchedArrayReg{2, ComplexF64, LinearAlgebra.Transpose{ComplexF64, Matrix{ComplexF64}}}` exists, but no method is defined for this combination of argument types when trying to construct it.
Closest candidates are:
BatchedArrayReg{D, T, MT}(::MT, ::Int64) where {D, T, MT<:AbstractMatrix{T}}
@ YaoArrayRegister ~/.julia/packages/YaoArrayRegister/RaQkU/src/register.jl:75
Stacktrace:
[1] tangent_to_reg(::Type{BatchedArrayReg{2, ComplexF64, LinearAlgebra.Transpose{…}}}, reg::ChainRulesCore.Tangent{Any, @NamedTuple{state::Matrix{…}, nbatch::ChainRulesCore.ZeroTangent}})
@ YaoBlocks.AD ~/.julia/packages/YaoBlocks/uf6fB/src/autodiff/chainrules_patch.jl:93
[2] (::YaoBlocks.AD.var"#47#48"{ChainBlock{2}, BatchedArrayReg{2, ComplexF64, LinearAlgebra.Transpose{…}}})(outδ::ChainRulesCore.Tangent{Any, @NamedTuple{state::Matrix{…}, nbatch::ChainRulesCore.ZeroTangent}})
@ YaoBlocks.AD ~/.julia/packages/YaoBlocks/uf6fB/src/autodiff/chainrules_patch.jl:82
[3] (::Zygote.ZBack{YaoBlocks.AD.var"#47#48"{ChainBlock{2}, BatchedArrayReg{2, ComplexF64, LinearAlgebra.Transpose{ComplexF64, Matrix{ComplexF64}}}}})(dy::Base.RefValue{Any})
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/chainrules.jl:222
[4] loss
@ /workspace/ml/julia/yao/lsh_train.jl:54 [inlined]
[5] (::Zygote.Pullback{Tuple{var"#loss#24"{Int64, ChainBlock{2}, Int64}, Vector{Float64}}, Any})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
[6] (::Zygote.var"#88#89"{Zygote.Pullback{Tuple{var"#loss#24"{Int64, ChainBlock{2}, Int64}, Vector{Float64}}, Any}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface.jl:97
[7] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface.jl:154
[8] qml_zeroing_batched(bstate::BatchedArrayReg{2, ComplexF64, LinearAlgebra.Transpose{ComplexF64, Matrix{ComplexF64}}}, vqc::ChainBlock{2}, step::Int64; index::Int64, param::Type, dev::Symbol)
@ Main /workspace/ml/julia/yao/lsh_train.jl:100
[9] macro expansion
@ ./timing.jl:581 [inlined]
[10] top-level scope
@ ./REPL[51]:1
Some type information was truncated. Use `show(err)` to see complete types.
What could be the issue and how to fix it? Are there some chain rules missing?
Metadata
Metadata
Assignees
Labels
No labels