Skip to content

Commit 39058e6

Browse files
committed
move part of getindex computation to compilation time
Also add `laplacian` and `laplacian!`
1 parent e72cb9a commit 39058e6

File tree

3 files changed

+102
-90
lines changed

3 files changed

+102
-90
lines changed

src/ImageBase.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ export
1111
fdiff!,
1212
fdiv,
1313
fdiv!,
14+
flaplacian,
15+
flaplacian!,
1416
DiffView,
1517

1618
# basic image statistics, from Images.jl

src/diff.jl

Lines changed: 88 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -3,92 +3,88 @@ struct Periodic <: BoundaryCondition end
33
struct ZeroFill <: BoundaryCondition end
44

55
"""
6-
DiffView(A::AbstractArray, [rev=Val(false)], [bc::BoundaryCondition=Periodic()]; dims)
6+
DiffView(A::AbstractArray, dims::Val{D}, [bc::BoundaryCondition=Periodic()], [rev=Val(false)])
77
88
Lazy version of finite difference [`fdiff`](@ref).
99
1010
!!! tip
11-
For performance, `rev` should be stable type `Val(false)` or `Val(true)`.
11+
For performance, both `dims` and `rev` require `Val` types.
1212
1313
# Arguments
1414
15+
- `dims::Val{D}`
16+
Specify the dimension D that dinite difference is applied to.
1517
- `rev::Bool`
1618
If `rev==Val(true)`, then it computes the backward difference
1719
`(A[end]-A[1], A[1]-A[2], ..., A[end-1]-A[end])`.
1820
- `boundary::BoundaryCondition`
1921
By default it computes periodically in the boundary, i.e., `Periodic()`.
2022
In some cases, one can fill zero values with `ZeroFill()`.
2123
"""
22-
struct DiffView{T,N,AT<:AbstractArray,BC,REV} <: AbstractArray{T,N}
24+
struct DiffView{T,N,D,BC,REV,AT<:AbstractArray} <: AbstractArray{T,N}
2325
data::AT
24-
dims::Int
2526
end
2627
function DiffView(
2728
data::AbstractArray{T,N},
29+
::Val{D},
2830
bc::BoundaryCondition=Periodic(),
29-
rev::Union{Val, Bool}=Val(false);
30-
dims=_fdiff_default_dims(data)) where {T,N}
31-
isnothing(dims) && throw(UndefKeywordError(:dims))
32-
rev = to_static_bool(rev)
33-
DiffView{maybe_floattype(T),N,typeof(data),typeof(bc),typeof(rev)}(data, dims)
34-
end
35-
function DiffView(
36-
data::AbstractArray,
37-
rev::Union{Val, Bool},
38-
bc::BoundaryCondition = Periodic();
39-
kwargs...)
40-
DiffView(data, bc, rev; kwargs...)
41-
end
42-
43-
to_static_bool(x::Union{Val{true},Val{false}}) = x
44-
function to_static_bool(x::Bool)
45-
@warn "Please use `Val($x)` for performance"
46-
return Val(x)
31+
rev::Val = Val(false)
32+
) where {T,N,D}
33+
DiffView{maybe_floattype(T),N,D,typeof(bc),typeof(rev),typeof(data)}(data)
4734
end
4835

4936
Base.size(A::DiffView) = size(A.data)
5037
Base.axes(A::DiffView) = axes(A.data)
5138
Base.IndexStyle(::DiffView) = IndexCartesian()
5239

