@@ -62,7 +62,8 @@ Base.show_function(io::IO, u::Base.Fix2{typeof(_unsqueeze)}, ::Bool) = print(io,
62
62
63
63
Unroll the given `xs` into an array of arrays along the given dimension `dims`.
64
64
65
- See also [`stack`](@ref) and [`unbatch`](@ref).
65
+ See also [`stack`](@ref), [`unbatch`](@ref),
66
+ and [`chunk`](@ref).
66
67
67
68
# Examples
68
69
@@ -156,6 +157,46 @@ function chunk(x::AbstractArray; size, dims::Int=ndims(x))
156
157
return [_selectdim (x, dims, i) for i in idxs]
157
158
end
158
159
160
+
161
+ """
162
+ chunk(x, partition_idxs; [npartitions, dims])
163
+
164
+ Partition the array `x` along the dimension `dims` according to the indexes
165
+ in `partition_idxs`.
166
+
167
+ `partition_idxs` must be sorted and contain only positive integers
168
+ between 1 and the number of partitions.
169
+
170
+ If the number of partition `npartitions` is not provided,
171
+ it is inferred from `partition_idxs`.
172
+
173
+ If `dims` is not provided, it defaults to the last dimension.
174
+
175
+ See also [`unbatch`](@ref).
176
+
177
+ # Examples
178
+
179
+ ```jldoctest
180
+ julia> x = reshape([1:10;], 2, 5)
181
+ 2×5 Matrix{Int64}:
182
+ 1 3 5 7 9
183
+ 2 4 6 8 10
184
+
185
+ julia> chunk(x, [1, 2, 2, 3, 3])
186
+ 3-element Vector{SubArray{Int64, 2, Matrix{Int64}, Tuple{Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, true}}:
187
+ [1; 2;;]
188
+ [3 5; 4 6]
189
+ [7 9; 8 10]
190
+ ```
191
+ """
192
+ function chunk (x:: AbstractArray{T,N} , partition_idxs:: AbstractVector ;
193
+ npartitions= nothing , dims= ndims (x)) where {T, N}
194
+ @assert issorted (partition_idxs) " partition_idxs must be sorted"
195
+ m = npartitions === nothing ? maximum (partition_idxs) : npartitions
196
+ degrees = NNlib. scatter (+ , ones_like (partition_idxs), partition_idxs, dstsize= (m,))
197
+ return chunk (x; size= degrees, dims)
198
+ end
199
+
159
200
# work around https://github.com/JuliaML/MLUtils.jl/issues/103
160
201
_selectdim (x:: AbstractArray , dims:: Int , i) = selectdim (x, dims, i)
161
202
_selectdim (x:: AbstractArray , dims:: Int , i:: UnitRange ) = _selectdim (x, Val (dims), i)
@@ -349,13 +390,13 @@ end
349
390
Reverse of the [`batch`](@ref) operation,
350
391
unstacking the last dimension of the array `x`.
351
392
352
- See also [`unstack`](@ref).
393
+ See also [`unstack`](@ref) and [`chunk`](@ref) .
353
394
354
395
# Examples
355
396
356
397
```jldoctest
357
398
julia> unbatch([1 3 5 7;
358
- 2 4 6 8])
399
+ 2 4 6 8])
359
400
4-element Vector{Vector{Int64}}:
360
401
[1, 2]
361
402
[3, 4]
0 commit comments