Skip to content

Conversation

casper2002casper
Copy link

Adds the possibility of unbatching data only. Follows the notation of explicit input of layers.

using Flux, GraphNeuralNetworks

function full(batched)
    unbatched = Flux.unbatch(batched)
    return([g.ndata.x] for g in unbatched)
end

function Flux.unbatch(g::GNNGraph, x) 
    return [x[g.graph_indicator .∈ Ref(i)] for i in 1:g.num_graphs]
end

g() = (n = rand(5:10); rand_graph(n, 8, ndata=rand(1,n)))
g_b = Flux.batch([g() for _ in 1:100])

full(g_b)
Flux.unbatch(g_b, g_b.ndata.x)
@time full(g_b);
@time Flux.unbatch(g_b, g_b.ndata.x);
0.001567 seconds (4.20 k allocations: 1.114 MiB)
0.000362 seconds (705 allocations: 457.891 KiB)

4x speedup in comparison to full unbatching.

Let me know what you think before I'll add tests and documentation

@codecov
Copy link

codecov bot commented Apr 19, 2022

Codecov Report

Merging #160 (0b9fbfe) into master (c7d0afe) will decrease coverage by 0.86%.
The diff coverage is 0.00%.

❗ Current head 0b9fbfe differs from pull request most recent head 61a9657. Consider uploading reports for the commit 61a9657 to get more accurate results

@@            Coverage Diff             @@
##           master     #160      +/-   ##
==========================================
- Coverage   86.26%   85.40%   -0.87%     
==========================================
  Files          14       14              
  Lines        1245     1281      +36     
==========================================
+ Hits         1074     1094      +20     
- Misses        171      187      +16     
Impacted Files Coverage Δ
src/GNNGraphs/transform.jl 95.81% <0.00%> (-0.90%) ⬇️
src/layers/conv.jl 77.31% <0.00%> (-1.39%) ⬇️
src/GNNGraphs/utils.jl 84.28% <0.00%> (-0.44%) ⬇️
src/GNNGraphs/convert.jl 89.91% <0.00%> (-0.17%) ⬇️
src/GNNGraphs/query.jl 92.94% <0.00%> (-0.05%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update c7d0afe...61a9657. Read the comment docs.

@CarloLucibello
Copy link
Member

I thinks this can be filed as a PR upstream in MLUtils where Flux.unbatch now lives. It could be something like

function MLUtils.unbatch(x, idxs::Vector{Int})
    n = length(idxs)
   @assert n == numobs(x)
    return [getobs(x, idxs .∈ Ref(i)] for i in 1:n]
end

@casper2002casper casper2002casper marked this pull request as draft April 19, 2022 14:40
@casper2002casper
Copy link
Author

I was trying around some more with this recently and found that this implementation scales badly. This is due to the use of boolean vector for indexing. Alternatively, using a startindex:stopindex notation a 10x speedup can be achieved, however I haven't found a way that's safe for zygote. Do you have any suggestions?
For example this ugly but fast method:

function Flux.unbatch(g::GNNGraph, x) 
    changes = g.graph_indicator[1:end-1] .!== g.graph_indicator[2:end] 
    index =  [0; Array(findall(changes)); length(g.graph_indicator)]
    return [x[index[i]+1:index[i+1]] for i in 1:g.num_graphs]
end

@CarloLucibello
Copy link
Member

What's your use case for this? Since batching is faster than unbatching, I think the most efficient setting is to have a collection of single graphs and batch them in each mini-batch iteration. I also reorganized MLDatasets.jl along this design.

@casper2002casper
Copy link
Author

casper2002casper commented May 23, 2022

I use it for splitting the node data after running a batched graph through the GNN. In more detail I'm working on solving combinatorial problems using AlphaZero.jl combined with this module, that's why I care a lot about speed.

@CarloLucibello
Copy link
Member

I see. And you need it AD friendly?

@casper2002casper
Copy link
Author

casper2002casper commented May 23, 2022

Yes, as the node data is used to optimize the GNN. AlphaZero basically uses supervised learning to optimize the output of the (G)NN towards that of an iterative search algorithm guided by the (G)NN.
The unbatching needs to be AD friendly in order to calculate the gradients from the node data to the GNN.

@CarloLucibello
Copy link
Member

Ideally one would like the entire pipeline to work on batched data, but if that's not possible your solution seems to be fine and AD friendly:

using Flux, MLUtils

function _unbatch(x::AbstractArray, indicator::AbstractVector{<:Int}, n=maximum(indicator)) 
    @assert minimum(indicator) >= 1
    all_idxs = partion_indexes_sorted(n, indicator)
    return [getobs(x, idxs) for idxs in all_idxs]
end

function partion_indexes_sorted(n, indicator)
    @assert issorted(indicator)
    changes = indicator[1:end-1] .!== indicator[2:end]
    c =  [0; findall(changes); length(indicator)]
    return [c[i]+1:c[i+1] for i in 1:n]
end

x = rand(2, 25)
indicator = [fill(1, 10); fill(2, 5); fill(3, 10)]

_unbatch(x, indicator)

grad = gradient(x) do x
    y = _unbatch(x, indicator)
    sum(y[1])
end[1]

You can also add

ChainRulesCore.@non_differentiable partion_indexes_sorted(n, indicator)

to facilitate Zygote in its job

@casper2002casper
Copy link
Author

Finally figured out why the gradient failed on my end

    changes = indicator[1:end-1] .!== indicator[2:end]

should be

    changes = indicator[1:end-1] .!= indicator[2:end]

@casper2002casper
Copy link
Author

I've implemented the change. For problems where unbatching is necessary (for example in AlphaZero as inference is batched and has to be distributed afterwards) it is in my opinion a nice addition.

using Flux, GraphNeuralNetworks
function full(batched)
    unbatched = Flux.unbatch(batched)
    return([g.ndata.x] for g in unbatched)
end

function Flux.unbatch(g::GNNGraph, x) 
    changes = g.graph_indicator[1:end-1] .!= g.graph_indicator[2:end] 
    indexes =  [0; Array(findall(changes)); length(g.graph_indicator)]
    return [x[indexes[i]+1:indexes[i+1]] for i in 1:g.num_graphs]
end

g() = (n = rand(5:10); rand_graph(n, 8, ndata=rand(1,n))) 
g_b = Flux.batch([g() for _ in 1:1000])

full(g_b)
Flux.unbatch(g_b, g_b.ndata.x)


@time full(g_b);
@time Flux.unbatch(g_b, g_b.ndata.x);
  0.068085 seconds (42.01 k allocations: 12.840 MiB)
  0.000195 seconds (1.03 k allocations: 271.828 KiB)

@CarloLucibello
Copy link
Member

CarloLucibello commented Jan 4, 2023

Closing this in favor of the new chunk(x, partition_idxs; [npartitions, dims]) method added to MLUtils.jl in JuliaML/MLUtils.jl#134

@CarloLucibello
Copy link
Member

Also the speedup to unbatch in #248 is going to mitigate the need for unbatching the data only

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants