-
Notifications
You must be signed in to change notification settings - Fork 55
Explicit unbatching #160
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Explicit unbatching #160
Conversation
Codecov Report
@@ Coverage Diff @@
## master #160 +/- ##
==========================================
- Coverage 86.26% 85.40% -0.87%
==========================================
Files 14 14
Lines 1245 1281 +36
==========================================
+ Hits 1074 1094 +20
- Misses 171 187 +16
Continue to review full report at Codecov.
|
I thinks this can be filed as a PR upstream in MLUtils where Flux.unbatch now lives. It could be something like function MLUtils.unbatch(x, idxs::Vector{Int})
n = length(idxs)
@assert n == numobs(x)
return [getobs(x, idxs .∈ Ref(i)] for i in 1:n]
end |
I was trying around some more with this recently and found that this implementation scales badly. This is due to the use of boolean vector for indexing. Alternatively, using a startindex:stopindex notation a 10x speedup can be achieved, however I haven't found a way that's safe for zygote. Do you have any suggestions? function Flux.unbatch(g::GNNGraph, x)
changes = g.graph_indicator[1:end-1] .!== g.graph_indicator[2:end]
index = [0; Array(findall(changes)); length(g.graph_indicator)]
return [x[index[i]+1:index[i+1]] for i in 1:g.num_graphs]
end
|
What's your use case for this? Since batching is faster than unbatching, I think the most efficient setting is to have a collection of single graphs and batch them in each mini-batch iteration. I also reorganized MLDatasets.jl along this design. |
I use it for splitting the node data after running a batched graph through the GNN. In more detail I'm working on solving combinatorial problems using AlphaZero.jl combined with this module, that's why I care a lot about speed. |
I see. And you need it AD friendly? |
Yes, as the node data is used to optimize the GNN. AlphaZero basically uses supervised learning to optimize the output of the (G)NN towards that of an iterative search algorithm guided by the (G)NN. |
Ideally one would like the entire pipeline to work on batched data, but if that's not possible your solution seems to be fine and AD friendly: using Flux, MLUtils
function _unbatch(x::AbstractArray, indicator::AbstractVector{<:Int}, n=maximum(indicator))
@assert minimum(indicator) >= 1
all_idxs = partion_indexes_sorted(n, indicator)
return [getobs(x, idxs) for idxs in all_idxs]
end
function partion_indexes_sorted(n, indicator)
@assert issorted(indicator)
changes = indicator[1:end-1] .!== indicator[2:end]
c = [0; findall(changes); length(indicator)]
return [c[i]+1:c[i+1] for i in 1:n]
end
x = rand(2, 25)
indicator = [fill(1, 10); fill(2, 5); fill(3, 10)]
_unbatch(x, indicator)
grad = gradient(x) do x
y = _unbatch(x, indicator)
sum(y[1])
end[1] You can also add ChainRulesCore.@non_differentiable partion_indexes_sorted(n, indicator) to facilitate Zygote in its job |
Finally figured out why the gradient failed on my end changes = indicator[1:end-1] .!== indicator[2:end] should be changes = indicator[1:end-1] .!= indicator[2:end] |
I've implemented the change. For problems where unbatching is necessary (for example in AlphaZero as inference is batched and has to be distributed afterwards) it is in my opinion a nice addition. using Flux, GraphNeuralNetworks
function full(batched)
unbatched = Flux.unbatch(batched)
return([g.ndata.x] for g in unbatched)
end
function Flux.unbatch(g::GNNGraph, x)
changes = g.graph_indicator[1:end-1] .!= g.graph_indicator[2:end]
indexes = [0; Array(findall(changes)); length(g.graph_indicator)]
return [x[indexes[i]+1:indexes[i+1]] for i in 1:g.num_graphs]
end
g() = (n = rand(5:10); rand_graph(n, 8, ndata=rand(1,n)))
g_b = Flux.batch([g() for _ in 1:1000])
full(g_b)
Flux.unbatch(g_b, g_b.ndata.x)
@time full(g_b);
@time Flux.unbatch(g_b, g_b.ndata.x); 0.068085 seconds (42.01 k allocations: 12.840 MiB)
0.000195 seconds (1.03 k allocations: 271.828 KiB) |
Closing this in favor of the new |
Also the speedup to |
Adds the possibility of unbatching data only. Follows the notation of explicit input of layers.
4x speedup in comparison to full unbatching.
Let me know what you think before I'll add tests and documentation