53-
Base.@propagate_inbounds function Base.getindex(A::DiffView{T,N,AT,Periodic,Val{true}}, I::Vararg{Int, N}) where {T,N,AT}
40+
Base.@propagate_inbounds function Base.getindex(A::DiffView{T,N,D,Periodic,Val{true}}, I::Vararg{Int, N}) where {T,N,D}
5441
data = A.data
55-
I_prev = map(ntuple(identity, N), I, axes(data)) do i, p, r
56-
i == A.dims || return p
57-
p == first(r) && return last(r)
58-
p - 1
59-
end
42+
r = axes(data, D)
43+
x = I[D]
44+
x_prev = first(r) == x ? last(r) : x - 1
45+
I_prev = update_tuple(I, x_prev, Val(D))
6046
return convert(T, data[I...]) - convert(T, data[I_prev...])
6147
end
62-
Base.@propagate_inbounds function Base.getindex(A::DiffView{T,N,AT,Periodic,Val{false}}, I::Vararg{Int, N}) where {T,N,AT}
48+
Base.@propagate_inbounds function Base.getindex(A::DiffView{T,N,D,Periodic,Val{false}}, I::Vararg{Int, N}) where {T,N,D}
6349
data = A.data
64-
I_next = map(ntuple(identity, N), I, axes(data)) do i, p, r
65-
i == A.dims || return p
66-
p == last(r) && return first(r)
67-
p + 1
68-
end
50+
r = axes(data, D)
51+
x = I[D]
52+
x_next = last(r) == x ? first(r) : x + 1
53+
I_next = update_tuple(I, x_next, Val(D))
6954
return convert(T, data[I_next...]) - convert(T, data[I...])
7055
end
71-
Base.@propagate_inbounds function Base.getindex(A::DiffView{T,N,AT,ZeroFill,Val{false}}, I::Vararg{Int, N}) where {T,N,AT}
56+
Base.@propagate_inbounds function Base.getindex(A::DiffView{T,N,D,ZeroFill,Val{false}}, I::Vararg{Int, N}) where {T,N,D}
7257
data = A.data
73-
I_next = I .+ ntuple(i->i==A.dims, N)
74-
if checkbounds(Bool, data, I_next...)
75-
vi = convert(T, data[I...]) # it requires the caller to pass @inbounds
76-
@inbounds convert(T, data[I_next...]) - vi
77-
else
58+
x = I[D]
59+
if last(axes(data, D)) == x
7860
zero(T)
61+
else
62+
I_next = update_tuple(I, x+1, Val(D))
63+
convert(T, data[I_next...]) - convert(T, data[I...])
7964
end
8065
end
81-
Base.@propagate_inbounds function Base.getindex(A::DiffView{T,N,AT,ZeroFill,Val{true}}, I::Vararg{Int, N}) where {T,N,AT}
66+
Base.@propagate_inbounds function Base.getindex(A::DiffView{T,N,D,ZeroFill,Val{true}}, I::Vararg{Int, N}) where {T,N,D}
8267
data = A.data
83-
I_prev = I .- ntuple(i->i==A.dims, N)
84-
if checkbounds(Bool, data, I_prev...)
85-
vi = convert(T, data[I...]) # it requires the caller to pass @inbounds
86-
@inbounds vi - convert(T, data[I_prev...])
87-
else
68+
x = I[D]
69+
if first(axes(data, D)) == x
8870
zero(T)
71+
else
72+
I_prev = update_tuple(I, x-1, Val(D))
73+
convert(T, data[I...]) - convert(T, data[I_prev...])
8974
end
9075
end
9176

