Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ Base.show_function(io::IO, u::Base.Fix2{typeof(_unsqueeze)}, ::Bool) = print(io,

Unroll the given `xs` into an array of arrays along the given dimension `dims`.

See also [`stack`](@ref) and [`unbatch`](@ref).
See also [`stack`](@ref), [`unbatch`](@ref),
and [`chunk`](@ref).

# Examples

Expand Down Expand Up @@ -156,6 +157,46 @@ function chunk(x::AbstractArray; size, dims::Int=ndims(x))
return [_selectdim(x, dims, i) for i in idxs]
end


"""
chunk(x, partition_idxs; [npartitions, dims])

Partition the array `x` along the dimension `dims` according to the indexes
in `partition_idxs`.

`partition_idxs` must be sorted and contain only positive integers
between 1 and the number of partitions.

If the number of partition `npartitions` is not provided,
it is inferred from `partition_idxs`.

If `dims` is not provided, it defaults to the last dimension.

See also [`unbatch`](@ref).

# Examples

```jldoctest
julia> x = reshape([1:10;], 2, 5)
2×5 Matrix{Int64}:
1 3 5 7 9
2 4 6 8 10

julia> chunk(x, [1, 2, 2, 3, 3])
3-element Vector{SubArray{Int64, 2, Matrix{Int64}, Tuple{Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, true}}:
[1; 2;;]
[3 5; 4 6]
[7 9; 8 10]
```
"""
function chunk(x::AbstractArray{T,N}, partition_idxs::AbstractVector;
npartitions=nothing, dims=ndims(x)) where {T, N}
@assert issorted(partition_idxs) "partition_idxs must be sorted"
m = npartitions === nothing ? maximum(partition_idxs) : npartitions
degrees = NNlib.scatter(+, ones_like(partition_idxs), partition_idxs, dstsize=(m,))
return chunk(x; size=degrees, dims)
end

# work around https://github.com/JuliaML/MLUtils.jl/issues/103
_selectdim(x::AbstractArray, dims::Int, i) = selectdim(x, dims, i)
_selectdim(x::AbstractArray, dims::Int, i::UnitRange) = _selectdim(x, Val(dims), i)
Expand Down Expand Up @@ -349,13 +390,13 @@ end
Reverse of the [`batch`](@ref) operation,
unstacking the last dimension of the array `x`.

See also [`unstack`](@ref).
See also [`unstack`](@ref) and [`chunk`](@ref).

# Examples

```jldoctest
julia> unbatch([1 3 5 7;
2 4 6 8])
2 4 6 8])
4-element Vector{Vector{Int64}}:
[1, 2]
[3, 4]
Expand Down
36 changes: 29 additions & 7 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,16 @@ end
idxs = MLUtils._partition_idxs(x, cld(size(x, dims), n), dims)
test_zygote(MLUtils.∇chunk, dl, x, idxs, Val(dims), check_inferred=false)


if CUDA.functional()
# https://github.com/JuliaML/MLUtils.jl/issues/103
x = rand(2, 10) |> cu
cs = chunk(x, 2)
@test length(cs) == 2
@test cs[1] isa CuArray
@test cs[1] == x[:, 1:5]
end

@testset "size collection" begin
a = reshape(collect(1:10), (5, 2))
y = chunk(a; dims = 1, size = (1, 4))
Expand All @@ -144,13 +154,25 @@ end
test_zygote(x -> chunk(x; dims = 1, size = (1, 4)), a)
end

if CUDA.functional()
# https://github.com/JuliaML/MLUtils.jl/issues/103
x = rand(2, 10) |> cu
cs = chunk(x, 2)
@test length(cs) == 2
@test cs[1] isa CuArray
@test cs[1] == x[:, 1:5]
@testset "chunk by partition_idxs" begin
x = reshape(collect(1:15), (3, 5))
partition_idxs = [1,1,3,3,4]

y = chunk(x, partition_idxs)
@test length(y) == 4
@test y[1] == [1 4; 2 5; 3 6]
@test size(y[2]) == (3, 0)
@test y[3] == [7 10; 8 11; 9 12]
@test y[4] == reshape([13, 14, 15], 3, 1)

y = chunk(x, partition_idxs; npartitions=5)
@test length(y) == 5
@test size(y[5]) == (3, 0)

y = chunk(x, [1,1,2]; dims=1)
@test length(y) == 2
@test y[1] == [1 4 7 10 13; 2 5 8 11 14]
@test y[2] == [3 6 9 12 15]
end
end

Expand Down