Skip to content

Commit 8e803e1

Browse files
committed
[FFI][REFACTOR] Update to distinguish as and cast (apache#17979)
This PR updates the Any system to distinguish as and cast - as function will run strict check and won't do any type conversion - try_cast/cast will try to run the type conversion We also updated the type traits to be consistent with the naming
1 parent 0eeb661 commit 8e803e1

File tree

16 files changed

+318
-197
lines changed

16 files changed

+318
-197
lines changed

include/tvm/ffi/any.h

Lines changed: 111 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,41 @@ class AnyView {
9999
return *this;
100100
}
101101

102+
/*!
103+
* \brief Try to see if we can reinterpret the AnyView to as T object.
104+
*
105+
* \tparam T The type to cast to.
106+
* \return The casted value, or std::nullopt if the cast is not possible.
107+
* \note This function won't try run type conversion (use try_cast for that purpose).
108+
*/
102109
template <typename T, typename = std::enable_if_t<TypeTraits<T>::convert_enabled>>
103110
TVM_FFI_INLINE std::optional<T> as() const {
104-
return TypeTraits<T>::TryConvertFromAnyView(&data_);
111+
if (TypeTraits<T>::CheckAnyStrict(&data_)) {
112+
return TypeTraits<T>::CopyFromAnyViewAfterCheck(&data_);
113+
} else {
114+
return std::optional<T>(std::nullopt);
115+
}
116+
}
117+
/*
118+
* \brief Shortcut of as Object to cast to a const pointer when T is an Object.
119+
*
120+
* \tparam T The object type.
121+
* \return The requested pointer, returns nullptr if type mismatches.
122+
*/
123+
template <typename T, typename = std::enable_if_t<std::is_base_of_v<Object, T>>>
124+
TVM_FFI_INLINE const T* as() const {
125+
return this->as<const T*>().value_or(nullptr);
105126
}
106127

128+
/**
129+
* \brief Cast to a type T.
130+
*
131+
* \tparam T The type to cast to.
132+
* \return The casted value, or throws an exception if the cast is not possible.
133+
*/
107134
template <typename T, typename = std::enable_if_t<TypeTraits<T>::convert_enabled>>
108135
TVM_FFI_INLINE T cast() const {
109-
std::optional<T> opt = TypeTraits<T>::TryConvertFromAnyView(&data_);
136+
std::optional<T> opt = TypeTraits<T>::TryCastFromAnyView(&data_);
110137
if (!opt.has_value()) {
111138
TVM_FFI_THROW(TypeError) << "Cannot convert from type `"
112139
<< TypeTraits<T>::GetMismatchTypeInfo(&data_) << "` to `"
@@ -115,16 +142,17 @@ class AnyView {
115142
return *std::move(opt);
116143
}
117144

118-
/*
119-
* \brief Shortcut of as Object to cast to a const pointer when T is an Object.
145+
/*!
146+
* \brief Try to cast to a type T, return std::nullopt if the cast is not possible.
120147
*
121-
* \tparam T The object type.
122-
* \return The requested pointer, returns nullptr if type mismatches.
148+
* \tparam T The type to cast to.
149+
* \return The casted value, or std::nullopt if the cast is not possible.
123150
*/
124-
template <typename T, typename = std::enable_if_t<std::is_base_of_v<Object, T>>>
125-
TVM_FFI_INLINE const T* as() const {
126-
return this->as<const T*>().value_or(nullptr);
151+
template <typename T, typename = std::enable_if_t<TypeTraits<T>::convert_enabled>>
152+
TVM_FFI_INLINE std::optional<T> try_cast() const {
153+
return TypeTraits<T>::TryCastFromAnyView(&data_);
127154
}
155+
128156
// comparison with nullptr
129157
TVM_FFI_INLINE bool operator==(std::nullptr_t) const noexcept {
130158
return data_.type_index == TypeIndex::kTVMFFINone;
@@ -269,13 +297,45 @@ class Any {
269297
return *this;
270298
}
271299

300+
/**
301+
* \brief Try to reinterpret the Any as a type T, return std::nullopt if it is not possible.
302+
*
303+
* \tparam T The type to cast to.
304+
* \return The casted value, or std::nullopt if the cast is not possible.
305+
* \note This function won't try to run type conversion (use try_cast for that purpose).
306+
*/
307+
template <typename T,
308+
typename = std::enable_if_t<TypeTraits<T>::storage_enabled || std::is_same_v<T, Any>>>
309+
TVM_FFI_INLINE std::optional<T> as() && {
310+
if constexpr (std::is_same_v<T, Any>) {
311+
return std::move(*this);
312+
} else {
313+
if (TypeTraits<T>::CheckAnyStrict(&data_)) {
314+
return TypeTraits<T>::MoveFromAnyAfterCheck(&data_);
315+
} else {
316+
return std::optional<T>(std::nullopt);
317+
}
318+
}
319+
}
320+
321+
/**
322+
* \brief Try to reinterpret the Any as a type T, return std::nullopt if it is not possible.
323+
*
324+
* \tparam T The type to cast to.
325+
* \return The casted value, or std::nullopt if the cast is not possible.
326+
* \note This function won't try to run type conversion (use try_cast for that purpose).
327+
*/
272328
template <typename T,
273329
typename = std::enable_if_t<TypeTraits<T>::convert_enabled || std::is_same_v<T, Any>>>
274-
TVM_FFI_INLINE std::optional<T> as() const {
330+
TVM_FFI_INLINE std::optional<T> as() const& {
275331
if constexpr (std::is_same_v<T, Any>) {
276332
return *this;
277333
} else {
278-
return TypeTraits<T>::TryConvertFromAnyView(&data_);
334+
if (TypeTraits<T>::CheckAnyStrict(&data_)) {
335+
return TypeTraits<T>::CopyFromAnyViewAfterCheck(&data_);
336+
} else {
337+
return std::optional<T>(std::nullopt);
338+
}
279339
}
280340
}
281341

@@ -286,13 +346,18 @@ class Any {
286346
* \return The requested pointer, returns nullptr if type mismatches.
287347
*/
288348
template <typename T, typename = std::enable_if_t<std::is_base_of_v<Object, T>>>
289-
TVM_FFI_INLINE const T* as() const {
349+
TVM_FFI_INLINE const T* as() const& {
290350
return this->as<const T*>().value_or(nullptr);
291351
}
292352

353+
/**
354+
* \brief Cast to a type T, throw an exception if the cast is not possible.
355+
*
356+
* \tparam T The type to cast to.
357+
*/
293358
template <typename T, typename = std::enable_if_t<TypeTraits<T>::convert_enabled>>
294359
TVM_FFI_INLINE T cast() const& {
295-
std::optional<T> opt = TypeTraits<T>::TryConvertFromAnyView(&data_);
360+
std::optional<T> opt = TypeTraits<T>::TryCastFromAnyView(&data_);
296361
if (!opt.has_value()) {
297362
TVM_FFI_THROW(TypeError) << "Cannot convert from type `"
298363
<< TypeTraits<T>::GetMismatchTypeInfo(&data_) << "` to `"
@@ -301,13 +366,18 @@ class Any {
301366
return *std::move(opt);
302367
}
303368

369+
/**
370+
* \brief Cast to a type T, throw an exception if the cast is not possible.
371+
*
372+
* \tparam T The type to cast to.
373+
*/
304374
template <typename T, typename = std::enable_if_t<TypeTraits<T>::storage_enabled>>
305375
TVM_FFI_INLINE T cast() && {
306-
if (TypeTraits<T>::CheckAnyStorage(&data_)) {
307-
return TypeTraits<T>::MoveFromAnyStorageAfterCheck(&data_);
376+
if (TypeTraits<T>::CheckAnyStrict(&data_)) {
377+
return TypeTraits<T>::MoveFromAnyAfterCheck(&data_);
308378
}
309379
// slow path, try to do fallback convert
310-
std::optional<T> opt = TypeTraits<T>::TryConvertFromAnyView(&data_);
380+
std::optional<T> opt = TypeTraits<T>::TryCastFromAnyView(&data_);
311381
if (!opt.has_value()) {
312382
TVM_FFI_THROW(TypeError) << "Cannot convert from type `"
313383
<< TypeTraits<T>::GetMismatchTypeInfo(&data_) << "` to `"
@@ -316,6 +386,22 @@ class Any {
316386
return *std::move(opt);
317387
}
318388

389+
/**
390+
* \brief Try to cast to a type T.
391+
*
392+
* \tparam T The type to cast to.
393+
* \return The casted value, or std::nullopt if the cast is not possible.
394+
* \note use STL name since it to be more consistent with cast API.
395+
*/
396+
template <typename T,
397+
typename = std::enable_if_t<TypeTraits<T>::convert_enabled || std::is_same_v<T, Any>>>
398+
TVM_FFI_INLINE std::optional<T> try_cast() const {
399+
if constexpr (std::is_same_v<T, Any>) {
400+
return *this;
401+
} else {
402+
return TypeTraits<T>::TryCastFromAnyView(&data_);
403+
}
404+
}
319405
/*
320406
* \brief Check if the two Any are same type and value in shallow comparison.
321407
* \param other The other Any
@@ -412,23 +498,23 @@ struct AnyUnsafe : public ObjectUnsafe {
412498
}
413499

414500
template <typename T>
415-
static TVM_FFI_INLINE bool CheckAnyStorage(const Any& ref) {
416-
return TypeTraits<T>::CheckAnyStorage(&(ref.data_));
501+
static TVM_FFI_INLINE bool CheckAnyStrict(const Any& ref) {
502+
return TypeTraits<T>::CheckAnyStrict(&(ref.data_));
417503
}
418504

419505
template <typename T>
420-
static TVM_FFI_INLINE T CopyFromAnyStorageAfterCheck(const Any& ref) {
506+
static TVM_FFI_INLINE T CopyFromAnyViewAfterCheck(const Any& ref) {
421507
if constexpr (!std::is_same_v<T, Any>) {
422-
return TypeTraits<T>::CopyFromAnyStorageAfterCheck(&(ref.data_));
508+
return TypeTraits<T>::CopyFromAnyViewAfterCheck(&(ref.data_));
423509
} else {
424510
return ref;
425511
}
426512
}
427513

428514
template <typename T>
429-
static TVM_FFI_INLINE T MoveFromAnyStorageAfterCheck(Any&& ref) {
515+
static TVM_FFI_INLINE T MoveFromAnyAfterCheck(Any&& ref) {
430516
if constexpr (!std::is_same_v<T, Any>) {
431-
return TypeTraits<T>::MoveFromAnyStorageAfterCheck(&(ref.data_));
517+
return TypeTraits<T>::MoveFromAnyAfterCheck(&(ref.data_));
432518
} else {
433519
return std::move(ref);
434520
}
@@ -461,7 +547,7 @@ struct AnyHash {
461547
if (src.data_.type_index == TypeIndex::kTVMFFIStr ||
462548
src.data_.type_index == TypeIndex::kTVMFFIBytes) {
463549
const BytesObjBase* src_str =
464-
details::AnyUnsafe::CopyFromAnyStorageAfterCheck<const BytesObjBase*>(src);
550+
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const BytesObjBase*>(src);
465551
return details::StableHashBytes(src_str->data, src_str->size);
466552
} else {
467553
return src.data_.v_uint64;
@@ -487,9 +573,9 @@ struct AnyEqual {
487573
if (lhs.data_.type_index == TypeIndex::kTVMFFIStr ||
488574
lhs.data_.type_index == TypeIndex::kTVMFFIBytes) {
489575
const BytesObjBase* lhs_str =
490-
details::AnyUnsafe::CopyFromAnyStorageAfterCheck<const BytesObjBase*>(lhs);
576+
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const BytesObjBase*>(lhs);
491577
const BytesObjBase* rhs_str =
492-
details::AnyUnsafe::CopyFromAnyStorageAfterCheck<const BytesObjBase*>(rhs);
578+
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const BytesObjBase*>(rhs);
493579
return Bytes::memncmp(lhs_str->data, rhs_str->data, lhs_str->size, rhs_str->size) == 0;
494580
}
495581
return false;

include/tvm/ffi/container/array.h

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -386,9 +386,7 @@ class Array : public ObjectRef {
386386
// iterators
387387
struct ValueConverter {
388388
using ResultType = T;
389-
static T convert(const Any& n) {
390-
return details::AnyUnsafe::CopyFromAnyStorageAfterCheck<T>(n);
391-
}
389+
static T convert(const Any& n) { return details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(n); }
392390
};
393391

394392
using iterator = details::IterAdapter<ValueConverter, const Any*>;
@@ -427,7 +425,7 @@ class Array : public ObjectRef {
427425
if (i < 0 || i >= p->size_) {
428426
TVM_FFI_THROW(IndexError) << "indexing " << i << " on an array of size " << p->size_;
429427
}
430-
return details::AnyUnsafe::CopyFromAnyStorageAfterCheck<T>(*(p->begin() + i));
428+
return details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(*(p->begin() + i));
431429
}
432430

433431
/*! \return The size of the array */
@@ -451,7 +449,7 @@ class Array : public ObjectRef {
451449
if (p == nullptr || p->size_ == 0) {
452450
TVM_FFI_THROW(IndexError) << "cannot index a empty array";
453451
}
454-
return details::AnyUnsafe::CopyFromAnyStorageAfterCheck<T>(*(p->begin()));
452+
return details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(*(p->begin()));
455453
}
456454

457455
/*! \return The last element of the array */
@@ -460,7 +458,7 @@ class Array : public ObjectRef {
460458
if (p == nullptr || p->size_ == 0) {
461459
TVM_FFI_THROW(IndexError) << "cannot index a empty array";
462460
}
463-
return details::AnyUnsafe::CopyFromAnyStorageAfterCheck<T>(*(p->end() - 1));
461+
return details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(*(p->end() - 1));
464462
}
465463

466464
public:
@@ -835,7 +833,7 @@ class Array : public ObjectRef {
835833
// no other shared copies of the array.
836834
auto arr = static_cast<ArrayObj*>(data.get());
837835
for (auto it = arr->MutableBegin(); it != arr->MutableEnd(); it++) {
838-
T value = details::AnyUnsafe::CopyFromAnyStorageAfterCheck<T>(*it);
836+
T value = details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(*it);
839837
// reset the original value to nullptr, to ensure unique ownership
840838
it->reset();
841839
T mapped = fmap(std::move(value));
@@ -860,7 +858,7 @@ class Array : public ObjectRef {
860858
// `T`.
861859
bool all_identical = true;
862860
for (; it != arr->end(); it++) {
863-
U mapped = fmap(details::AnyUnsafe::CopyFromAnyStorageAfterCheck<T>(*it));
861+
U mapped = fmap(details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(*it));
864862
if (!(*it).same_as(mapped)) {
865863
// At least one mapped element is different than the
866864
// original. Therefore, prepare the output array,
@@ -914,7 +912,7 @@ class Array : public ObjectRef {
914912
// so we can either start or resume the iteration from that point,
915913
// with no further checks on the result.
916914
for (; it != arr->end(); it++) {
917-
U mapped = fmap(details::AnyUnsafe::CopyFromAnyStorageAfterCheck<T>(*it));
915+
U mapped = fmap(details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(*it));
918916
output->SetItem(it - arr->begin(), std::move(mapped));
919917
}
920918

@@ -952,7 +950,7 @@ inline constexpr bool use_default_type_traits_v<Array<T>> = false;
952950
template <typename T>
953951
struct TypeTraits<Array<T>> : public ObjectRefTypeTraitsBase<Array<T>> {
954952
static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIArray;
955-
using ObjectRefTypeTraitsBase<Array<T>>::CopyFromAnyStorageAfterCheck;
953+
using ObjectRefTypeTraitsBase<Array<T>>::CopyFromAnyViewAfterCheck;
956954

957955
static TVM_FFI_INLINE std::string GetMismatchTypeInfo(const TVMFFIAny* src) {
958956
if (src->type_index != TypeIndex::kTVMFFIArray) {
@@ -962,10 +960,10 @@ struct TypeTraits<Array<T>> : public ObjectRefTypeTraitsBase<Array<T>> {
962960
const ArrayObj* n = reinterpret_cast<const ArrayObj*>(src->v_obj);
963961
for (size_t i = 0; i < n->size(); i++) {
964962
const Any& any_v = (*n)[i];
965-
// CheckAnyStorage is cheaper than as<T>
966-
if (details::AnyUnsafe::CheckAnyStorage<T>(any_v)) continue;
963+
// CheckAnyStrict is cheaper than try_cast<T>
964+
if (details::AnyUnsafe::CheckAnyStrict<T>(any_v)) continue;
967965
// try see if p is convertible to T
968-
if (any_v.as<T>()) continue;
966+
if (any_v.try_cast<T>()) continue;
969967
// now report the accurate mismatch information
970968
return "Array[index " + std::to_string(i) + ": " +
971969
details::AnyUnsafe::GetMismatchTypeInfo<T>(any_v) + "]";
@@ -975,50 +973,50 @@ struct TypeTraits<Array<T>> : public ObjectRefTypeTraitsBase<Array<T>> {
975973
TVM_FFI_UNREACHABLE();
976974
}
977975

978-
static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) {
976+
static TVM_FFI_INLINE bool CheckAnyStrict(const TVMFFIAny* src) {
979977
if (src->type_index != TypeIndex::kTVMFFIArray) return false;
980978
if constexpr (std::is_same_v<T, Any>) {
981979
return true;
982980
} else {
983981
const ArrayObj* n = reinterpret_cast<const ArrayObj*>(src->v_obj);
984982
for (size_t i = 0; i < n->size(); i++) {
985983
const Any& any_v = (*n)[i];
986-
if (!details::AnyUnsafe::CheckAnyStorage<T>(any_v)) return false;
984+
if (!details::AnyUnsafe::CheckAnyStrict<T>(any_v)) return false;
987985
}
988986
return true;
989987
}
990988
}
991989

992-
static TVM_FFI_INLINE std::optional<Array<T>> TryConvertFromAnyView(const TVMFFIAny* src) {
990+
static TVM_FFI_INLINE std::optional<Array<T>> TryCastFromAnyView(const TVMFFIAny* src) {
993991
// try to run conversion.
994992
if (src->type_index != TypeIndex::kTVMFFIArray) return std::nullopt;
995993
if constexpr (!std::is_same_v<T, Any>) {
996994
const ArrayObj* n = reinterpret_cast<const ArrayObj*>(src->v_obj);
997995
bool storage_check = [&]() {
998996
for (size_t i = 0; i < n->size(); i++) {
999997
const Any& any_v = (*n)[i];
1000-
if (!details::AnyUnsafe::CheckAnyStorage<T>(any_v)) return false;
998+
if (!details::AnyUnsafe::CheckAnyStrict<T>(any_v)) return false;
1001999
}
10021000
return true;
10031001
}();
10041002
// fast path, if storage check passes, we can return the array directly.
10051003
if (storage_check) {
1006-
return CopyFromAnyStorageAfterCheck(src);
1004+
return CopyFromAnyViewAfterCheck(src);
10071005
}
10081006
// slow path, try to run a conversion to Array<T>
10091007
Array<T> result;
10101008
result.reserve(n->size());
10111009
for (size_t i = 0; i < n->size(); i++) {
10121010
const Any& any_v = (*n)[i];
1013-
if (auto opt_v = any_v.as<T>()) {
1011+
if (auto opt_v = any_v.try_cast<T>()) {
10141012
result.push_back(*std::move(opt_v));
10151013
} else {
10161014
return std::nullopt;
10171015
}
10181016
}
10191017
return result;
10201018
} else {
1021-
return CopyFromAnyStorageAfterCheck(src);
1019+
return CopyFromAnyViewAfterCheck(src);
10221020
}
10231021
}
10241022

0 commit comments

Comments
 (0)