Skip to content

Commit ff2fcc1

Browse files
chunk by partition indexes (#134)
* chunk by partitions * references
1 parent 6a86d23 commit ff2fcc1

File tree

2 files changed

+73
-10
lines changed

2 files changed

+73
-10
lines changed

src/utils.jl

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ Base.show_function(io::IO, u::Base.Fix2{typeof(_unsqueeze)}, ::Bool) = print(io,
6262
6363
Unroll the given `xs` into an array of arrays along the given dimension `dims`.
6464
65-
See also [`stack`](@ref) and [`unbatch`](@ref).
65+
See also [`stack`](@ref), [`unbatch`](@ref),
66+
and [`chunk`](@ref).
6667
6768
# Examples
6869
@@ -156,6 +157,46 @@ function chunk(x::AbstractArray; size, dims::Int=ndims(x))
156157
return [_selectdim(x, dims, i) for i in idxs]
157158
end
158159

160+
161+
"""
162+
chunk(x, partition_idxs; [npartitions, dims])
163+
164+
Partition the array `x` along the dimension `dims` according to the indexes
165+
in `partition_idxs`.
166+
167+
`partition_idxs` must be sorted and contain only positive integers
168+
between 1 and the number of partitions.
169+
170+
If the number of partition `npartitions` is not provided,
171+
it is inferred from `partition_idxs`.
172+
173+
If `dims` is not provided, it defaults to the last dimension.
174+
175+
See also [`unbatch`](@ref).
176+
177+
# Examples
178+
179+
```jldoctest
180+
julia> x = reshape([1:10;], 2, 5)
181+
2×5 Matrix{Int64}:
182+
1 3 5 7 9
183+
2 4 6 8 10
184+
185+
julia> chunk(x, [1, 2, 2, 3, 3])
186+
3-element Vector{SubArray{Int64, 2, Matrix{Int64}, Tuple{Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, true}}:
187+
[1; 2;;]
188+
[3 5; 4 6]
189+
[7 9; 8 10]
190+
```
191+
"""
192+
function chunk(x::AbstractArray{T,N}, partition_idxs::AbstractVector;
193+
npartitions=nothing, dims=ndims(x)) where {T, N}
194+
@assert issorted(partition_idxs) "partition_idxs must be sorted"
195+
m = npartitions === nothing ? maximum(partition_idxs) : npartitions
196+
degrees = NNlib.scatter(+, ones_like(partition_idxs), partition_idxs, dstsize=(m,))
197+
return chunk(x; size=degrees, dims)
198+
end
199+
159200
# work around https://github.com/JuliaML/MLUtils.jl/issues/103
160201
_selectdim(x::AbstractArray, dims::Int, i) = selectdim(x, dims, i)
161202
_selectdim(x::AbstractArray, dims::Int, i::UnitRange) = _selectdim(x, Val(dims), i)
@@ -349,13 +390,13 @@ end
349390
Reverse of the [`batch`](@ref) operation,
350391
unstacking the last dimension of the array `x`.
351392
352-
See also [`unstack`](@ref).
393+
See also [`unstack`](@ref) and [`chunk`](@ref).
353394
354395
# Examples
355396
356397
```jldoctest
357398
julia> unbatch([1 3 5 7;
358-
2 4 6 8])
399+
2 4 6 8])
359400
4-element Vector{Vector{Int64}}:
360401
[1, 2]
361402
[3, 4]

test/utils.jl

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,16 @@ end
134134
idxs = MLUtils._partition_idxs(x, cld(size(x, dims), n), dims)
135135
test_zygote(MLUtils.∇chunk, dl, x, idxs, Val(dims), check_inferred=false)
136136

137+
138+
if CUDA.functional()
139+
# https://github.com/JuliaML/MLUtils.jl/issues/103
140+
x = rand(2, 10) |> cu
141+
cs = chunk(x, 2)
142+
@test length(cs) == 2
143+
@test cs[1] isa CuArray
144+
@test cs[1] == x[:, 1:5]
145+
end
146+
137147
@testset "size collection" begin
138148
a = reshape(collect(1:10), (5, 2))
139149
y = chunk(a; dims = 1, size = (1, 4))
@@ -144,13 +154,25 @@ end
144154
test_zygote(x -> chunk(x; dims = 1, size = (1, 4)), a)
145155
end
146156

147-
if CUDA.functional()
148-
# https://github.com/JuliaML/MLUtils.jl/issues/103
149-
x = rand(2, 10) |> cu
150-
cs = chunk(x, 2)
151-
@test length(cs) == 2
152-
@test cs[1] isa CuArray
153-
@test cs[1] == x[:, 1:5]
157+
@testset "chunk by partition_idxs" begin
158+
x = reshape(collect(1:15), (3, 5))
159+
partition_idxs = [1,1,3,3,4]
160+
161+
y = chunk(x, partition_idxs)
162+
@test length(y) == 4
163+
@test y[1] == [1 4; 2 5; 3 6]
164+
@test size(y[2]) == (3, 0)
165+
@test y[3] == [7 10; 8 11; 9 12]
166+
@test y[4] == reshape([13, 14, 15], 3, 1)
167+
168+
y = chunk(x, partition_idxs; npartitions=5)
169+
@test length(y) == 5
170+
@test size(y[5]) == (3, 0)
171+
172+
y = chunk(x, [1,1,2]; dims=1)
173+
@test length(y) == 2
174+
@test y[1] == [1 4 7 10 13; 2 5 8 11 14]
175+
@test y[2] == [3 6 9 12 15]
154176
end
155177
end
156178

0 commit comments

Comments
 (0)