diff --git a/src/API/types.jl b/src/API/types.jl index 55a4211e..4f375045 100644 --- a/src/API/types.jl +++ b/src/API/types.jl @@ -51,7 +51,7 @@ baremodule pybuiltins end # Wrap """ - PyArray{T,N,M,L,R}(x; copy=true, array=true, buffer=true) + PyArray{T,N,F}(x; copy=true, array=true, buffer=true) Wrap the Python array `x` as a Julia `AbstractArray{T,N}`. @@ -62,31 +62,35 @@ If `copy=false` then the resulting array is guaranteed to directly wrap the data The type parameters are all optional, and are: - `T`: The element type. - `N`: The number of dimensions. -- `M`: True if the array is mutable. -- `L`: True if the array supports fast linear indexing. -- `R`: The element type of the underlying buffer. Often equal to `T`. -""" -struct PyArray{T,N,M,L,R} <: AbstractArray{T,N} - ptr::Ptr{R} # pointer to the data +- `F`: Tuple of symbols, including: + - `:mutable`: The array is mutable. + - `:linear`: Supports fast linear indexing. + - `:contiguous`: Data is F-contiguous. Implies `:linear`. + - `:copy`/`:nocopy`: Whether or not a copy of the data may be taken. +""" +struct PyArray{T,N,F} <: AbstractArray{T,N} + ptr::Ptr{Cvoid} # pointer to the data length::Int # length of the array size::NTuple{N,Int} # size of the array strides::NTuple{N,Int} # strides (in bytes) between elements py::Py # underlying python object handle::Py # the data in this array is valid as long as this handle is alive - function PyArray{T,N,M,L,R}( + function PyArray{T,N,F}( ::Val{:new}, - ptr::Ptr{R}, + ptr::Ptr, size::NTuple{N,Int}, strides::NTuple{N,Int}, py::Py, handle::Py, - ) where {T,N,M,L,R} + ) where {T,N,F} T isa Type || error("T must be a Type") N isa Int || error("N must be an Int") - M isa Bool || error("M must be a Bool") - L isa Bool || error("L must be a Bool") - R isa DataType || error("R must be a DataType") - new{T,N,M,L,R}(ptr, prod(size), size, strides, py, handle) + F isa Tuple{Vararg{Symbol}} || error("F must be a tuple of Symbols") + for flag in F + flag in (:mutable, :linear, :contiguous, :copy, :nocopy) || + error("invalid entry of F: $(repr(flag))") + end + new{T,N,F}(ptr, prod(size), size, strides, py, handle) end end diff --git a/src/Wrap/PyArray.jl b/src/Wrap/PyArray.jl index 7b1dc8d3..030de5fc 100644 --- a/src/Wrap/PyArray.jl +++ b/src/Wrap/PyArray.jl @@ -1,37 +1,6 @@ -struct UnsafePyObject - ptr::C.PyPtr -end - - ispy(::PyArray) = true Py(x::PyArray) = x.py -for N in (missing, 1, 2) - for M in (missing, true, false) - for L in (missing, true, false) - for R in (true, false) - name = Symbol( - "Py", - M === missing ? "" : M ? "Mutable" : "Immutable", - L === missing ? "" : L ? "Linear" : "Cartesian", - R ? "Raw" : "", - N === missing ? "Array" : N == 1 ? "Vector" : "Matrix", - ) - name == :PyArray && continue - vars = Any[ - :T, - N === missing ? :N : N, - M === missing ? :M : M, - L === missing ? :L : L, - R ? :T : :R, - ] - @eval const $name{$(unique([v for v in vars if v isa Symbol])...)} = PyArray{$(vars...)} - @eval export $name - end - end - end -end - (::Type{A})( x; array::Bool = true, @@ -116,38 +85,29 @@ function pyarray_make( ::Type{A}, x::Py, info::PyArraySource, - ::Type{PyArray{T0,N0,M0,L0,R0}} = Utils._type_lb(A), - ::Type{PyArray{T1,N1,M1,L1,R1}} = Utils._type_ub(A), -) where {A<:PyArray,T0,N0,M0,L0,R0,T1,N1,M1,L1,R1} - # R (buffer eltype) + ::Type{PyArray{T0,N0,F0}} = Utils._type_lb(A), + ::Type{PyArray{T1,N1,F1}} = Utils._type_ub(A), +) where {A<:PyArray,T0,N0,F0,T1,N1,F1} + # T (eltype) - checked against R (ptr eltype) R′ = pyarray_get_R(info)::DataType - if R0 == R1 - R = R1 - R == R′ || error("incorrect R, got $R, should be $R′") - # elseif T0 == T1 && T1 in (Bool, Int8, Int16, Int32, Int64, Int128, UInt8, UInt16, UInt32, UInt64, UInt128, Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64) - # R = T1 - # R == R′ || error("incorrect R, got $R, should be $R′") - # R <: R1 || error("R out of bounds, got $R, should be <: $R1") - # R >: R0 || error("R out of bounds, got $R, should be >: $R0") - else - R = R′ - end - # ptr - ptr = pyarray_get_ptr(info, R)::Ptr{R} - # T (eltype) if T0 == T1 T = T1 - pyarray_check_T(T, R) + R = pyarray_get_R(T) + R′ == R || error("eltype $T is incompatible with array data of type $R′") else - T = pyarray_get_T(R, T0, T1)::DataType - T <: T1 || error("T out of bounds, got $T, should be <: $T1") - T >: T0 || error("T out of bounds, got $T, should be >: $T0") + T = pyarray_get_T(R′, T0, T1)::DataType + T <: T1 || error("computed eltype $T out of bounds, should be <: $T1") + T >: T0 || error("computed eltype $T out of bounds, should be >: $T0") + R = pyarray_get_R(T) + R′ == R || error("computed eltype $T is incompatible with array data of type $R′") end + # ptr + ptr = pyarray_get_ptr(info)::Ptr{Cvoid} # N (ndims) N′ = pyarray_get_N(info)::Int if N0 == N1 N = N1 isa Int ? N1 : Int(N1) - N == N′ || error("incorrect N, got $N, should be $N′") + N == N′ || error("number of dimensions $N incorrect, expecting $N′") else N = N′ end @@ -155,28 +115,28 @@ function pyarray_make( size = pyarray_get_size(info, Val(N))::NTuple{N,Int} # strides strides = pyarray_get_strides(info, Val(N), R, size)::NTuple{N,Int} - # M (mutable) - # TODO: if M==false, we don't need to compute M′ - M′ = pyarray_get_M(info)::Bool - if M0 == M1 - M = M1 isa Bool ? M1 : Bool(M1) - M && !M′ && error("incorrect M, got $M, should be $M′") - else - M = M′ - end - # L (linearly indexable) - # TODO: if L==false, we don't need to compute L′ - L′ = N < 2 || strides == Utils.size_to_fstrides(strides[1], size) - if L0 == L1 - L = L1 isa Bool ? L1 : Bool(L1) - L && !L′ && error("incorrect L, got $L, should be $L′") + # F (flags) + M = pyarray_get_M(info)::Bool + L = N < 2 || strides == Utils.size_to_fstrides(strides[1], size) + C = L && (N == 0 || strides[1] == sizeof(R)) + Flist = Symbol[] + if F0 == F1 + F = F1::Tuple{Vararg{Symbol}} + (:mutable in F) && (!M) && error(":mutable flag given but array is not mutable") + (:linear in F) && (!L) && error(":linear flag given but array is not evenly spaced") + (:contiguous in F) && + (!C) && + error(":contiguous flag given but array is not contiguous") else - L = L′ + M && push!(Flist, :mutable) + L && push!(Flist, :linear) + C && push!(Flist, :contiguous) + F = Tuple(Flist) end # handle handle = pyarray_get_handle(info) # done - arr = PyArray{T,N,M,L,R}(Val(:new), ptr, size, strides, x, handle) + arr = PyArray{T,N,F}(Val(:new), ptr, size, strides, x, handle) return pyconvert_return(arr) end @@ -367,7 +327,7 @@ function pyarray_get_R(src::PyArraySource_ArrayInterface) return R end -pyarray_get_ptr(src::PyArraySource_ArrayInterface, ::Type{R}) where {R} = Ptr{R}(src.ptr) +pyarray_get_ptr(src::PyArraySource_ArrayInterface) = src.ptr pyarray_get_N(src::PyArraySource_ArrayInterface) = Int(@py jllen(@jl(src.dict)["shape"])) @@ -490,8 +450,8 @@ function pyarray_get_R(src::PyArraySource_ArrayStruct) @assert false end -function pyarray_get_ptr(src::PyArraySource_ArrayStruct, ::Type{R}) where {R} - return Ptr{R}(src.info.data) +function pyarray_get_ptr(src::PyArraySource_ArrayStruct) + src.info.data end function pyarray_get_N(src::PyArraySource_ArrayStruct) @@ -564,7 +524,7 @@ const PYARRAY_BUFFERFORMAT_TO_TYPE = let c = Utils.islittleendian() ? '<' : '>' "$(c)d" => Cdouble, "?" => Bool, "P" => Ptr{Cvoid}, - "O" => UnsafePyObject, + "O" => C.PyPtr, "=e" => Float16, "=f" => Float32, "=d" => Float64, @@ -582,7 +542,7 @@ function pyarray_get_R(src::PyArraySource_Buffer) return ptr == C_NULL ? UInt8 : pyarray_bufferformat_to_type(String(ptr)) end -pyarray_get_ptr(src::PyArraySource_Buffer, ::Type{R}) where {R} = Ptr{R}(src.buf.buf[!]) +pyarray_get_ptr(src::PyArraySource_Buffer) = src.buf.buf[!] pyarray_get_N(src::PyArraySource_Buffer) = Int(src.buf.ndim[]) @@ -620,93 +580,100 @@ Base.length(x::PyArray) = x.length Base.size(x::PyArray) = x.size -Utils.ismutablearray(x::PyArray{T,N,M,L,R}) where {T,N,M,L,R} = M +Utils.ismutablearray(x::PyArray{T,N,F}) where {T,N,F} = (:mutable in F) -Base.IndexStyle(::Type{PyArray{T,N,M,L,R}}) where {T,N,M,L,R} = - L ? Base.IndexLinear() : Base.IndexCartesian() +Base.IndexStyle(::Type{PyArray{T,N,F}}) where {T,N,F} = + (:linear in F) ? Base.IndexLinear() : Base.IndexCartesian() -Base.unsafe_convert(::Type{Ptr{T}}, x::PyArray{T,N,M,L,T}) where {T,N,M,L} = x.ptr +Base.unsafe_convert(::Type{Ptr{T}}, x::PyArray{T}) where {T} = + pyarray_get_R(T) == T ? Ptr{T}(x.ptr) : error("") -Base.elsize(::Type{PyArray{T,N,M,L,T}}) where {T,N,M,L} = sizeof(T) +Base.elsize(::Type{PyArray{T}}) where {T} = pyarray_get_R(T) == T ? sizeof(T) : error("") -Base.strides(x::PyArray{T,N,M,L,R}) where {T,N,M,L,R} = +function Base.strides(x::PyArray{T}) where {T} + R = pyarray_get_R(T) if all(mod.(x.strides, sizeof(R)) .== 0) div.(x.strides, sizeof(R)) else error("strides are not a multiple of element size") end - -function Base.showarg(io::IO, x::PyArray{T,N}, toplevel::Bool) where {T,N} - toplevel || print(io, "::") - print(io, "PyArray{") - show(io, T) - print(io, ", ", N, "}") - return end @propagate_inbounds Base.getindex(x::PyArray{T,N}, i::Vararg{Int,N}) where {T,N} = pyarray_getindex(x, i...) -@propagate_inbounds Base.getindex(x::PyArray{T,N,M,true}, i::Int) where {T,N,M} = - pyarray_getindex(x, i) -@propagate_inbounds Base.getindex(x::PyArray{T,1,M,true}, i::Int) where {T,M} = - pyarray_getindex(x, i) +@propagate_inbounds Base.getindex(x::PyArray, i::Int) = pyarray_getindex(x, i) -@propagate_inbounds function pyarray_getindex(x::PyArray, i...) +@propagate_inbounds function pyarray_getindex(x::PyArray{T}, i...) where {T} @boundscheck checkbounds(x, i...) - pyarray_load(eltype(x), x.ptr + pyarray_offset(x, i...)) + R = pyarray_get_R(T) + pyarray_load(T, Ptr{R}(x.ptr + pyarray_offset(x, i...))) end -@propagate_inbounds Base.setindex!(x::PyArray{T,N,true}, v, i::Vararg{Int,N}) where {T,N} = +@propagate_inbounds Base.setindex!(x::PyArray{T,N}, v, i::Vararg{Int,N}) where {T,N} = pyarray_setindex!(x, v, i...) -@propagate_inbounds Base.setindex!(x::PyArray{T,N,true,true}, v, i::Int) where {T,N} = - pyarray_setindex!(x, v, i) -@propagate_inbounds Base.setindex!(x::PyArray{T,1,true,true}, v, i::Int) where {T} = - pyarray_setindex!(x, v, i) +@propagate_inbounds Base.setindex!(x::PyArray, v, i::Int) = pyarray_setindex!(x, v, i) -@propagate_inbounds function pyarray_setindex!(x::PyArray{T,N,true}, v, i...) where {T,N} +@propagate_inbounds function pyarray_setindex!(x::PyArray{T,N,F}, v, i...) where {T,N,F} + (:mutable in F) || error("array is immutable") @boundscheck checkbounds(x, i...) - pyarray_store!(x.ptr + pyarray_offset(x, i...), convert(T, v)) + R = pyarray_get_R(T) + pyarray_store!(T, Ptr{R}(x.ptr + pyarray_offset(x, i...)), convert(T, v)::T) x end -pyarray_offset(x::PyArray{T,N,M,true}, i::Int) where {T,N,M} = - N == 0 ? 0 : (i - 1) * x.strides[1] -pyarray_offset(x::PyArray{T,1,M,true}, i::Int) where {T,M} = (i - 1) .* x.strides[1] -pyarray_offset(x::PyArray{T,N}, i::Vararg{Int,N}) where {T,N} = sum((i .- 1) .* x.strides) -pyarray_offset(x::PyArray{T,0}) where {T} = 0 +function pyarray_offset(x::PyArray{T,N,F}, i::Int) where {T,N,F} + if N == 0 + 0 + elseif (:contiguous in F) + (i - 1) * sizeof(pyarray_get_R(T)) + elseif (N == 1) || (:linear in F) + (i - 1) * x.strides[1] + else + # convert i to cartesian indices + # there's no public function for this :( + j = Base._to_subscript_indices(x, i) + sum((j .- 1) .* x.strides) + end +end + +function pyarray_offset(x::PyArray{T,N,F}, i::Vararg{Int,N}) where {T,N,F} + sum((i .- 1) .* x.strides) +end function pyarray_load(::Type{T}, p::Ptr{R}) where {T,R} if R == T unsafe_load(p) - elseif R == UnsafePyObject + elseif R == C.PyPtr u = unsafe_load(p) - o = u.ptr == C_NULL ? pynew(Py(nothing)) : pynew(incref(u.ptr)) + o = u == C_NULL ? pynew(Py(nothing)) : pynew(incref(u)) T == Py ? o : pyconvert(T, o) else convert(T, unsafe_load(p)) end end -function pyarray_store!(p::Ptr{R}, x::T) where {R,T} +function pyarray_store!(::Type{T}, p::Ptr{R}, x::T) where {R,T} if R == T unsafe_store!(p, x) - elseif R == UnsafePyObject + elseif R == C.PyPtr @autopy x begin decref(unsafe_load(p).ptr) - unsafe_store!(p, UnsafePyObject(getptr(incref(x_)))) + unsafe_store!(p, getptr(incref(x_))) end else unsafe_store!(p, convert(R, x)) end end -function pyarray_get_T(::Type{R}, ::Type{T0}, ::Type{T1}) where {R,T0,T1} - if R == UnsafePyObject +@generated function pyarray_get_T(::Type{R}, ::Type{T0}, ::Type{T1}) where {R,T0,T1} + if R == C.PyPtr if T0 <: Py <: T1 Py else - T1 + error("impossible") end + elseif R <: Tuple + error("not implemented") elseif T0 <: R <: T1 R else @@ -714,16 +681,38 @@ function pyarray_get_T(::Type{R}, ::Type{T0}, ::Type{T1}) where {R,T0,T1} end end -function pyarray_check_T(::Type{T}, ::Type{R}) where {T,R} - if R == UnsafePyObject - nothing - elseif T == R - nothing - elseif T <: Number && R <: Number - nothing - elseif T <: AbstractString && R <: AbstractString - nothing +@generated function pyarray_get_R(::Type{T}) where {T} + if ( + T == Bool || + T == Int8 || + T == Int16 || + T == Int32 || + T == Int64 || + T == UInt8 || + T == UInt16 || + T == UInt32 || + T == UInt64 || + T == Float16 || + T == Float32 || + T == Float64 || + T == Complex{Float16} || + T == Complex{Float32} || + T == Complex{Float16} + ) + T + elseif isconcretetype(T) && + isbitstype(T) && + ( + T <: NumpyDates.InlineDateTime64 || + T <: NumpyDates.InlineTimeDelta64 || + T <: Ptr + ) + T + elseif T == Py + C.PyPtr + elseif T <: Tuple + Tuple{map(pyarray_get_R, T.parameters)...} else - error("invalid eltype T=$T for raw eltype R=$R") + error("unsupported eltype $T") end end