77+
@generated function update_tuple(A::NTuple{N, T}, x::T, ::Val{i}) where {T, N, i}
78+
# This is equivalent to `ntuple(j->j==i ? x : A[j], N)` but is optimized by moving
79+
# the if branches to compilation time.
80+
ex = :()
81+
for j in Base.OneTo(N)
82+
new_x = i == j ? :(x) : :(A[$j])
83+
ex = :($ex..., $new_x)
84+
end
85+
return ex
86+
end
87+
9288
# TODO: add keyword `shrink` to give a consistant result on Base
9389
# when this is done, then we can propose this change to upstream Base
9490
"""
@@ -201,45 +197,66 @@ maybe_floattype(::Type{CT}) where CT<:Color = base_color_type(CT){maybe_floattyp
201197

202198

203199
"""
204-
fdiv(Vs::AbstractArray...; boundary=:periodic)
200+
fdiv(Vs::AbstractArray...)
205201
206202
Discrete divergence operator for vector field (V₁, V₂, ..., Vₙ).
207203
208-
# Example
209-
210-
Laplacian operator of array `A` is the divergence of its gradient vector field (∂₁A, ∂₂A, ..., ∂ₙA):
211-
212-
```jldoctest
213-
julia> using ImageFiltering, ImageBase
214-
215-
julia> X = Float32.(rand(1:9, 7, 7));
216-
217-
julia> laplacian(X) = fdiv(ntuple(i->DiffView(X, dims=i), ndims(X))...)
218-
laplacian (generic function with 1 method)
219-
220-
julia> laplacian(X) == imfilter(X, Kernel.Laplacian(), "circular")
221-
true
222-
```
223-
224204
See also [`fdiv!`](@ref) for the in-place version.
225205
"""
226-
function fdiv(V₁::AbstractArray, Vs...; kwargs...)
227-
fdiv!(similar(V₁, floattype(eltype(V₁))), V₁, Vs...; kwargs...)
228-
end
206+
fdiv(V₁::AbstractArray, Vs...) = fdiv!(similar(V₁, floattype(eltype(V₁))), V₁, Vs...)
229207

230208
"""
231209
fdiv!(dst::AbstractArray, Vs::AbstractArray...)
232210
233211
The in-place version of [`fdiv`](@ref).
234212
"""
235213
function fdiv!(dst::AbstractArray, Vs::AbstractArray...)
236-
= map(ntuple(identity, length(Vs)), Vs) do n, V
237-
DiffView(V, Val(true), dims=n)
238-
end
214+
# negative adjoint of gradient is equivalent to the reversed finite difference
215+
= fnegative_adjoint_gradient(Vs...)
239216
@inbounds for i in CartesianIndices(dst)
240-
dst[i] = sum(x->_inbound_getindex(x, i), ∇)
217+
dst[i] = heterogeneous_getindex_sum(i, ∇...)
241218
end
242219
return dst
243220
end
244221

245-
@inline _inbound_getindex(x, i) = @inbounds x[i]
222+
@generated function heterogeneous_getindex_sum(i, Vs::Vararg{<:AbstractArray, N}) where N
223+
# This method is equivalent to `sum(V->V[i], Vs)` but is optimized for heterogeneous arrays
224+
ex = :(zero(eltype(Vs[1])))
225+
for j in Base.OneTo(N)
226+
ex = :($ex + Vs[$j][i])
227+
end
228+
return ex
229+
end
230+
231+
"""
232+
flaplacian(X::AbstractArray)
233+
234+
The Laplacian operator ∇² is the divergence of the gradient operator.
235+
"""
236+
flaplacian(X::AbstractArray) = flaplacian!(similar(X, maybe_floattype(eltype(X))), X)
237+
238+
"""
239+
flaplacian!(dst::AbstractArray, X::AbstractArray)
240+
241+
The in-place version of the Laplacian operator [`laplacian`](@ref).
242+
"""
243+
flaplacian!(dst::AbstractArray, X::AbstractArray) = fdiv!(dst, fgradient(X)...)
244+
245+
# These two functions pass dimension information `Val(i)` to DiffView so that
246+
# we can move computations to compilation time.
247+
@generated function fgradient(X::AbstractArray{T, N}) where {T, N}
248+
ex = :()
249+
for i in Base.OneTo(N)
250+
new_x = :(DiffView(X, Val($i), Periodic(), Val(false)))
251+
ex = :($ex..., $new_x)
252+
end
253+
return ex
254+
end
255+
@generated function fnegative_adjoint_gradient(Vs::Vararg{<:AbstractArray, N}) where N
256+
ex = :()
257+
for i in Base.OneTo(N)
258+
new_x = :(DiffView(Vs[$i], Val($i), Periodic(), Val(true)))
259+
ex = :($ex..., $new_x)
260+
end
261+
return ex
262+
end

test/diff.jl

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -109,46 +109,39 @@
109109
end
110110

111111
@testset "DiffView" begin
112-
A = rand(6)
113-
@test DiffView(A) ===
114-
DiffView(A, Val(false), ImageBase.Periodic()) ===
115-
DiffView(A, ImageBase.Periodic(), Val(false)) ===
116-
@test_logs((:warn, "Please use `Val(false)` for performance"), DiffView(A, false))
117-
118112
for T in generate_test_types([N0f8, Float32], [Gray, RGB])
119113
A = rand(T, 6)
120-
Av = DiffView(A)
121-
@test Av == DiffView(A, ImageBase.Periodic(), Val(false))
114+
Av = DiffView(A, Val(1))
115+
@test Av == DiffView(A, Val(1), ImageBase.Periodic(), Val(false))
122116
@test eltype(Av) == floattype(T)
123117
@test axes(Av) == axes(A)
124118
@test Av == fdiff(A)
125-
@test DiffView(A, Val(true)) == fdiff(A; rev=true)
126-
@test DiffView(A, ImageBase.ZeroFill()) == fdiff(A; boundary=:zero)
127-
@test DiffView(A, ImageBase.ZeroFill(), Val(true)) == fdiff(A; boundary=:zero, rev=true)
119+
@test DiffView(A, Val(1), ImageBase.Periodic(), Val(true)) == fdiff(A; rev=true)
120+
@test DiffView(A, Val(1), ImageBase.ZeroFill()) == fdiff(A; boundary=:zero)
121+
@test DiffView(A, Val(1), ImageBase.ZeroFill(), Val(true)) == fdiff(A; boundary=:zero, rev=true)
128122

129123
A = rand(T, 6, 6)
130-
Av = DiffView(A, dims=1)
124+
Av = DiffView(A, Val(1))
131125
@test eltype(Av) == floattype(T)
132126
@test axes(Av) == axes(A)
133127
@test Av == fdiff(A, dims=1)
134-
@test DiffView(A, Val(true), dims=1) == fdiff(A; dims=1, rev=true)
135-
@test DiffView(A, ImageBase.ZeroFill(), dims=1) == fdiff(A; boundary=:zero, dims=1)
136-
@test DiffView(A, ImageBase.ZeroFill(), Val(true), dims=1) == fdiff(A; boundary=:zero, rev=true, dims=1)
128+
@test DiffView(A, Val(1), ImageBase.Periodic(), Val(true)) == fdiff(A; dims=1, rev=true)
129+
@test DiffView(A, Val(1), ImageBase.ZeroFill()) == fdiff(A; boundary=:zero, dims=1)
130+
@test DiffView(A, Val(1), ImageBase.ZeroFill(), Val(true)) == fdiff(A; boundary=:zero, rev=true, dims=1)
137131
end
138132

139133
A = OffsetArray(rand(6, 6), -1, -1)
140-
Av = DiffView(A, dims=1)
134+
Av = DiffView(A, Val(1))
141135
@test axes(Av) == axes(A)
142136
@test Av == fdiff(A, dims=1)
143137
end
144138

145-
@testset "fdiv" begin
146-
laplacian(X) = fdiv(ntuple(i->DiffView(X, dims=i), ndims(X))...)
139+
@testset "fdiv/flaplacian" begin
147140
ref_laplacian(X) = imfilter(X, Kernel.Laplacian(ntuple(x->true, ndims(X))), "circular")
148141
for T in generate_test_types([N0f8, Float32], [Gray, RGB])
149142
for sz in [(7,), (7, 7), (7, 7, 7)]
150143
A = rand(T, sz...)
151-
@test laplacian(A) ref_laplacian(A)
144+
@test flaplacian(A) ref_laplacian(A)
152145
end
153146
end
154147
end

0 commit comments

Comments
 (0)