From df420b04019dd1d700851cf241c2812b8380542a Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 1 Jul 2025 14:13:37 -0400 Subject: [PATCH 1/3] [REFACTOR][FFI] Phase out old VisitAttrs mechanism This PR phases out the old VisitAttrs mechanism as we fully transition to use the new reflection mechanism for attribute get and object creation. --- ffi/include/tvm/ffi/object.h | 1 - include/tvm/arith/int_solver.h | 3 - include/tvm/ir/attrs.h | 602 +----------------- include/tvm/ir/expr.h | 10 +- include/tvm/meta_schedule/database.h | 6 +- include/tvm/meta_schedule/extracted_task.h | 2 +- include/tvm/meta_schedule/feature_extractor.h | 4 - include/tvm/meta_schedule/measure_callback.h | 4 - include/tvm/meta_schedule/measure_candidate.h | 2 - include/tvm/meta_schedule/mutator.h | 4 - include/tvm/meta_schedule/postproc.h | 4 - include/tvm/meta_schedule/profiler.h | 2 - include/tvm/meta_schedule/runner.h | 8 +- include/tvm/meta_schedule/schedule_rule.h | 4 - include/tvm/meta_schedule/search_strategy.h | 2 - include/tvm/meta_schedule/space_generator.h | 4 - include/tvm/meta_schedule/tune_context.h | 2 - include/tvm/node/reflection.h | 92 +-- include/tvm/node/structural_hash.h | 1 - include/tvm/script/printer/doc.h | 50 +- include/tvm/script/printer/ir_docsifier.h | 4 +- include/tvm/target/tag.h | 2 +- include/tvm/target/target.h | 2 +- include/tvm/target/target_info.h | 2 +- include/tvm/target/target_kind.h | 2 +- include/tvm/te/operation.h | 11 +- include/tvm/te/tensor.h | 2 +- include/tvm/tir/expr.h | 18 - include/tvm/tir/var.h | 2 - python/tvm/ir/attrs.py | 20 - python/tvm/topi/gpu/scan.py | 1 + src/contrib/msc/core/ir/graph_builder.h | 66 +- src/ir/attrs.cc | 13 +- src/ir/transform.cc | 4 +- .../search_strategy/replay_func.cc | 2 - .../search_strategy/replay_trace.cc | 2 - .../space_generator/post_order_apply.cc | 2 - .../space_generator/schedule_fn.cc | 2 - .../space_generator/space_generator_union.cc | 2 - src/node/reflection.cc | 233 +------ src/node/serialization.cc | 243 +++---- src/node/structural_hash.cc | 21 - .../contrib/codegen_json/codegen_json.h | 108 ++-- src/relax/backend/vm/vm_shape_lower.cc | 2 +- src/relax/ir/binding_rewrite.cc | 2 +- .../transform/static_plan_block_memory.cc | 2 +- src/script/printer/ir_docsifier.cc | 44 +- src/script/printer/relax/call.cc | 84 +-- src/te/operation/create_primfunc.cc | 2 +- src/tir/ir/data_type_rewriter.cc | 2 +- src/tir/schedule/concrete_schedule.h | 4 +- .../schedule/primitive/decompose_padding.cc | 2 +- 52 files changed, 270 insertions(+), 1445 deletions(-) diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h index 164c4e2f369c..ce667d6b4b65 100644 --- a/ffi/include/tvm/ffi/object.h +++ b/ffi/include/tvm/ffi/object.h @@ -214,7 +214,6 @@ class Object { static constexpr int32_t _type_depth = 0; // extra fields used by plug-ins for attribute visiting // and structural information - static constexpr const bool _type_has_method_visit_attrs = true; static constexpr const bool _type_has_method_sequal_reduce = false; static constexpr const bool _type_has_method_shash_reduce = false; // The following functions are provided by macro diff --git a/include/tvm/arith/int_solver.h b/include/tvm/arith/int_solver.h index 4716dc7aa274..dd9259cf97cb 100644 --- a/include/tvm/arith/int_solver.h +++ b/include/tvm/arith/int_solver.h @@ -83,7 +83,6 @@ class IntGroupBoundsNode : public Object { hash_reduce(upper); } - static constexpr const bool _type_has_method_visit_attrs = false; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const char* _type_key = "arith.IntGroupBounds"; TVM_DECLARE_FINAL_OBJECT_INFO(IntGroupBoundsNode, Object); @@ -174,7 +173,6 @@ class IntConstraintsNode : public Object { hash_reduce(relations); } - static constexpr const bool _type_has_method_visit_attrs = false; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const char* _type_key = "arith.IntConstraints"; TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsNode, Object); @@ -240,7 +238,6 @@ class IntConstraintsTransformNode : public Object { hash_reduce(dst_to_src); } - static constexpr const bool _type_has_method_visit_attrs = false; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const char* _type_key = "arith.IntConstraintsTransform"; TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsTransformNode, Object); diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index caa1b2fb8c78..8715643d709a 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -44,22 +44,6 @@ #include namespace tvm { -/*! - * \brief Declare an attribute function. - * \param ClassName The name of the class. - * \param TypeKey The type key to be used by the TVM node system. - */ -#define TVM_DECLARE_ATTRS(ClassName, TypeKey) \ - static constexpr const char* _type_key = TypeKey; \ - TVM_DECLARE_FINAL_OBJECT_INFO(ClassName, ::tvm::BaseAttrsNode); \ - template \ - void _tvm_VisitAttrs(FVisit& _tvm_fvisit) // NOLINT(*) - -/*! - * \brief Declare an attribute field. - * \param FieldName The field name. - */ -#define TVM_ATTR_FIELD(FieldName) _tvm_fvisit(#FieldName, &FieldName) /*! * \brief Create a NodeRef type that represents null. @@ -77,15 +61,6 @@ inline DataType NullValue() { return DataType(DataType::kHandle, 0, 0); } -/*! \brief Error thrown during attribute checking. */ -struct AttrError : public Error { - /*! - * \brief constructor - * \param msg error message - */ - explicit AttrError(std::string msg) : Error("AttributeError", msg, TVM_FFI_TRACEBACK_HERE) {} -}; - /*! * \brief Information about attribute fields in string representations. */ @@ -98,13 +73,16 @@ class AttrFieldInfoNode : public Object { /*! \brief detailed description of the type */ String description; - void VisitAttrs(AttrVisitor* v) { - v->Visit("name", &name); - v->Visit("type_info", &type_info); - v->Visit("description", &description); + static void RegisterReflection() { + namespace rfl = ffi::reflection; + rfl::ObjectDef() + .def_ro("name", &AttrFieldInfoNode::name) + .def_ro("type_info", &AttrFieldInfoNode::type_info) + .def_ro("description", &AttrFieldInfoNode::description); } static constexpr const char* _type_key = "ir.AttrFieldInfo"; + static constexpr bool _type_has_method_visit_attrs = false; static constexpr bool _type_has_method_sequal_reduce = false; static constexpr bool _type_has_method_shash_reduce = false; TVM_DECLARE_FINAL_OBJECT_INFO(AttrFieldInfoNode, Object); @@ -126,8 +104,6 @@ class BaseAttrsNode : public Object { public: /*! \brief virtual destructor */ virtual ~BaseAttrsNode() {} - // visit function - virtual void VisitAttrs(AttrVisitor* v) {} /*! * \brief Initialize the attributes by sequence of arguments * \param args The positional arguments in the form @@ -135,23 +111,6 @@ class BaseAttrsNode : public Object { */ template inline void InitBySeq(Args&&... args); - /*! - * \brief Print readible docstring to ostream, add newline. - * \param os the stream to print the docstring to. - */ - inline void PrintDocString(std::ostream& os) const; // NOLINT(*) - /*! - * \brief Visit attributes that do not equal the default value. - * - * \note This is useful to extract fields for concise printing. - * \param v The visitor - */ - TVM_DLL virtual void VisitNonDefaultAttrs(AttrVisitor* v) = 0; - /*! - * \brief Get the field information - * \return The fields in the Attrs. - */ - TVM_DLL virtual Array ListFieldInfo() const = 0; /*! * \brief Initialize the attributes by arguments. * \param kwargs The key value pairs for initialization. @@ -188,17 +147,17 @@ class DictAttrsNode : public BaseAttrsNode { /*! \brief internal attrs map */ Map dict; + static void RegisterReflection() { + namespace rfl = ffi::reflection; + rfl::ObjectDef().def_ro("__dict__", &DictAttrsNode::dict); + } + bool SEqualReduce(const DictAttrsNode* other, SEqualReducer equal) const { return equal(dict, other->dict); } void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dict); } - - // implementations - void VisitAttrs(AttrVisitor* v) final; - void VisitNonDefaultAttrs(AttrVisitor* v) final; void InitByPackedArgs(const ffi::PackedArgs& args, bool allow_unknown) final; - Array ListFieldInfo() const final; // type info static constexpr const char* _type_key = "ir.DictAttrs"; @@ -357,7 +316,6 @@ inline TFunc WithAttr(TFunc input, const std::string& attr_key, Any attr_value) static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); TNode* node = input.CopyOnWrite(); node->attrs = WithAttr(std::move(node->attrs), attr_key, attr_value); - return input; } @@ -419,529 +377,6 @@ inline TFunc WithoutAttr(TFunc input, const std::string& attr_key) { return input; } -// Namespace containing detail implementations -namespace detail { - -using tvm::ffi::AnyView; - -// helper entry that does nothing in set_default/bound/describe calls. -struct AttrNopEntry { - using TSelf = AttrNopEntry; - - TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; } - template - TSelf& set_default(DMLC_ATTRIBUTE_UNUSED const T& value) { - return *this; - } - template - TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) { - return *this; - } - template - TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) { - return *this; - } -}; - -// Wrapper for normal visitor. -class AttrNormalVisitor { - public: - explicit AttrNormalVisitor(AttrVisitor* visitor) : visitor_(visitor) {} - template - AttrNopEntry operator()(const char* key, T* value) { - visitor_->Visit(key, value); - return AttrNopEntry(); - } - - private: - AttrVisitor* visitor_; -}; - -class AttrsSEqualVisitor { - public: - bool result_{true}; - // constructor - AttrsSEqualVisitor(const Object* lhs, const Object* rhs, const SEqualReducer& equal) - : lhs_(lhs), rhs_(rhs), equal_(equal) {} - template - AttrNopEntry operator()(const char* key, T* lhs_value) { - if (!result_) return AttrNopEntry(); - const T* rhs_value = reinterpret_cast( - reinterpret_cast(rhs_) + - (reinterpret_cast(lhs_value) - reinterpret_cast(lhs_))); - if (!equal_(*lhs_value, *rhs_value)) { - result_ = false; - } - return AttrNopEntry(); - } - - private: - const Object* lhs_; - const Object* rhs_; - const SEqualReducer& equal_; -}; - -class AttrsSHashVisitor { - public: - explicit AttrsSHashVisitor(const SHashReducer& hash_reducer) : hash_reducer_(hash_reducer) {} - - template - AttrNopEntry operator()(const char* key, T* value) { - hash_reducer_(*value); - return AttrNopEntry(); - } - - private: - const SHashReducer& hash_reducer_; -}; - -// helper entry that does initialization, set default. -template -struct AttrInitEntry { - // The attributes - using TSelf = AttrInitEntry; - // The type key - const char* type_key_; - // field name - const char* key_; - // internal value. - T* value_; - // whether the value is missing. - // NOTE: initialize to false so that the destructor does not throw unless - // AttrInitVisitor::operator() is committed to returning an instance of this class. - // It is expected not to set this to true until that is true. - bool value_missing_{false}; - - AttrInitEntry() = default; - - AttrInitEntry(AttrInitEntry&& other) { - type_key_ = other.type_key_; - key_ = other.key_; - value_ = other.value_; - value_missing_ = other.value_missing_; - // avoid unexpected throw - other.value_missing_ = false; - } - - // If the value is still missing in destruction time throw an error. - ~AttrInitEntry() DMLC_THROW_EXCEPTION { - if (value_missing_) { - std::ostringstream os; - os << type_key_ << ": Cannot find required field \'" << key_ << "\' during initialization. " - << "If the key is defined check that its type matches the declared type."; - throw AttrError(os.str()); - } - } - // override fields. - // This function sets the lower bound of the attribute - TSelf& set_lower_bound(const T& begin) { - if (this->value_missing_) return *this; - const T& val = *value_; - if (begin > val) { - std::ostringstream os; - os << type_key_ << "." << key_ << "'s " - << "value " << val << " is smaller than the lower bound " << begin; - throw AttrError(os.str()); - } - return *this; - } - // This function sets the upper bound of the attribute - TSelf& set_upper_bound(const T& end) { - if (this->value_missing_) return *this; - const T& val = *value_; - if (val > end) { - std::ostringstream os; - os << type_key_ << "." << key_ << "'s " - << "value " << val << " is bigger than the upper bound " << end; - throw AttrError(os.str()); - } - return *this; - } - // set default when - TSelf& set_default(const T& value) { - if (!value_missing_) return *this; - *value_ = value; - value_missing_ = false; - return *this; - } - TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; } -}; - -// Template function to allow smart conversion -// from Expr types into the constants. -template -inline void SetValue(T* ptr, const ffi::AnyView& val) { - *ptr = val.cast(); -} - -template -inline void SetIntValue(T* ptr, const ffi::AnyView& val) { - if (auto opt_int = val.try_cast()) { - *ptr = static_cast(opt_int.value()); - } else { - IntImm expr = val.cast(); - *ptr = static_cast(expr->value); - } -} - -// Workaround for GCC8.1 / GCC8.2 -template <> -inline void SetValue(DataType* ptr, const ffi::AnyView& val) { - *ptr = DataType(val.cast()); -} - -template <> -inline void SetValue(std::string* ptr, const ffi::AnyView& val) { - *ptr = val.cast(); -} - -template <> -inline void SetValue(double* ptr, const ffi::AnyView& val) { - if (auto opt_double = val.try_cast()) { - *ptr = opt_double.value(); - } else { - ObjectRef expr = val.cast(); - ICHECK(expr.defined()); - if (const IntImmNode* op = expr.as()) { - *ptr = static_cast(op->value); - } else if (const FloatImmNode* op = expr.as()) { - *ptr = static_cast(op->value); - } else { - LOG(FATAL) << "Expect float value, but get " << expr->GetTypeKey(); - } - } -} -template <> -inline void SetValue(int* ptr, const ffi::AnyView& val) { - SetIntValue(ptr, val); -} -template <> -inline void SetValue(int64_t* ptr, const ffi::AnyView& val) { - SetIntValue(ptr, val); -} -template <> -inline void SetValue(uint64_t* ptr, const ffi::AnyView& val) { - SetIntValue(ptr, val); -} -template <> -inline void SetValue(bool* ptr, const ffi::AnyView& val) { - SetIntValue(ptr, val); -} - -// Visitor for value initialization -template -class AttrInitVisitor { - public: - // Counter of number of matched attributes during visit. - // This is used to decide if there is additional unmatched attributes. - size_t hit_count_{0}; - // constructor - AttrInitVisitor(const char* type_key, FFind ffind) : type_key_(type_key), ffind_(ffind) {} - - template - AttrInitEntry operator()(const char* key, T* value) { - ffi::AnyView val; - AttrInitEntry opt; - opt.type_key_ = type_key_; - opt.key_ = key; - opt.value_ = value; - if (ffind_(key, &val)) { - SetValue(value, val); - opt.value_missing_ = false; - ++hit_count_; - } else { - opt.value_missing_ = true; - } -#if defined(__GNUC__) -#pragma GCC diagnostic ignored "-Wpragmas" -#pragma GCC diagnostic ignored "-Wpessimizing-move" -#endif - return std::move(opt); - } - - private: - // the type key - const char* type_key_; - FFind ffind_; -}; - -template -inline AttrInitVisitor CreateInitVisitor(const char* type_key, FFind ffind) { - return AttrInitVisitor(type_key, ffind); -} - -/*! - * \brief Helper struct to get the type name known to tvm. - * \tparam T the type we are interested in. - */ -template -struct TypeName { - static constexpr const char* value = T::ContainerType::_type_key; -}; - -template <> -struct TypeName { - static constexpr const char* value = "int"; -}; - -template <> -struct TypeName { - static constexpr const char* value = "int"; -}; - -template <> -struct TypeName> { - static constexpr const char* value = "Optional[int]"; -}; - -template <> -struct TypeName> { - static constexpr const char* value = "Optional[float]"; -}; - -template <> -struct TypeName { - static constexpr const char* value = "int"; -}; - -template <> -struct TypeName { - static constexpr const char* value = "DataType"; -}; - -template <> -struct TypeName { - static constexpr const char* value = "str"; -}; - -template <> -struct TypeName { - static constexpr const char* value = "bool"; -}; - -template <> -struct TypeName { - static constexpr const char* value = "handle"; -}; - -template <> -struct TypeName { - static constexpr const char* value = "float"; -}; - -class AttrDocEntry { - public: - using TSelf = AttrDocEntry; - - explicit AttrDocEntry(ObjectPtr info) : info_(info) {} - TSelf& describe(const char* str) { - info_->description = str; - return *this; - } - template - TSelf& set_default(const T& value) { - std::ostringstream os; - os << info_->type_info << ", default=" << value; - info_->type_info = os.str(); - return *this; - } - template - TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED T begin) { - return *this; - } - template - TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED T end) { - return *this; - } - - private: - ObjectPtr info_; -}; - -class AttrDocVisitor { - public: - template - AttrDocEntry operator()(const char* key, T* v) { - ObjectPtr info = make_object(); - info->name = key; - info->type_info = TypeName::value; - fields_.push_back(AttrFieldInfo(info)); - return AttrDocEntry(info); - } - - Array fields_; -}; - -class AttrExistVisitor { - public: - std::string key_; - bool exist_{false}; - - template - AttrNopEntry operator()(const char* key, T* v) { - if (exist_) return AttrNopEntry(); - if (key == key_) exist_ = true; - return AttrNopEntry(); - } -}; - -template -struct AttrTriggerNonDefaultEntry { - using TSelf = AttrTriggerNonDefaultEntry; - // constructor - AttrTriggerNonDefaultEntry(AttrVisitor* visitor, const char* key, T* data) - : visitor_(visitor), key_(key), data_(data) {} - - ~AttrTriggerNonDefaultEntry() DMLC_THROW_EXCEPTION { - if (trigger_) { - visitor_->Visit(key_, data_); - } - } - TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; } - TSelf& set_default(const T& value) { - if (tvm::StructuralEqual()(value, *data_)) { - trigger_ = false; - } - return *this; - } - TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) { return *this; } - TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) { return *this; } - - private: - AttrVisitor* visitor_; - const char* key_; - T* data_; - bool trigger_{true}; -}; - -class AttrNonDefaultVisitor { - public: - explicit AttrNonDefaultVisitor(AttrVisitor* visitor) : visitor_(visitor) {} - template - AttrTriggerNonDefaultEntry operator()(const char* key, T* value) { - return AttrTriggerNonDefaultEntry(visitor_, key, value); - } - - private: - AttrVisitor* visitor_; -}; -} // namespace detail - -/*! - * \brief The base class of the all the - * Use "curiously recurring template pattern". - * - * \tparam DerivedType The final attribute type. - */ -template -class AttrsNode : public BaseAttrsNode { - public: - void VisitAttrs(AttrVisitor* v) { - ::tvm::detail::AttrNormalVisitor vis(v); - self()->_tvm_VisitAttrs(vis); - } - - void VisitNonDefaultAttrs(AttrVisitor* v) { - ::tvm::detail::AttrNonDefaultVisitor vis(v); - self()->_tvm_VisitAttrs(vis); - } - - void InitByPackedArgs(const ffi::PackedArgs& args, bool allow_unknown) final { - ICHECK_EQ(args.size() % 2, 0); - const int kLinearSearchBound = 16; - int hit_count = 0; - // applies two strategies to lookup - if (args.size() < kLinearSearchBound) { - // linear search. - auto ffind = [&args](const char* key, ffi::AnyView* val) { - for (int i = 0; i < args.size(); i += 2) { - if (!std::strcmp(key, args[i].cast())) { - *val = args[i + 1]; - return true; - } - } - return false; - }; - auto vis = ::tvm::detail::CreateInitVisitor(DerivedType::_type_key, ffind); - self()->_tvm_VisitAttrs(vis); - hit_count = vis.hit_count_; - } else { - // construct a map then do lookup. - std::unordered_map kwargs; - for (int i = 0; i < args.size(); i += 2) { - kwargs[args[i].cast()] = args[i + 1]; - } - auto ffind = [&kwargs](const char* key, ffi::AnyView* val) { - auto it = kwargs.find(key); - if (it != kwargs.end()) { - *val = it->second; - return true; - } - return false; - }; - auto vis = ::tvm::detail::CreateInitVisitor(DerivedType::_type_key, ffind); - self()->_tvm_VisitAttrs(vis); - hit_count = vis.hit_count_; - } - // error handling, slow path - if (hit_count * 2 != args.size() && !allow_unknown) { - for (int i = 0; i < args.size(); i += 2) { - ::tvm::detail::AttrExistVisitor visitor; - visitor.key_ = args[i].cast(); - self()->_tvm_VisitAttrs(visitor); - if (!visitor.exist_) { - std::ostringstream os; - os << DerivedType::_type_key << " does not have field \'" << visitor.key_ - << "\', Possible fields:\n"; - os << "----------------\n"; - this->PrintDocString(os); - throw AttrError(os.str()); - } - } - } - } - - bool SEqualReduce(const DerivedType* other, SEqualReducer equal) const { - DerivedType* pself = self(); - ::tvm::detail::AttrsSEqualVisitor visitor(pself, other, equal); - self()->_tvm_VisitAttrs(visitor); - return visitor.result_; - } - - void SHashReduce(SHashReducer hash_reducer) const { - ::tvm::detail::AttrsSHashVisitor visitor(hash_reducer); - self()->_tvm_VisitAttrs(visitor); - } - - Array ListFieldInfo() const final { - ::tvm::detail::AttrDocVisitor visitor; - self()->_tvm_VisitAttrs(visitor); - return visitor.fields_; - } - - private: - DerivedType* self() const { - return const_cast(static_cast(this)); - } -}; - -template -inline void BaseAttrsNode::InitBySeq(Args&&... args) { - ffi::Function pf( - [this](const ffi::PackedArgs& args, ffi::Any* rv) { this->InitByPackedArgs(args); }); - pf(std::forward(args)...); -} - -inline void BaseAttrsNode::PrintDocString(std::ostream& os) const { // NOLINT(*) - Array entry = this->ListFieldInfo(); - for (AttrFieldInfo info : entry) { - os << info->name << " : " << info->type_info << '\n'; - if (info->description.length() != 0) { - os << " " << info->description << '\n'; - } - } -} - /*! * \brief Adapter for AttrsNode with the new reflection API. * @@ -956,14 +391,6 @@ class AttrsNodeReflAdapter : public BaseAttrsNode { void InitByPackedArgs(const ffi::PackedArgs& args, bool allow_unknown) final { LOG(FATAL) << "`" << DerivedType::_type_key << "` uses new reflection mechanism for init"; } - void VisitNonDefaultAttrs(AttrVisitor* v) final { - LOG(FATAL) << "`" << DerivedType::_type_key - << "` uses new reflection mechanism for visit non default attrs"; - } - void VisitAttrs(AttrVisitor* v) final { - LOG(FATAL) << "`" << DerivedType::_type_key - << "` uses new reflection mechanism for visit attrs"; - } bool SEqualReduce(const DerivedType* other, SEqualReducer equal) const { const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(DerivedType::RuntimeTypeIndex()); @@ -991,11 +418,6 @@ class AttrsNodeReflAdapter : public BaseAttrsNode { }); } - Array ListFieldInfo() const final { - // use the new reflection to list field info - return Array(); - } - private: DerivedType* self() const { return const_cast(static_cast(this)); diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 07a198cfc33b..4bc54fd0df5e 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -62,7 +62,7 @@ class BaseExprNode : public Object { } static constexpr const char* _type_key = "ir.BaseExpr"; - static constexpr const bool _type_has_method_visit_attrs = true; + static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; static constexpr const uint32_t _type_child_slots = 64; @@ -113,8 +113,6 @@ class PrimExprNode : public BaseExprNode { refl::ObjectDef().def_ro("dtype", &PrimExprNode::dtype); } - static constexpr const bool _type_has_method_visit_attrs = false; - TVM_OBJECT_ENABLE_SCRIPT_PRINTER(); static constexpr const char* _type_key = "ir.PrimExpr"; @@ -461,8 +459,6 @@ class GlobalVarNode : public RelaxExprNode { /*! \brief The name of the variable, this only acts as a hint. */ String name_hint; - static constexpr const bool _type_has_method_visit_attrs = false; - static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("name_hint", &GlobalVarNode::name_hint); @@ -549,8 +545,6 @@ class FloatImmNode : public PrimExprNode { /*! \brief The constant value content. */ double value; - static constexpr const bool _type_has_method_visit_attrs = false; - static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("value", &FloatImmNode::value); @@ -700,8 +694,6 @@ class RangeNode : public Object { RangeNode(PrimExpr min, PrimExpr extent, Span span = Span()) : min(min), extent(extent), span(span) {} - static constexpr const bool _type_has_method_visit_attrs = false; - static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() diff --git a/include/tvm/meta_schedule/database.h b/include/tvm/meta_schedule/database.h index db8ef0a348ce..90981900b214 100644 --- a/include/tvm/meta_schedule/database.h +++ b/include/tvm/meta_schedule/database.h @@ -55,7 +55,7 @@ class WorkloadNode : public runtime::Object { } static constexpr const char* _type_key = "meta_schedule.Workload"; - static constexpr const bool _type_has_method_visit_attrs = false; + TVM_DECLARE_FINAL_OBJECT_INFO(WorkloadNode, runtime::Object); /*! @@ -137,7 +137,7 @@ class TuningRecordNode : public runtime::Object { } static constexpr const char* _type_key = "meta_schedule.TuningRecord"; - static constexpr const bool _type_has_method_visit_attrs = false; + TVM_DECLARE_FINAL_OBJECT_INFO(TuningRecordNode, runtime::Object); /*! \brief Construct the measure candidate given the initial IR module and trace @@ -397,8 +397,6 @@ class PyDatabaseNode : public DatabaseNode { // `f_size` is not registered } - static constexpr const bool _type_has_method_visit_attrs = false; - bool HasWorkload(const IRModule& mod) final { ICHECK(f_has_workload != nullptr) << "PyDatabase's HasWorkload method not implemented!"; return f_has_workload(mod); diff --git a/include/tvm/meta_schedule/extracted_task.h b/include/tvm/meta_schedule/extracted_task.h index e3e1d8272327..5023d5bbfcdc 100644 --- a/include/tvm/meta_schedule/extracted_task.h +++ b/include/tvm/meta_schedule/extracted_task.h @@ -64,7 +64,7 @@ class ExtractedTaskNode : public runtime::Object { } static constexpr const char* _type_key = "meta_schedule.ExtractedTask"; - static constexpr const bool _type_has_method_visit_attrs = false; + TVM_DECLARE_FINAL_OBJECT_INFO(ExtractedTaskNode, runtime::Object); }; diff --git a/include/tvm/meta_schedule/feature_extractor.h b/include/tvm/meta_schedule/feature_extractor.h index d04189700516..d925503d8d9e 100644 --- a/include/tvm/meta_schedule/feature_extractor.h +++ b/include/tvm/meta_schedule/feature_extractor.h @@ -44,8 +44,6 @@ class FeatureExtractorNode : public runtime::Object { // No fields to register } - static constexpr const bool _type_has_method_visit_attrs = false; - /*! * \brief Extract features from the given measure candidate. * \param context The tuning context for feature extraction. @@ -86,8 +84,6 @@ class PyFeatureExtractorNode : public FeatureExtractorNode { // `f_as_string` is not registered } - static constexpr const bool _type_has_method_visit_attrs = false; - Array ExtractFrom(const TuneContext& context, const Array& candidates) final; diff --git a/include/tvm/meta_schedule/measure_callback.h b/include/tvm/meta_schedule/measure_callback.h index 862606640947..ba5c063e0989 100644 --- a/include/tvm/meta_schedule/measure_callback.h +++ b/include/tvm/meta_schedule/measure_callback.h @@ -47,8 +47,6 @@ class MeasureCallbackNode : public runtime::Object { // No fields to register } - static constexpr const bool _type_has_method_visit_attrs = false; - /*! * \brief Apply a measure callback rule with given arguments. * \param task_scheduler The task scheduler. @@ -100,8 +98,6 @@ class PyMeasureCallbackNode : public MeasureCallbackNode { // `f_as_string` is not registered } - static constexpr const bool _type_has_method_visit_attrs = false; - void Apply(const TaskScheduler& task_scheduler, // int task_id, // const Array& measure_candidates, // diff --git a/include/tvm/meta_schedule/measure_candidate.h b/include/tvm/meta_schedule/measure_candidate.h index 79feda757688..edbc4149f72a 100644 --- a/include/tvm/meta_schedule/measure_candidate.h +++ b/include/tvm/meta_schedule/measure_candidate.h @@ -45,8 +45,6 @@ class MeasureCandidateNode : public runtime::Object { .def_ro("args_info", &MeasureCandidateNode::args_info); } - static constexpr const bool _type_has_method_visit_attrs = false; - static constexpr const char* _type_key = "meta_schedule.MeasureCandidate"; TVM_DECLARE_FINAL_OBJECT_INFO(MeasureCandidateNode, Object); }; diff --git a/include/tvm/meta_schedule/mutator.h b/include/tvm/meta_schedule/mutator.h index 7e00f9d72e3a..0aa32409ecaa 100644 --- a/include/tvm/meta_schedule/mutator.h +++ b/include/tvm/meta_schedule/mutator.h @@ -45,8 +45,6 @@ class MutatorNode : public runtime::Object { // No fields to register } - static constexpr const bool _type_has_method_visit_attrs = false; - /*! * \brief Initialize the design space generator with tuning context. * \param context The tuning context for initialization. @@ -169,8 +167,6 @@ class PyMutatorNode : public MutatorNode { // `f_as_string` is not registered } - static constexpr const bool _type_has_method_visit_attrs = false; - void InitializeWithTuneContext(const TuneContext& context) final; Optional Apply(const tir::Trace& trace, support::LinearCongruentialEngine::TRandState* rand_state) final; diff --git a/include/tvm/meta_schedule/postproc.h b/include/tvm/meta_schedule/postproc.h index cfbf9c702e65..1a0c53572c6a 100644 --- a/include/tvm/meta_schedule/postproc.h +++ b/include/tvm/meta_schedule/postproc.h @@ -44,8 +44,6 @@ class PostprocNode : public runtime::Object { // No fields to register } - static constexpr const bool _type_has_method_visit_attrs = false; - /*! * \brief Initialize the design space generator with tuning context. * \param context The tuning context for initialization. @@ -202,8 +200,6 @@ class PyPostprocNode : public PostprocNode { // `f_as_string` is not registered } - static constexpr const bool _type_has_method_visit_attrs = false; - void InitializeWithTuneContext(const TuneContext& context) final; bool Apply(const tir::Schedule& sch) final; Postproc Clone() const final; diff --git a/include/tvm/meta_schedule/profiler.h b/include/tvm/meta_schedule/profiler.h index 21b77109ecc1..e58906e484c3 100644 --- a/include/tvm/meta_schedule/profiler.h +++ b/include/tvm/meta_schedule/profiler.h @@ -65,8 +65,6 @@ class ProfilerNode : public runtime::Object { // `total_timer` is not registered } - static constexpr const bool _type_has_method_visit_attrs = false; - static constexpr const char* _type_key = "meta_schedule.Profiler"; TVM_DECLARE_FINAL_OBJECT_INFO(ProfilerNode, runtime::Object); diff --git a/include/tvm/meta_schedule/runner.h b/include/tvm/meta_schedule/runner.h index 80d5816db031..e0490234d7b2 100644 --- a/include/tvm/meta_schedule/runner.h +++ b/include/tvm/meta_schedule/runner.h @@ -51,7 +51,7 @@ class RunnerInputNode : public runtime::Object { } static constexpr const char* _type_key = "meta_schedule.RunnerInput"; - static constexpr const bool _type_has_method_visit_attrs = false; + TVM_DECLARE_FINAL_OBJECT_INFO(RunnerInputNode, runtime::Object); }; @@ -87,7 +87,7 @@ class RunnerResultNode : public runtime::Object { } static constexpr const char* _type_key = "meta_schedule.RunnerResult"; - static constexpr const bool _type_has_method_visit_attrs = false; + TVM_DECLARE_FINAL_OBJECT_INFO(RunnerResultNode, runtime::Object); }; @@ -134,8 +134,6 @@ class RunnerFutureNode : public runtime::Object { // `f_result` is not registered } - static constexpr const bool _type_has_method_visit_attrs = false; - /*! * \brief Check whether the runner has finished. * \return A boolean indicating whether the runner has finished. @@ -228,8 +226,6 @@ class PyRunnerNode : public RunnerNode { // `f_run` is not registered } - static constexpr const bool _type_has_method_visit_attrs = false; - Array Run(Array runner_inputs) final { ICHECK(f_run != nullptr) << "PyRunner's Run method not implemented!"; return f_run(runner_inputs); diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index e702e1d2cbfe..d86b74e81996 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -47,8 +47,6 @@ class ScheduleRuleNode : public runtime::Object { // No fields to register } - static constexpr const bool _type_has_method_visit_attrs = false; - /*! * \brief Initialize the design space generator with tuning context. * \param context The tuning context for initialization. @@ -332,8 +330,6 @@ class PyScheduleRuleNode : public ScheduleRuleNode { // `f_clone` is not registered } - static constexpr const bool _type_has_method_visit_attrs = false; - void InitializeWithTuneContext(const TuneContext& context) final; Array Apply(const tir::Schedule& sch, const tir::BlockRV& block) final; ScheduleRule Clone() const final; diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h index 5133e9fd8a48..65fae52bafbe 100644 --- a/include/tvm/meta_schedule/search_strategy.h +++ b/include/tvm/meta_schedule/search_strategy.h @@ -251,8 +251,6 @@ class PySearchStrategyNode : public SearchStrategyNode { // `f_clone` is not registered } - static constexpr const bool _type_has_method_visit_attrs = false; - void InitializeWithTuneContext(const TuneContext& context) final; void PreTuning(int max_trials, int num_trials_per_iter, const Array& design_spaces, const Optional& database, const Optional& cost_model) final; diff --git a/include/tvm/meta_schedule/space_generator.h b/include/tvm/meta_schedule/space_generator.h index 68f26a6bfee9..f148314b52e2 100644 --- a/include/tvm/meta_schedule/space_generator.h +++ b/include/tvm/meta_schedule/space_generator.h @@ -91,8 +91,6 @@ class SpaceGeneratorNode : public runtime::Object { .def_ro("mutator_probs", &SpaceGeneratorNode::mutator_probs); } - static constexpr const bool _type_has_method_visit_attrs = false; - /*! \brief Default destructor */ virtual ~SpaceGeneratorNode() = default; @@ -223,8 +221,6 @@ class PySpaceGeneratorNode : public SpaceGeneratorNode { // `f_clone` is not registered } - static constexpr const bool _type_has_method_visit_attrs = false; - void InitializeWithTuneContext(const TuneContext& context) final; Array GenerateDesignSpace(const IRModule& mod) final; SpaceGenerator Clone() const final; diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h index 65d48e000dcc..793b4e672981 100644 --- a/include/tvm/meta_schedule/tune_context.h +++ b/include/tvm/meta_schedule/tune_context.h @@ -78,8 +78,6 @@ class TuneContextNode : public runtime::Object { // `logger` is not registered } - static constexpr const bool _type_has_method_visit_attrs = false; - /*! * \brief Initialize members that needs initialization with tune context. */ diff --git a/include/tvm/node/reflection.h b/include/tvm/node/reflection.h index e56639570e37..9b3186bac117 100644 --- a/include/tvm/node/reflection.h +++ b/include/tvm/node/reflection.h @@ -43,39 +43,6 @@ using runtime::Object; using runtime::ObjectPtr; using runtime::ObjectRef; -/*! - * \brief Visitor class to get the attributes of an AST/IR node. - * The content is going to be called for each field. - * - * Each objects that wants reflection will need to implement - * a VisitAttrs function and call visitor->Visit on each of its field. - */ -class AttrVisitor { - public: - //! \cond Doxygen_Suppress - TVM_DLL virtual ~AttrVisitor() = default; - TVM_DLL virtual void Visit(const char* key, double* value) = 0; - TVM_DLL virtual void Visit(const char* key, int64_t* value) = 0; - TVM_DLL virtual void Visit(const char* key, uint64_t* value) = 0; - TVM_DLL virtual void Visit(const char* key, int* value) = 0; - TVM_DLL virtual void Visit(const char* key, bool* value) = 0; - TVM_DLL virtual void Visit(const char* key, std::string* value) = 0; - TVM_DLL virtual void Visit(const char* key, void** value) = 0; - TVM_DLL virtual void Visit(const char* key, DataType* value) = 0; - TVM_DLL virtual void Visit(const char* key, runtime::NDArray* value) = 0; - TVM_DLL virtual void Visit(const char* key, ffi::ObjectRef* value) = 0; - TVM_DLL virtual void Visit(const char* key, Optional* value) = 0; - TVM_DLL virtual void Visit(const char* key, Optional* value) = 0; - - template ::value>::type> - void Visit(const char* key, ENum* ptr) { - static_assert(std::is_same::type>::value, - "declare enum to be enum int to use visitor"); - this->Visit(key, reinterpret_cast(ptr)); - } - //! \endcond -}; - /*! * \brief Virtual function table to support IR/AST node reflection. * @@ -84,13 +51,6 @@ class AttrVisitor { */ class ReflectionVTable { public: - /*! - * \brief Visitor function. - * \note We use function pointer, instead of std::function - * to reduce the dispatch overhead as field visit - * does not need as much customization. - */ - typedef void (*FVisitAttrs)(Object* self, AttrVisitor* visitor); /*! * \brief Equality comparison function. */ @@ -112,12 +72,6 @@ class ReflectionVTable { * \return bytes The bytes that can be used to recover the object. */ typedef std::string (*FReprBytes)(const Object* self); - /*! - * \brief Dispatch the VisitAttrs function. - * \param self The pointer to the object. - * \param visitor The attribute visitor. - */ - inline void VisitAttrs(Object* self, AttrVisitor* visitor) const; /*! * \brief Get repr bytes if any. * \param self The pointer to the object. @@ -188,8 +142,6 @@ class ReflectionVTable { inline Registry Register(); private: - /*! \brief Attribute visitor. */ - std::vector fvisit_attrs_; /*! \brief Structural equal function. */ std::vector fsequal_reduce_; /*! \brief Structural hash function. */ @@ -237,14 +189,14 @@ class ReflectionVTable::Registry { /*! * \brief Directly register reflection VTable. * \param TypeName The name of the type. - * \param TraitName A trait class that implements functions like VisitAttrs and SEqualReduce. + * \param TraitName A trait class that implements functions like SEqualReduce. * * \code * * // Example SEQualReduce traits for runtime StringObj. * * struct StringObjTrait { - * static constexpr const std::nullptr_t VisitAttrs = nullptr; + * * * static void SHashReduce(const StringObj* key, SHashReducer hash_reduce) { * hash_reduce->SHashReduceHashedValue(String::StableHashBytes(key->data, key->size)); @@ -286,16 +238,6 @@ class ReflectionVTable::Registry { // Implementation details namespace detail { -template -struct ImplVisitAttrs { - static constexpr const std::nullptr_t VisitAttrs = nullptr; -}; - -template -struct ImplVisitAttrs { - static void VisitAttrs(T* self, AttrVisitor* v) { self->VisitAttrs(v); } -}; - template struct ImplSEqualReduce { static constexpr const std::nullptr_t SEqualReduce = nullptr; @@ -321,22 +263,7 @@ struct ImplSHashReduce { }; template -struct ReflectionTrait : public ImplVisitAttrs, - public ImplSEqualReduce, - public ImplSHashReduce {}; - -template ::value> -struct SelectVisitAttrs { - static constexpr const std::nullptr_t VisitAttrs = nullptr; -}; - -template -struct SelectVisitAttrs { - static void VisitAttrs(Object* self, AttrVisitor* v) { - TraitName::VisitAttrs(static_cast(self), v); - } -}; +struct ReflectionTrait : public ImplSEqualReduce, public ImplSHashReduce {}; template ::value> @@ -370,16 +297,13 @@ struct SelectSHashReduce { template inline ReflectionVTable::Registry ReflectionVTable::Register() { uint32_t tindex = T::RuntimeTypeIndex(); - if (tindex >= fvisit_attrs_.size()) { - fvisit_attrs_.resize(tindex + 1, nullptr); + if (tindex >= fcreate_.size()) { fcreate_.resize(tindex + 1, nullptr); frepr_bytes_.resize(tindex + 1, nullptr); fsequal_reduce_.resize(tindex + 1, nullptr); fshash_reduce_.resize(tindex + 1, nullptr); } // functor that implements the redirection. - fvisit_attrs_[tindex] = ::tvm::detail::SelectVisitAttrs::VisitAttrs; - fsequal_reduce_[tindex] = ::tvm::detail::SelectSEqualReduce::SEqualReduce; fshash_reduce_[tindex] = ::tvm::detail::SelectSHashReduce::SHashReduce; @@ -387,14 +311,6 @@ inline ReflectionVTable::Registry ReflectionVTable::Register() { return Registry(this, tindex); } -inline void ReflectionVTable::VisitAttrs(Object* self, AttrVisitor* visitor) const { - uint32_t tindex = self->type_index(); - if (tindex >= fvisit_attrs_.size() || fvisit_attrs_[tindex] == nullptr) { - return; - } - fvisit_attrs_[tindex](self, visitor); -} - inline bool ReflectionVTable::GetReprBytes(const Object* self, std::string* repr_bytes) const { uint32_t tindex = self->type_index(); if (tindex < frepr_bytes_.size() && frepr_bytes_[tindex] != nullptr) { diff --git a/include/tvm/node/structural_hash.h b/include/tvm/node/structural_hash.h index ca547b17149a..c909b9dfc7e3 100644 --- a/include/tvm/node/structural_hash.h +++ b/include/tvm/node/structural_hash.h @@ -302,7 +302,6 @@ class SHashHandlerDefault : public SHashReducer::Handler { class SEqualReducer; struct NDArrayContainerTrait { - static constexpr const std::nullptr_t VisitAttrs = nullptr; static void SHashReduce(const runtime::NDArray::Container* key, SHashReducer hash_reduce); static bool SEqualReduce(const runtime::NDArray::Container* lhs, const runtime::NDArray::Container* rhs, SEqualReducer equal); diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index c47385375fd7..b131decc8ae6 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -69,7 +69,7 @@ class DocNode : public Object { } static constexpr const char* _type_key = "script.printer.Doc"; - static constexpr const bool _type_has_method_visit_attrs = false; + TVM_DECLARE_BASE_OBJECT_INFO(DocNode, Object); public: @@ -132,7 +132,7 @@ class ExprDocNode : public DocNode { } static constexpr const char* _type_key = "script.printer.ExprDoc"; - static constexpr const bool _type_has_method_visit_attrs = false; + TVM_DECLARE_BASE_OBJECT_INFO(ExprDocNode, DocNode); }; @@ -178,7 +178,7 @@ class StmtDocNode : public DocNode { } static constexpr const char* _type_key = "script.printer.StmtDoc"; - static constexpr const bool _type_has_method_visit_attrs = false; + TVM_DECLARE_BASE_OBJECT_INFO(StmtDocNode, DocNode); }; @@ -212,7 +212,7 @@ class StmtBlockDocNode : public DocNode { } static constexpr const char* _type_key = "script.printer.StmtBlockDoc"; - static constexpr const bool _type_has_method_visit_attrs = false; + TVM_DECLARE_FINAL_OBJECT_INFO(StmtBlockDocNode, DocNode); }; @@ -254,7 +254,7 @@ class LiteralDocNode : public ExprDocNode { } static constexpr const char* _type_key = "script.printer.LiteralDoc"; - static constexpr const bool _type_has_method_visit_attrs = false; + TVM_DECLARE_FINAL_OBJECT_INFO(LiteralDocNode, ExprDocNode); }; @@ -344,7 +344,7 @@ class IdDocNode : public ExprDocNode { } static constexpr const char* _type_key = "script.printer.IdDoc"; - static constexpr const bool _type_has_method_visit_attrs = false; + TVM_DECLARE_FINAL_OBJECT_INFO(IdDocNode, ExprDocNode); }; @@ -384,7 +384,7 @@ class AttrAccessDocNode : public ExprDocNode { } static constexpr const char* _type_key = "script.printer.AttrAccessDoc"; - static constexpr const bool _type_has_method_visit_attrs = false; + TVM_DECLARE_FINAL_OBJECT_INFO(AttrAccessDocNode, ExprDocNode); }; @@ -430,7 +430,7 @@ class IndexDocNode : public ExprDocNode { } static constexpr const char* _type_key = "script.printer.IndexDoc"; - static constexpr const bool _type_has_method_visit_attrs = false; + TVM_DECLARE_FINAL_OBJECT_INFO(IndexDocNode, ExprDocNode); }; @@ -481,7 +481,7 @@ class CallDocNode : public ExprDocNode { } static constexpr const char* _type_key = "script.printer.CallDoc"; - static constexpr const bool _type_has_method_visit_attrs = false; + TVM_DECLARE_FINAL_OBJECT_INFO(CallDocNode, ExprDocNode); }; @@ -565,7 +565,7 @@ class OperationDocNode : public ExprDocNode { } static constexpr const char* _type_key = "script.printer.OperationDoc"; - static constexpr const bool _type_has_method_visit_attrs = false; + TVM_DECLARE_FINAL_OBJECT_INFO(OperationDocNode, ExprDocNode); }; @@ -608,7 +608,7 @@ class LambdaDocNode : public ExprDocNode { } static constexpr const char* _type_key = "script.printer.LambdaDoc"; - static constexpr const bool _type_has_method_visit_attrs = false; + TVM_DECLARE_FINAL_OBJECT_INFO(LambdaDocNode, ExprDocNode); }; @@ -644,7 +644,7 @@ class TupleDocNode : public ExprDocNode { } static constexpr const char* _type_key = "script.printer.TupleDoc"; - static constexpr const bool _type_has_method_visit_attrs = false; + TVM_DECLARE_FINAL_OBJECT_INFO(TupleDocNode, ExprDocNode); }; @@ -683,7 +683,7 @@ class ListDocNode : public ExprDocNode { } static constexpr const char* _type_key = "script.printer.ListDoc"; - static constexpr const bool _type_has_method_visit_attrs = false; + TVM_DECLARE_FINAL_OBJECT_INFO(ListDocNode, ExprDocNode); }; @@ -731,7 +731,7 @@ class DictDocNode : public ExprDocNode { } static constexpr const char* _type_key = "script.printer.DictDoc"; - static constexpr const bool _type_has_method_visit_attrs = false; + TVM_DECLARE_FINAL_OBJECT_INFO(DictDocNode, ExprDocNode); }; @@ -780,7 +780,7 @@ class SliceDocNode : public DocNode { } static constexpr const char* _type_key = "script.printer.SliceDoc"; - static constexpr const bool _type_has_method_visit_attrs = false; + TVM_DECLARE_FINAL_OBJECT_INFO(SliceDocNode, DocNode); }; @@ -828,7 +828,7 @@ class AssignDocNode : public StmtDocNode { } static constexpr const char* _type_key = "script.printer.AssignDoc"; - static constexpr const bool _type_has_method_visit_attrs = false; + TVM_DECLARE_FINAL_OBJECT_INFO(AssignDocNode, StmtDocNode); }; @@ -872,7 +872,7 @@ class IfDocNode : public StmtDocNode { } static constexpr const char* _type_key = "script.printer.IfDoc"; - static constexpr const bool _type_has_method_visit_attrs = false; + TVM_DECLARE_FINAL_OBJECT_INFO(IfDocNode, StmtDocNode); }; @@ -913,7 +913,7 @@ class WhileDocNode : public StmtDocNode { } static constexpr const char* _type_key = "script.printer.WhileDoc"; - static constexpr const bool _type_has_method_visit_attrs = false; + TVM_DECLARE_FINAL_OBJECT_INFO(WhileDocNode, StmtDocNode); }; @@ -960,7 +960,7 @@ class ForDocNode : public StmtDocNode { } static constexpr const char* _type_key = "script.printer.ForDoc"; - static constexpr const bool _type_has_method_visit_attrs = false; + TVM_DECLARE_FINAL_OBJECT_INFO(ForDocNode, StmtDocNode); }; @@ -1009,7 +1009,7 @@ class ScopeDocNode : public StmtDocNode { } static constexpr const char* _type_key = "script.printer.ScopeDoc"; - static constexpr const bool _type_has_method_visit_attrs = false; + TVM_DECLARE_FINAL_OBJECT_INFO(ScopeDocNode, StmtDocNode); }; @@ -1054,7 +1054,7 @@ class ExprStmtDocNode : public StmtDocNode { } static constexpr const char* _type_key = "script.printer.ExprStmtDoc"; - static constexpr const bool _type_has_method_visit_attrs = false; + TVM_DECLARE_FINAL_OBJECT_INFO(ExprStmtDocNode, StmtDocNode); }; @@ -1093,7 +1093,7 @@ class AssertDocNode : public StmtDocNode { } static constexpr const char* _type_key = "script.printer.AssertDoc"; - static constexpr const bool _type_has_method_visit_attrs = false; + TVM_DECLARE_FINAL_OBJECT_INFO(AssertDocNode, StmtDocNode); }; @@ -1129,7 +1129,7 @@ class ReturnDocNode : public StmtDocNode { } static constexpr const char* _type_key = "script.printer.ReturnDoc"; - static constexpr const bool _type_has_method_visit_attrs = false; + TVM_DECLARE_FINAL_OBJECT_INFO(ReturnDocNode, StmtDocNode); }; @@ -1183,7 +1183,7 @@ class FunctionDocNode : public StmtDocNode { } static constexpr const char* _type_key = "script.printer.FunctionDoc"; - static constexpr const bool _type_has_method_visit_attrs = false; + TVM_DECLARE_FINAL_OBJECT_INFO(FunctionDocNode, StmtDocNode); }; @@ -1230,7 +1230,7 @@ class ClassDocNode : public StmtDocNode { } static constexpr const char* _type_key = "script.printer.ClassDoc"; - static constexpr const bool _type_has_method_visit_attrs = false; + TVM_DECLARE_FINAL_OBJECT_INFO(ClassDocNode, StmtDocNode); }; diff --git a/include/tvm/script/printer/ir_docsifier.h b/include/tvm/script/printer/ir_docsifier.h index 338b65e1cf6f..9e11f052aa03 100644 --- a/include/tvm/script/printer/ir_docsifier.h +++ b/include/tvm/script/printer/ir_docsifier.h @@ -59,7 +59,7 @@ class FrameNode : public Object { } static constexpr const char* _type_key = "script.printer.Frame"; - static constexpr const bool _type_has_method_visit_attrs = false; + TVM_DECLARE_BASE_OBJECT_INFO(FrameNode, Object); public: @@ -163,7 +163,7 @@ class IRDocsifierNode : public Object { } static constexpr const char* _type_key = "script.printer.IRDocsifier"; - static constexpr const bool _type_has_method_visit_attrs = false; + TVM_DECLARE_FINAL_OBJECT_INFO(IRDocsifierNode, Object); public: diff --git a/include/tvm/target/tag.h b/include/tvm/target/tag.h index 00542d43dce1..44c475b5f93c 100644 --- a/include/tvm/target/tag.h +++ b/include/tvm/target/tag.h @@ -49,7 +49,7 @@ class TargetTagNode : public Object { } static constexpr const char* _type_key = "target.TargetTag"; - static constexpr const bool _type_has_method_visit_attrs = false; + TVM_DECLARE_FINAL_OBJECT_INFO(TargetTagNode, Object); private: diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index cc334b785428..f1f354a43e0a 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -177,7 +177,7 @@ class TargetNode : public Object { void SHashReduce(SHashReducer hash_reduce) const; static constexpr const char* _type_key = "target.Target"; - static constexpr const bool _type_has_method_visit_attrs = false; + static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(TargetNode, Object); diff --git a/include/tvm/target/target_info.h b/include/tvm/target/target_info.h index 552152dbde87..1600b3edde0f 100644 --- a/include/tvm/target/target_info.h +++ b/include/tvm/target/target_info.h @@ -59,7 +59,7 @@ class MemoryInfoNode : public Object { } static constexpr const char* _type_key = "target.MemoryInfo"; - static constexpr const bool _type_has_method_visit_attrs = false; + TVM_DECLARE_FINAL_OBJECT_INFO(MemoryInfoNode, Object); }; diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h index 3a451832499f..0f1faa3442b6 100644 --- a/include/tvm/target/target_kind.h +++ b/include/tvm/target/target_kind.h @@ -86,7 +86,7 @@ class TargetKindNode : public Object { } static constexpr const char* _type_key = "target.TargetKind"; - static constexpr const bool _type_has_method_visit_attrs = false; + TVM_DECLARE_FINAL_OBJECT_INFO(TargetKindNode, Object); private: diff --git a/include/tvm/te/operation.h b/include/tvm/te/operation.h index abf52a2528a1..498deb516d42 100644 --- a/include/tvm/te/operation.h +++ b/include/tvm/te/operation.h @@ -92,7 +92,6 @@ class TVM_DLL OperationNode : public Object { } static constexpr const char* _type_key = "te.Operation"; - static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_BASE_OBJECT_INFO(OperationNode, Object); }; @@ -120,7 +119,7 @@ class PlaceholderOpNode : public OperationNode { } static constexpr const char* _type_key = "te.PlaceholderOp"; - static constexpr const bool _type_has_method_visit_attrs = false; + TVM_DECLARE_BASE_OBJECT_INFO(PlaceholderOpNode, OperationNode); }; @@ -156,7 +155,7 @@ class TVM_DLL BaseComputeOpNode : public OperationNode { } static constexpr const char* _type_key = "te.BaseComputeOp"; - static constexpr const bool _type_has_method_visit_attrs = false; + TVM_DECLARE_BASE_OBJECT_INFO(BaseComputeOpNode, OperationNode); }; @@ -180,7 +179,7 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode { } static constexpr const char* _type_key = "te.ComputeOp"; - static constexpr const bool _type_has_method_visit_attrs = false; + TVM_DECLARE_FINAL_OBJECT_INFO(ComputeOpNode, BaseComputeOpNode); }; @@ -245,7 +244,7 @@ class ScanOpNode : public OperationNode { } static constexpr const char* _type_key = "te.ScanOp"; - static constexpr const bool _type_has_method_visit_attrs = false; + TVM_DECLARE_FINAL_OBJECT_INFO(ScanOpNode, OperationNode); }; @@ -294,7 +293,7 @@ class ExternOpNode : public OperationNode { } static constexpr const char* _type_key = "te.ExternOp"; - static constexpr const bool _type_has_method_visit_attrs = false; + TVM_DECLARE_FINAL_OBJECT_INFO(ExternOpNode, OperationNode); }; diff --git a/include/tvm/te/tensor.h b/include/tvm/te/tensor.h index 56dce360ccf1..e35b9069a073 100644 --- a/include/tvm/te/tensor.h +++ b/include/tvm/te/tensor.h @@ -88,7 +88,7 @@ class TensorNode : public DataProducerNode { TVM_DLL String GetNameHint() const final; static constexpr const char* _type_key = "te.Tensor"; - static constexpr const bool _type_has_method_visit_attrs = false; + TVM_DECLARE_FINAL_OBJECT_INFO(TensorNode, DataProducerNode); }; diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 6e283575c67c..cd70212d6d19 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -55,8 +55,6 @@ class StringImmNode : public PrimExprNode { /*! \brief The constant value content. */ String value; - static constexpr const bool _type_has_method_visit_attrs = false; - static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("value", &StringImmNode::value); @@ -92,8 +90,6 @@ class CastNode : public PrimExprNode { /*! \brief Original data type. */ PrimExpr value; - static constexpr const bool _type_has_method_visit_attrs = false; - static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("value", &CastNode::value); @@ -135,8 +131,6 @@ class BinaryOpNode : public PrimExprNode { /*! \brief The right operand. */ PrimExpr b; - static constexpr const bool _type_has_method_visit_attrs = false; - static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("a", &T::a).def_ro("b", &T::b); @@ -327,8 +321,6 @@ class CmpOpNode : public PrimExprNode { /*! \brief The right operand. */ PrimExpr b; - static constexpr const bool _type_has_method_visit_attrs = false; - static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("a", &T::a).def_ro("b", &T::b); @@ -457,8 +449,6 @@ class AndNode : public PrimExprNode { /*! \brief The right operand. */ PrimExpr b; - static constexpr const bool _type_has_method_visit_attrs = false; - static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("a", &AndNode::a).def_ro("b", &AndNode::b); @@ -497,8 +487,6 @@ class OrNode : public PrimExprNode { /*! \brief The right operand. */ PrimExpr b; - static constexpr const bool _type_has_method_visit_attrs = false; - static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("a", &OrNode::a).def_ro("b", &OrNode::b); @@ -535,8 +523,6 @@ class NotNode : public PrimExprNode { /*! \brief The input operand. */ PrimExpr a; - static constexpr const bool _type_has_method_visit_attrs = false; - static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("a", &NotNode::a); @@ -582,8 +568,6 @@ class SelectNode : public PrimExprNode { /*! \brief value to be returned when condition is false. */ PrimExpr false_value; - static constexpr const bool _type_has_method_visit_attrs = false; - static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() @@ -1002,8 +986,6 @@ class CommReducerNode : public Object { */ mutable Span span; - static constexpr const bool _type_has_method_visit_attrs = false; - static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h index b40a82e1cfdc..24c7e6944d04 100644 --- a/include/tvm/tir/var.h +++ b/include/tvm/tir/var.h @@ -286,8 +286,6 @@ class IterVarNode : public PrimExprConvertibleNode { PrimExpr ToPrimExpr() const final { return var; } - static constexpr const bool _type_has_method_visit_attrs = false; - static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() diff --git a/python/tvm/ir/attrs.py b/python/tvm/ir/attrs.py index d8d6188e155c..e7de1a9f909b 100644 --- a/python/tvm/ir/attrs.py +++ b/python/tvm/ir/attrs.py @@ -30,26 +30,6 @@ class Attrs(Object): Attrs is passed as the first argument to these functions. """ - def list_field_info(self): - """Get fields information - - Returns - ------- - infos: list of AttrFieldInfo - List of field information - """ - return _ffi_api.AttrsListFieldInfo(self) - - def keys(self): - """Get list of names in the attribute. - - Returns - ------- - keys : list of str - List of keys - """ - return [field.name for field in self.list_field_info()] - def get_int_tuple(self, key): """Get a python int tuple of a key diff --git a/python/tvm/topi/gpu/scan.py b/python/tvm/topi/gpu/scan.py index f45702c6341f..5be4033e4575 100644 --- a/python/tvm/topi/gpu/scan.py +++ b/python/tvm/topi/gpu/scan.py @@ -40,6 +40,7 @@ def _can_use_scan_thrust(binop): target = tvm.target.Target.current() if target is None: return False + # pylint: disable=comparison-with-callable return binop == tvm.tir.generic.add and any( [ can_use_thrust(target, "tvm.contrib.thrust.sum_scan"), diff --git a/src/contrib/msc/core/ir/graph_builder.h b/src/contrib/msc/core/ir/graph_builder.h index ecb2639fea32..1a478956c2bf 100644 --- a/src/contrib/msc/core/ir/graph_builder.h +++ b/src/contrib/msc/core/ir/graph_builder.h @@ -107,7 +107,7 @@ struct MSCRBuildConfig { } }; -class AttrGetter : private AttrVisitor { +class AttrGetter { public: /*! * \brief Get the attributes as Map @@ -116,16 +116,18 @@ class AttrGetter : private AttrVisitor { explicit AttrGetter(Map* attrs) : attrs_(attrs) {} void operator()(const Attrs& attrs) { - // dispatch between new reflection and old reflection - const TVMFFITypeInfo* attrs_tinfo = TVMFFIGetTypeInfo(attrs->type_index()); - if (attrs_tinfo->extra_info != nullptr) { - tvm::ffi::reflection::ForEachFieldInfo(attrs_tinfo, [&](const TVMFFIFieldInfo* field_info) { - Any field_value = tvm::ffi::reflection::FieldGetter(field_info)(attrs); - this->VisitAny(String(field_info->name), field_value); - }); + if (const auto* dict_attrs = attrs.as()) { + for (const auto& [key, value] : dict_attrs->dict) { + this->VisitAny(key, value); + } } else { - // TODO(tvm-team): remove this once all objects are transitioned to the new reflection - const_cast(attrs.get())->VisitAttrs(this); + const TVMFFITypeInfo* attrs_tinfo = TVMFFIGetTypeInfo(attrs->type_index()); + if (attrs_tinfo->extra_info != nullptr) { + tvm::ffi::reflection::ForEachFieldInfo(attrs_tinfo, [&](const TVMFFIFieldInfo* field_info) { + Any field_value = tvm::ffi::reflection::FieldGetter(field_info)(attrs); + this->VisitAny(String(field_info->name), field_value); + }); + } } } @@ -167,50 +169,6 @@ class AttrGetter : private AttrVisitor { } } - void Visit(const char* key, double* value) final { attrs_->Set(key, std::to_string(*value)); } - - void Visit(const char* key, int64_t* value) final { attrs_->Set(key, std::to_string(*value)); } - - void Visit(const char* key, uint64_t* value) final { attrs_->Set(key, std::to_string(*value)); } - - void Visit(const char* key, int* value) final { attrs_->Set(key, std::to_string(*value)); } - - void Visit(const char* key, bool* value) final { attrs_->Set(key, std::to_string(*value)); } - - void Visit(const char* key, std::string* value) final { attrs_->Set(key, *value); } - - void Visit(const char* key, Optional* value) final { - if (value->has_value()) { - attrs_->Set(key, std::to_string(value->value())); - } else { - attrs_->Set(key, ""); - } - } - - void Visit(const char* key, Optional* value) final { - if (value->has_value()) { - attrs_->Set(key, std::to_string(value->value())); - } else { - attrs_->Set(key, ""); - } - } - - void Visit(const char* key, DataType* value) final { - attrs_->Set(key, runtime::DLDataTypeToString(*value)); - } - - void Visit(const char* key, runtime::ObjectRef* value) final { - attrs_->Set(key, StringUtils::ToString(*value)); - } - - void Visit(const char* key, void** value) final { - LOG(FATAL) << "TypeError: void is not allowed in Attrs"; - } - - void Visit(const char* key, runtime::NDArray* value) final { - LOG(FATAL) << "TypeError: NDArray is not allowed in Attrs"; - } - private: Map* attrs_; }; diff --git a/src/ir/attrs.cc b/src/ir/attrs.cc index a1bfc5783ee7..ff19cc0c03e9 100644 --- a/src/ir/attrs.cc +++ b/src/ir/attrs.cc @@ -28,9 +28,10 @@ namespace tvm { -void DictAttrsNode::VisitAttrs(AttrVisitor* v) { v->Visit("__dict__", &dict); } - -void DictAttrsNode::VisitNonDefaultAttrs(AttrVisitor* v) { v->Visit("__dict__", &dict); } +TVM_FFI_STATIC_INIT_BLOCK({ + AttrFieldInfoNode::RegisterReflection(); + DictAttrsNode::RegisterReflection(); +}); DictAttrs WithAttrs(DictAttrs attrs, Map new_attrs) { if (new_attrs.empty()) { @@ -62,8 +63,6 @@ void DictAttrsNode::InitByPackedArgs(const ffi::PackedArgs& args, bool allow_unk } } -Array DictAttrsNode::ListFieldInfo() const { return {}; } - DictAttrs::DictAttrs(Map dict) { ObjectPtr n = make_object(); n->dict = std::move(dict); @@ -80,8 +79,4 @@ TVM_FFI_REGISTER_GLOBAL("ir.DictAttrsGetDict").set_body_typed([](DictAttrs attrs return attrs->dict; }); -TVM_FFI_REGISTER_GLOBAL("ir.AttrsListFieldInfo").set_body_typed([](Attrs attrs) { - return attrs->ListFieldInfo(); -}); - } // namespace tvm diff --git a/src/ir/transform.cc b/src/ir/transform.cc index e3a6d886945a..159c1f2e5aa8 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -311,7 +311,7 @@ IRModule Pass::operator()(IRModule mod, const PassContext& pass_ctx) const { ret = node->operator()(std::move(mod), pass_ctx); } pass_ctx.InstrumentAfterPass(ret, pass_info); - return std::move(ret); + return ret; } IRModule Pass::AssertImmutableModule(const IRModule& mod, const PassNode* node, @@ -325,7 +325,7 @@ IRModule Pass::AssertImmutableModule(const IRModule& mod, const PassNode* node, // must be very low. LOG_FATAL << "Immutable module has been modified in pass: " << node->Info()->name; } - return std::move(ret); + return ret; } /*! diff --git a/src/meta_schedule/search_strategy/replay_func.cc b/src/meta_schedule/search_strategy/replay_func.cc index e5dbc27f6411..6bfe2927c6e7 100644 --- a/src/meta_schedule/search_strategy/replay_func.cc +++ b/src/meta_schedule/search_strategy/replay_func.cc @@ -64,8 +64,6 @@ class ReplayFuncNode : public SearchStrategyNode { // No fields to register } - static constexpr const bool _type_has_method_visit_attrs = false; - static constexpr const char* _type_key = "meta_schedule.ReplayFunc"; TVM_DECLARE_FINAL_OBJECT_INFO(ReplayFuncNode, SearchStrategyNode); diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc index f6ad9e3770d0..ae55bc58f16e 100644 --- a/src/meta_schedule/search_strategy/replay_trace.cc +++ b/src/meta_schedule/search_strategy/replay_trace.cc @@ -80,8 +80,6 @@ class ReplayTraceNode : public SearchStrategyNode { refl::ObjectDef().def_ro("max_fail_count", &ReplayTraceNode::max_fail_count); } - static constexpr const bool _type_has_method_visit_attrs = false; - static constexpr const char* _type_key = "meta_schedule.ReplayTrace"; TVM_DECLARE_FINAL_OBJECT_INFO(ReplayTraceNode, SearchStrategyNode); diff --git a/src/meta_schedule/space_generator/post_order_apply.cc b/src/meta_schedule/space_generator/post_order_apply.cc index d716e4f3e488..a6baaa78d53d 100644 --- a/src/meta_schedule/space_generator/post_order_apply.cc +++ b/src/meta_schedule/space_generator/post_order_apply.cc @@ -40,8 +40,6 @@ class PostOrderApplyNode : public SpaceGeneratorNode { // No fields to register } - static constexpr const bool _type_has_method_visit_attrs = false; - void InitializeWithTuneContext(const TuneContext& context) final { SpaceGeneratorNode::InitializeWithTuneContext(context); this->rand_state_ = ForkSeed(&context->rand_state); diff --git a/src/meta_schedule/space_generator/schedule_fn.cc b/src/meta_schedule/space_generator/schedule_fn.cc index 06695cbb3f2c..0b52c58449b4 100644 --- a/src/meta_schedule/space_generator/schedule_fn.cc +++ b/src/meta_schedule/space_generator/schedule_fn.cc @@ -35,8 +35,6 @@ class ScheduleFnNode : public SpaceGeneratorNode { // `schedule_fn_` is not registered. } - static constexpr const bool _type_has_method_visit_attrs = false; - void InitializeWithTuneContext(const TuneContext& context) final { SpaceGeneratorNode::InitializeWithTuneContext(context); this->rand_state_ = ForkSeed(&context->rand_state); diff --git a/src/meta_schedule/space_generator/space_generator_union.cc b/src/meta_schedule/space_generator/space_generator_union.cc index 24464ad31e31..d1f22e013b0d 100644 --- a/src/meta_schedule/space_generator/space_generator_union.cc +++ b/src/meta_schedule/space_generator/space_generator_union.cc @@ -35,8 +35,6 @@ class SpaceGeneratorUnionNode : public SpaceGeneratorNode { &SpaceGeneratorUnionNode::space_generators); } - static constexpr const bool _type_has_method_visit_attrs = false; - void InitializeWithTuneContext(const TuneContext& context) final { SpaceGeneratorNode::InitializeWithTuneContext(context); for (const SpaceGenerator& space_generator : space_generators) { diff --git a/src/node/reflection.cc b/src/node/reflection.cc index 13c1c96924a3..8dfb1ebfc7bc 100644 --- a/src/node/reflection.cc +++ b/src/node/reflection.cc @@ -33,75 +33,10 @@ using ffi::Any; using ffi::Function; using ffi::PackedArgs; -// Attr getter. -class AttrGetter : public AttrVisitor { - public: - const String& skey; - ffi::Any* ret; - - AttrGetter(const String& skey, ffi::Any* ret) : skey(skey), ret(ret) {} - - bool found_ref_object{false}; - - void Visit(const char* key, double* value) final { - if (skey == key) *ret = value[0]; - } - void Visit(const char* key, int64_t* value) final { - if (skey == key) *ret = value[0]; - } - void Visit(const char* key, uint64_t* value) final { - ICHECK_LE(value[0], static_cast(std::numeric_limits::max())) - << "cannot return too big constant"; - if (skey == key) *ret = static_cast(value[0]); - } - void Visit(const char* key, int* value) final { - if (skey == key) *ret = static_cast(value[0]); - } - void Visit(const char* key, bool* value) final { - if (skey == key) *ret = static_cast(value[0]); - } - void Visit(const char* key, void** value) final { - if (skey == key) *ret = static_cast(value[0]); - } - void Visit(const char* key, DataType* value) final { - if (skey == key) *ret = value[0]; - } - void Visit(const char* key, std::string* value) final { - if (skey == key) *ret = value[0]; - } - void Visit(const char* key, Optional* value) final { - if (skey == key) { - *ret = value[0]; - found_ref_object = true; - } - } - void Visit(const char* key, Optional* value) final { - if (skey == key) { - *ret = value[0]; - found_ref_object = true; - } - } - - void Visit(const char* key, runtime::NDArray* value) final { - if (skey == key) { - *ret = value[0]; - found_ref_object = true; - } - } - void Visit(const char* key, runtime::ObjectRef* value) final { - if (skey == key) { - *ret = value[0]; - found_ref_object = true; - } - } -}; - ffi::Any ReflectionVTable::GetAttr(Object* self, const String& field_name) const { ffi::Any ret; - AttrGetter getter(field_name, &ret); - bool success; - if (getter.skey == "type_key") { + if (field_name == "type_key") { ret = self->GetTypeKey(); success = true; } else if (!self->IsInstance()) { @@ -116,15 +51,11 @@ ffi::Any ReflectionVTable::GetAttr(Object* self, const String& field_name) const success = true; } }); - } else { - // legacy reflection mechanism, will be phased out in the future - VisitAttrs(self, &getter); - success = getter.found_ref_object || ret != nullptr; } } else { // specially handle dict attr DictAttrsNode* dnode = static_cast(self); - auto it = dnode->dict.find(getter.skey); + auto it = dnode->dict.find(field_name); if (it != dnode->dict.end()) { success = true; ret = (*it).second; @@ -134,34 +65,13 @@ ffi::Any ReflectionVTable::GetAttr(Object* self, const String& field_name) const } if (!success) { LOG(FATAL) << "AttributeError: " << self->GetTypeKey() << " object has no attributed " - << getter.skey; + << field_name; } return ret; } -// List names; -class AttrDir : public AttrVisitor { - public: - std::vector* names; - - void Visit(const char* key, double* value) final { names->push_back(key); } - void Visit(const char* key, int64_t* value) final { names->push_back(key); } - void Visit(const char* key, uint64_t* value) final { names->push_back(key); } - void Visit(const char* key, bool* value) final { names->push_back(key); } - void Visit(const char* key, int* value) final { names->push_back(key); } - void Visit(const char* key, void** value) final { names->push_back(key); } - void Visit(const char* key, DataType* value) final { names->push_back(key); } - void Visit(const char* key, std::string* value) final { names->push_back(key); } - void Visit(const char* key, runtime::NDArray* value) final { names->push_back(key); } - void Visit(const char* key, runtime::ObjectRef* value) final { names->push_back(key); } - void Visit(const char* key, Optional* value) final { names->push_back(key); } - void Visit(const char* key, Optional* value) final { names->push_back(key); } -}; - std::vector ReflectionVTable::ListAttrNames(Object* self) const { std::vector names; - AttrDir dir; - dir.names = &names; if (!self->IsInstance()) { const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(self->type_index()); @@ -170,9 +80,6 @@ std::vector ReflectionVTable::ListAttrNames(Object* self) const { ffi::reflection::ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* field_info) { names.push_back(std::string(field_info->name.data, field_info->name.size)); }); - } else { - // legacy reflection mechanism, will be phased out in the future - VisitAttrs(self, &dir); } } else { // specially handle dict attr @@ -200,91 +107,26 @@ ObjectPtr ReflectionVTable::CreateInitObject(const std::string& type_key return fcreate_[tindex](repr_bytes); } -class NodeAttrSetter : public AttrVisitor { - public: - std::string type_key; - std::unordered_map attrs; - - void Visit(const char* key, double* value) final { *value = GetAttr(key).cast(); } - void Visit(const char* key, int64_t* value) final { *value = GetAttr(key).cast(); } - void Visit(const char* key, uint64_t* value) final { *value = GetAttr(key).cast(); } - void Visit(const char* key, int* value) final { *value = GetAttr(key).cast(); } - void Visit(const char* key, bool* value) final { *value = GetAttr(key).cast(); } - void Visit(const char* key, std::string* value) final { - *value = GetAttr(key).cast(); - } - void Visit(const char* key, void** value) final { *value = GetAttr(key).cast(); } - void Visit(const char* key, DataType* value) final { *value = GetAttr(key).cast(); } - void Visit(const char* key, runtime::NDArray* value) final { - *value = GetAttr(key).cast(); - } - void Visit(const char* key, ObjectRef* value) final { *value = GetAttr(key).cast(); } - - void Visit(const char* key, Optional* value) final { - *value = GetAttr(key).cast>(); - } - void Visit(const char* key, Optional* value) final { - *value = GetAttr(key).cast>(); - } - - private: - ffi::AnyView GetAttr(const char* key) { - auto it = attrs.find(key); - if (it == attrs.end()) { - LOG(FATAL) << type_key << ": require field " << key; - } - ffi::AnyView v = it->second; - attrs.erase(it); - return v; - } -}; - -void InitNodeByPackedArgs(ReflectionVTable* reflection, Object* n, const ffi::PackedArgs& args) { - NodeAttrSetter setter; - setter.type_key = n->GetTypeKey(); - ICHECK_EQ(args.size() % 2, 0); - for (int i = 0; i < args.size(); i += 2) { - setter.attrs.emplace(args[i].cast(), args[i + 1]); - } - reflection->VisitAttrs(n, &setter); - - if (setter.attrs.size() != 0) { - std::ostringstream os; - os << setter.type_key << " does not contain field "; - for (const auto& kv : setter.attrs) { - os << " " << kv.first; - } - LOG(FATAL) << os.str(); - } -} - ObjectRef ReflectionVTable::CreateObject(const std::string& type_key, const ffi::PackedArgs& kwargs) { - // dispatch between new reflection and old reflection int32_t type_index; TVMFFIByteArray type_key_array = TVMFFIByteArray{type_key.data(), type_key.size()}; TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index)); - const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index); - if (type_info->extra_info != nullptr) { - auto fcreate_object = ffi::Function::GetGlobalRequired("ffi.MakeObjectFromPackedArgs"); - std::vector packed_args(kwargs.size() + 1); - packed_args[0] = type_index; - for (int i = 0; i < kwargs.size(); i++) { - packed_args[i + 1] = kwargs[i]; - } - ffi::Any rv; - fcreate_object.CallPacked(ffi::PackedArgs(packed_args.data(), packed_args.size()), &rv); - return rv.cast(); - } else { - // TODO(tvm-team): remove this once all objects are transitioned to the new reflection + if (type_index == DictAttrsNode::RuntimeTypeIndex()) { ObjectPtr n = this->CreateInitObject(type_key); - if (n->IsInstance()) { - static_cast(n.get())->InitByPackedArgs(kwargs); - } else { - InitNodeByPackedArgs(this, n.get(), kwargs); - } + static_cast(n.get())->InitByPackedArgs(kwargs); return ObjectRef(n); } + // TODO(tvm-team): remove this once all objects are transitioned to the new reflection + auto fcreate_object = ffi::Function::GetGlobalRequired("ffi.MakeObjectFromPackedArgs"); + std::vector packed_args(kwargs.size() + 1); + packed_args[0] = type_index; + for (int i = 0; i < kwargs.size(); i++) { + packed_args[i + 1] = kwargs[i]; + } + ffi::Any rv; + fcreate_object.CallPacked(ffi::PackedArgs(packed_args.data(), packed_args.size()), &rv); + return rv.cast(); } ObjectRef ReflectionVTable::CreateObject(const std::string& type_key, @@ -339,41 +181,7 @@ TVM_FFI_REGISTER_GLOBAL("node.NodeListAttrNames").set_body_packed(NodeListAttrNa TVM_FFI_REGISTER_GLOBAL("node.MakeNode").set_body_packed(MakeNode); -namespace { -// Attribute visitor class for finding the attribute key by its address -class GetAttrKeyByAddressVisitor : public AttrVisitor { - public: - explicit GetAttrKeyByAddressVisitor(const void* attr_address) - : attr_address_(attr_address), key_(nullptr) {} - - void Visit(const char* key, double* value) final { DoVisit(key, value); } - void Visit(const char* key, int64_t* value) final { DoVisit(key, value); } - void Visit(const char* key, uint64_t* value) final { DoVisit(key, value); } - void Visit(const char* key, int* value) final { DoVisit(key, value); } - void Visit(const char* key, bool* value) final { DoVisit(key, value); } - void Visit(const char* key, std::string* value) final { DoVisit(key, value); } - void Visit(const char* key, void** value) final { DoVisit(key, value); } - void Visit(const char* key, DataType* value) final { DoVisit(key, value); } - void Visit(const char* key, runtime::NDArray* value) final { DoVisit(key, value); } - void Visit(const char* key, runtime::ObjectRef* value) final { DoVisit(key, value); } - void Visit(const char* key, Optional* value) final { DoVisit(key, value); } - void Visit(const char* key, Optional* value) final { DoVisit(key, value); } - const char* GetKey() const { return key_; } - - private: - const void* attr_address_; - const char* key_; - - void DoVisit(const char* key, const void* candidate) { - if (attr_address_ == candidate) { - key_ = key; - } - } -}; -} // anonymous namespace - Optional GetAttrKeyByAddress(const Object* object, const void* attr_address) { - // NOTE: reflection dispatch for both new and legacy reflection mechanism const TVMFFITypeInfo* tinfo = TVMFFIGetTypeInfo(object->type_index()); if (tinfo->extra_info != nullptr) { Optional result; @@ -388,17 +196,8 @@ Optional GetAttrKeyByAddress(const Object* object, const void* attr_addr return false; }); return result; - } else { - // TODO(tvm-team): remove this path once all objects are transitioned to the new reflection - GetAttrKeyByAddressVisitor visitor(attr_address); - ReflectionVTable::Global()->VisitAttrs(const_cast(object), &visitor); - const char* key = visitor.GetKey(); - if (key == nullptr) { - return std::nullopt; - } else { - return String(key); - } } + return std::nullopt; } } // namespace tvm diff --git a/src/node/serialization.cc b/src/node/serialization.cc index 116a543f58e1..60b455589ff9 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -63,28 +63,12 @@ inline std::string Base64Encode(std::string s) { } // indexer to index all the nodes -class NodeIndexer : private AttrVisitor { +class NodeIndexer { public: std::unordered_map node_index_{{Any(nullptr), 0}}; std::vector node_list_{Any(nullptr)}; ReflectionVTable* reflection_ = ReflectionVTable::Global(); - void Visit(const char* key, double* value) final {} - void Visit(const char* key, int64_t* value) final {} - void Visit(const char* key, uint64_t* value) final {} - void Visit(const char* key, int* value) final {} - void Visit(const char* key, bool* value) final {} - void Visit(const char* key, std::string* value) final {} - void Visit(const char* key, void** value) final {} - void Visit(const char* key, DataType* value) final {} - - void Visit(const char* key, runtime::NDArray* value) final { MakeIndex(Any(*value)); } - - void Visit(const char* key, Optional* value) final {} - void Visit(const char* key, Optional* value) final {} - - void Visit(const char* key, ObjectRef* value) final { MakeIndex(Any(*value)); } - void MakeNodeIndex(Any node) { if (node == nullptr) return; if (node_index_.count(node)) { @@ -134,18 +118,16 @@ class NodeIndexer : private AttrVisitor { void VisitObjectFields(Object* obj) { const TVMFFITypeInfo* tinfo = TVMFFIGetTypeInfo(obj->type_index()); - if (tinfo->extra_info != nullptr) { - ffi::reflection::ForEachFieldInfo(tinfo, [&](const TVMFFIFieldInfo* field_info) { - Any field_value = ffi::reflection::FieldGetter(field_info)(obj); - // only make index for ObjectRef - if (field_value.as()) { - this->MakeIndex(field_value); - } - }); - } else { - // TODO(tvm-team): remove this once all objects are transitioned to the new reflection - reflection_->VisitAttrs(obj, this); - } + ICHECK(tinfo->extra_info != nullptr) + << "Object `" << obj->GetTypeKey() + << "` misses reflection registration and do not support serialization"; + ffi::reflection::ForEachFieldInfo(tinfo, [&](const TVMFFIFieldInfo* field_info) { + Any field_value = ffi::reflection::FieldGetter(field_info)(obj); + // only make index for ObjectRef + if (field_value.as()) { + this->MakeIndex(field_value); + } + }); } }; @@ -221,49 +203,35 @@ struct JSONNode { // Helper class to populate the json node // using the existing index. -class JSONAttrGetter : private AttrVisitor { +class JSONAttrGetter { public: const std::unordered_map* node_index_; JSONNode* node_; ReflectionVTable* reflection_ = ReflectionVTable::Global(); - void Visit(const char* key, double* value) final { + void Visit(const char* key, double* value) { std::ostringstream s; // Save 17 decimal digits for type to avoid precision loss during loading JSON s.precision(17); s << (*value); node_->attrs[key] = s.str(); } - void Visit(const char* key, int64_t* value) final { node_->attrs[key] = std::to_string(*value); } - void Visit(const char* key, uint64_t* value) final { node_->attrs[key] = std::to_string(*value); } - void Visit(const char* key, int* value) final { node_->attrs[key] = std::to_string(*value); } - void Visit(const char* key, bool* value) final { node_->attrs[key] = std::to_string(*value); } - void Visit(const char* key, std::string* value) final { node_->attrs[key] = *value; } - void Visit(const char* key, void** value) final { - LOG(FATAL) << "not allowed to serialize a pointer"; - } - void Visit(const char* key, DataType* value) final { node_->attrs[key] = Type2String(*value); } - void Visit(const char* key, runtime::NDArray* value) final { - Visit(key, static_cast(value)); - } - - void Visit(const char* key, Optional* value) final { + void Visit(const char* key, int64_t* value) { node_->attrs[key] = std::to_string(*value); } + void Visit(const char* key, uint64_t* value) { node_->attrs[key] = std::to_string(*value); } + void Visit(const char* key, int* value) { node_->attrs[key] = std::to_string(*value); } + void Visit(const char* key, bool* value) { node_->attrs[key] = std::to_string(*value); } + void Visit(const char* key, std::string* value) { node_->attrs[key] = *value; } + void Visit(const char* key, void** value) { LOG(FATAL) << "not allowed to serialize a pointer"; } + void Visit(const char* key, DataType* value) { node_->attrs[key] = Type2String(*value); } + void Visit(const char* key, Optional* value) { if (value->has_value()) { node_->attrs[key] = std::to_string(value->value()); } else { node_->attrs[key] = "null"; } } - void Visit(const char* key, Optional* value) final { - if (value->has_value()) { - double val = **value; - Visit(key, &val); - } else { - node_->attrs[key] = "null"; - } - } - void Visit(const char* key, ObjectRef* value) final { + void Visit(const char* key, ObjectRef* value) { if (value->defined()) { node_->attrs[key] = std::to_string(node_index_->at(Any(*value))); } else { @@ -343,55 +311,53 @@ class JSONAttrGetter : private AttrVisitor { void VisitObjectFields(Object* obj) { // dispatch between new reflection and old reflection const TVMFFITypeInfo* tinfo = TVMFFIGetTypeInfo(obj->type_index()); - if (tinfo->extra_info != nullptr) { - ffi::reflection::ForEachFieldInfo(tinfo, [&](const TVMFFIFieldInfo* field_info) { - Any field_value = ffi::reflection::FieldGetter(field_info)(obj); - String field_name(field_info->name); - switch (field_value.type_index()) { - case ffi::TypeIndex::kTVMFFINone: { - node_->attrs[field_name] = "null"; - break; - } - case ffi::TypeIndex::kTVMFFIBool: - case ffi::TypeIndex::kTVMFFIInt: { - int64_t value = field_value.cast(); - this->Visit(field_info->name.data, &value); - break; - } - case ffi::TypeIndex::kTVMFFIFloat: { - double value = field_value.cast(); - this->Visit(field_info->name.data, &value); - break; - } - case ffi::TypeIndex::kTVMFFIDataType: { - DataType value(field_value.cast()); - this->Visit(field_info->name.data, &value); - break; - } - case ffi::TypeIndex::kTVMFFINDArray: { - runtime::NDArray value = field_value.cast(); - this->Visit(field_info->name.data, &value); + ICHECK(tinfo->extra_info != nullptr) + << "Object `" << obj->GetTypeKey() + << "` misses reflection registration and do not support serialization"; + ffi::reflection::ForEachFieldInfo(tinfo, [&](const TVMFFIFieldInfo* field_info) { + Any field_value = ffi::reflection::FieldGetter(field_info)(obj); + String field_name(field_info->name); + switch (field_value.type_index()) { + case ffi::TypeIndex::kTVMFFINone: { + node_->attrs[field_name] = "null"; + break; + } + case ffi::TypeIndex::kTVMFFIBool: + case ffi::TypeIndex::kTVMFFIInt: { + int64_t value = field_value.cast(); + this->Visit(field_info->name.data, &value); + break; + } + case ffi::TypeIndex::kTVMFFIFloat: { + double value = field_value.cast(); + this->Visit(field_info->name.data, &value); + break; + } + case ffi::TypeIndex::kTVMFFIDataType: { + DataType value(field_value.cast()); + this->Visit(field_info->name.data, &value); + break; + } + case ffi::TypeIndex::kTVMFFINDArray: { + runtime::NDArray value = field_value.cast(); + this->Visit(field_info->name.data, &value); + break; + } + default: { + if (field_value.type_index() >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) { + ObjectRef obj = field_value.cast(); + this->Visit(field_info->name.data, &obj); break; - } - default: { - if (field_value.type_index() >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) { - ObjectRef obj = field_value.cast(); - this->Visit(field_info->name.data, &obj); - break; - } else { - LOG(FATAL) << "Unsupported type: " << field_value.GetTypeKey(); - } + } else { + LOG(FATAL) << "Unsupported type: " << field_value.GetTypeKey(); } } - }); - } else { - // TODO(tvm-team): remove this once all objects are transitioned to the new reflection - reflection_->VisitAttrs(obj, this); - } + } + }); } }; -class FieldDependencyFinder : private AttrVisitor { +class FieldDependencyFinder { public: JSONNode* jnode_; ReflectionVTable* reflection_ = ReflectionVTable::Global(); @@ -423,24 +389,7 @@ class FieldDependencyFinder : private AttrVisitor { *value = temp; } } - void Visit(const char* key, double* value) final {} - void Visit(const char* key, int64_t* value) final {} - void Visit(const char* key, uint64_t* value) final {} - void Visit(const char* key, int* value) final {} - void Visit(const char* key, bool* value) final {} - void Visit(const char* key, std::string* value) final {} - void Visit(const char* key, void** value) final {} - void Visit(const char* key, DataType* value) final {} - void Visit(const char* key, runtime::NDArray* value) final {} - void Visit(const char* key, Optional* value) final {} - void Visit(const char* key, Optional* value) final {} - void Visit(const char* key, ObjectRef* value) final { - Optional index; - ParseOptionalValue(key, &index); - if (index.has_value()) { - jnode_->fields.push_back(*index); - } - } + void Find(Any node, JSONNode* jnode) { // Skip None if (node == nullptr) { @@ -469,27 +418,25 @@ class FieldDependencyFinder : private AttrVisitor { void VisitObjectFields(Object* obj) { // dispatch between new reflection and old reflection const TVMFFITypeInfo* tinfo = TVMFFIGetTypeInfo(obj->type_index()); - if (tinfo->extra_info != nullptr) { - ffi::reflection::ForEachFieldInfo(tinfo, [&](const TVMFFIFieldInfo* field_info) { - if (field_info->field_static_type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin || - field_info->field_static_type_index == ffi::TypeIndex::kTVMFFIAny) { - Optional index; - ParseOptionalValue(field_info->name.data, &index); - if (index.has_value()) { - jnode_->fields.push_back(*index); - } + ICHECK(tinfo->extra_info != nullptr) + << "Object `" << obj->GetTypeKey() + << "` misses reflection registration and do not support serialization"; + ffi::reflection::ForEachFieldInfo(tinfo, [&](const TVMFFIFieldInfo* field_info) { + if (field_info->field_static_type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin || + field_info->field_static_type_index == ffi::TypeIndex::kTVMFFIAny) { + Optional index; + ParseOptionalValue(field_info->name.data, &index); + if (index.has_value()) { + jnode_->fields.push_back(*index); } - }); - } else { - // TODO(tvm-team): remove this once all objects are transitioned to the new reflection - reflection_->VisitAttrs(obj, this); - } + } + }); } }; // Helper class to set the attributes of a node // from given json node. -class JSONAttrSetter : private AttrVisitor { +class JSONAttrSetter { public: const std::vector* node_list_; JSONNode* jnode_; @@ -538,33 +485,33 @@ class JSONAttrSetter : private AttrVisitor { } } - void Visit(const char* key, double* value) final { ParseDouble(key, value); } - void Visit(const char* key, int64_t* value) final { ParseValue(key, value); } - void Visit(const char* key, uint64_t* value) final { ParseValue(key, value); } - void Visit(const char* key, int* value) final { ParseValue(key, value); } - void Visit(const char* key, bool* value) final { ParseValue(key, value); } - void Visit(const char* key, std::string* value) final { *value = GetValue(key); } + void Visit(const char* key, double* value) { ParseDouble(key, value); } + void Visit(const char* key, int64_t* value) { ParseValue(key, value); } + void Visit(const char* key, uint64_t* value) { ParseValue(key, value); } + void Visit(const char* key, int* value) { ParseValue(key, value); } + void Visit(const char* key, bool* value) { ParseValue(key, value); } + void Visit(const char* key, std::string* value) { *value = GetValue(key); } - void Visit(const char* key, Optional* value) final { + void Visit(const char* key, Optional* value) { ParseOptionalValue(key, value, [this](const char* key, double* value) { ParseDouble(key, value); }); } - void Visit(const char* key, Optional* value) final { + void Visit(const char* key, Optional* value) { ParseOptionalValue( key, value, [this](const char* key, int64_t* value) { ParseValue(key, value); }); } - void Visit(const char* key, void** value) final { + void Visit(const char* key, void** value) { LOG(FATAL) << "not allowed to deserialize a pointer"; } - void Visit(const char* key, DataType* value) final { + void Visit(const char* key, DataType* value) { std::string stype = GetValue(key); *value = String2Type(stype); } - void Visit(const char* key, runtime::NDArray* value) final { + void Visit(const char* key, runtime::NDArray* value) { Visit(key, static_cast(value)); } - void Visit(const char* key, ObjectRef* value) final { + void Visit(const char* key, ObjectRef* value) { Optional index; ParseOptionalValue(key, &index, [this](const char* key, int64_t* value) { ParseValue(key, value); }); @@ -645,13 +592,11 @@ class JSONAttrSetter : private AttrVisitor { void SetObjectFields(Object* obj) { // dispatch between new reflection and old reflection const TVMFFITypeInfo* tinfo = TVMFFIGetTypeInfo(obj->type_index()); - if (tinfo->extra_info != nullptr) { - ffi::reflection::ForEachFieldInfo( - tinfo, [&](const TVMFFIFieldInfo* field_info) { this->SetObjectField(obj, field_info); }); - } else { - // TODO(tvm-team): remove this once all objects are transitioned to the new reflection - reflection_->VisitAttrs(obj, this); - } + ICHECK(tinfo->extra_info != nullptr) + << "Object `" << obj->GetTypeKey() + << "` misses reflection registration and do not support serialization"; + ffi::reflection::ForEachFieldInfo( + tinfo, [&](const TVMFFIFieldInfo* field_info) { this->SetObjectField(obj, field_info); }); } void SetObjectField(Object* obj, const TVMFFIFieldInfo* field_info) { diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index 538cbab837c1..1f1d476d5cf3 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -313,8 +313,6 @@ void SHashHandlerIgnoreNDArray::DispatchSHash(const ObjectRef& object, bool map_ // SEQualReduce traits for runtime containers. struct StringObjTrait { - static constexpr const std::nullptr_t VisitAttrs = nullptr; - static void SHashReduce(const ffi::StringObj* key, SHashReducer hash_reduce) { hash_reduce->SHashReduceHashedValue(ffi::details::StableHashBytes(key->data, key->size)); } @@ -329,8 +327,6 @@ struct StringObjTrait { }; struct BytesObjTrait { - static constexpr const std::nullptr_t VisitAttrs = nullptr; - static void SHashReduce(const ffi::BytesObj* key, SHashReducer hash_reduce) { hash_reduce->SHashReduceHashedValue(ffi::details::StableHashBytes(key->data, key->size)); } @@ -375,7 +371,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); struct ModuleNodeTrait { - static constexpr const std::nullptr_t VisitAttrs = nullptr; static constexpr const std::nullptr_t SHashReduce = nullptr; static constexpr const std::nullptr_t SEqualReduce = nullptr; }; @@ -431,8 +426,6 @@ TVM_REGISTER_REFLECTION_VTABLE(runtime::NDArray::Container, NDArrayContainerTrai }); struct ArrayObjTrait { - static constexpr const std::nullptr_t VisitAttrs = nullptr; - static void SHashReduce(const ffi::ArrayObj* key, SHashReducer hash_reduce) { hash_reduce(static_cast(key->size())); for (uint32_t i = 0; i < key->size(); ++i) { @@ -518,8 +511,6 @@ TVM_REGISTER_REFLECTION_VTABLE(ffi::ArrayObj, ArrayObjTrait) }); struct ShapeObjTrait { - static constexpr const std::nullptr_t VisitAttrs = nullptr; - static void SHashReduce(const ffi::ShapeObj* self, SHashReducer hash_reduce) { hash_reduce(static_cast(self->size)); for (uint32_t i = 0; i < self->size; ++i) { @@ -562,8 +553,6 @@ TVM_REGISTER_REFLECTION_VTABLE(ffi::ShapeObj, ShapeObjTrait) }); struct MapObjTrait { - static constexpr const std::nullptr_t VisitAttrs = nullptr; - static void SHashReduceForOMap(const ffi::MapObj* key, SHashReducer hash_reduce) { // SHash's var handling depends on the determinism of traversal. // NOTE: only book-keep the mapped hash keys. @@ -707,8 +696,6 @@ struct ReportNodeTrait { .def_ro("configuration", &runtime::profiling::ReportNode::configuration); } - static constexpr const std::nullptr_t VisitAttrs = nullptr; - static constexpr std::nullptr_t SEqualReduce = nullptr; static constexpr std::nullptr_t SHashReduce = nullptr; }; @@ -728,8 +715,6 @@ struct CountNodeTrait { &runtime::profiling::CountNode::value); } - static constexpr const std::nullptr_t VisitAttrs = nullptr; - static constexpr std::nullptr_t SEqualReduce = nullptr; static constexpr std::nullptr_t SHashReduce = nullptr; }; @@ -750,8 +735,6 @@ struct DurationNodeTrait { "microseconds", &runtime::profiling::DurationNode::microseconds); } - static constexpr const std::nullptr_t VisitAttrs = nullptr; - static constexpr std::nullptr_t SEqualReduce = nullptr; static constexpr std::nullptr_t SHashReduce = nullptr; }; @@ -772,8 +755,6 @@ struct PercentNodeTrait { "percent", &runtime::profiling::PercentNode::percent); } - static constexpr const std::nullptr_t VisitAttrs = nullptr; - static constexpr std::nullptr_t SEqualReduce = nullptr; static constexpr std::nullptr_t SHashReduce = nullptr; }; @@ -794,8 +775,6 @@ struct RatioNodeTrait { &runtime::profiling::RatioNode::ratio); } - static constexpr const std::nullptr_t VisitAttrs = nullptr; - static constexpr std::nullptr_t SEqualReduce = nullptr; static constexpr std::nullptr_t SHashReduce = nullptr; }; diff --git a/src/relax/backend/contrib/codegen_json/codegen_json.h b/src/relax/backend/contrib/codegen_json/codegen_json.h index 9aa693f58b56..b2b0423257fa 100644 --- a/src/relax/backend/contrib/codegen_json/codegen_json.h +++ b/src/relax/backend/contrib/codegen_json/codegen_json.h @@ -57,7 +57,7 @@ using JSONGraphObjectPtr = std::shared_ptr; * \brief Helper class to extract all attributes of a certain op and save them * into text format. */ -class OpAttrExtractor : private AttrVisitor { +class OpAttrExtractor { public: explicit OpAttrExtractor(JSONGraphObjectPtr node) : node_(node) {} @@ -75,19 +75,19 @@ class OpAttrExtractor : private AttrVisitor { node_->SetAttr(key, attr); } - void Visit(const char* key, double* value) final { SetNodeAttr(key, {Fp2String(*value)}); } + void Visit(const char* key, double* value) { SetNodeAttr(key, {Fp2String(*value)}); } - void Visit(const char* key, int64_t* value) final { SetNodeAttr(key, {std::to_string(*value)}); } + void Visit(const char* key, int64_t* value) { SetNodeAttr(key, {std::to_string(*value)}); } - void Visit(const char* key, uint64_t* value) final { SetNodeAttr(key, {std::to_string(*value)}); } + void Visit(const char* key, uint64_t* value) { SetNodeAttr(key, {std::to_string(*value)}); } - void Visit(const char* key, int* value) final { SetNodeAttr(key, {std::to_string(*value)}); } + void Visit(const char* key, int* value) { SetNodeAttr(key, {std::to_string(*value)}); } - void Visit(const char* key, bool* value) final { SetNodeAttr(key, {std::to_string(*value)}); } + void Visit(const char* key, bool* value) { SetNodeAttr(key, {std::to_string(*value)}); } - void Visit(const char* key, std::string* value) final { SetNodeAttr(key, {*value}); } + void Visit(const char* key, std::string* value) { SetNodeAttr(key, {*value}); } - void Visit(const char* key, Optional* value) final { + void Visit(const char* key, Optional* value) { if (value->has_value()) { SetNodeAttr(key, {Fp2String(value->value())}); } else { @@ -95,7 +95,7 @@ class OpAttrExtractor : private AttrVisitor { } } - void Visit(const char* key, Optional* value) final { + void Visit(const char* key, Optional* value) { if (value->has_value()) { SetNodeAttr(key, {std::to_string(value->value())}); } else { @@ -103,7 +103,7 @@ class OpAttrExtractor : private AttrVisitor { } } - void Visit(const char* key, DataType* value) final { + void Visit(const char* key, DataType* value) { if (!value->is_void()) { SetNodeAttr(key, {runtime::DLDataTypeToString(*value)}); } else { @@ -111,7 +111,7 @@ class OpAttrExtractor : private AttrVisitor { } } - void Visit(const char* key, runtime::ObjectRef* value) final { + void Visit(const char* key, runtime::ObjectRef* value) { if (const auto* an = (*value).as()) { std::vector attr; for (size_t i = 0; i < an->size(); ++i) { @@ -141,14 +141,6 @@ class OpAttrExtractor : private AttrVisitor { } } - void Visit(const char* key, runtime::NDArray* value) final { - LOG(FATAL) << "NDArray is not allowed in op attribute"; - } - - void Visit(const char* key, void** value) final { - LOG(FATAL) << "void pointer is not allowed in op attribute"; - } - void Extract(Object* node) { if (node) { this->VisitObjectFields(node); @@ -158,49 +150,47 @@ class OpAttrExtractor : private AttrVisitor { private: void VisitObjectFields(Object* obj) { const TVMFFITypeInfo* tinfo = TVMFFIGetTypeInfo(obj->type_index()); - if (tinfo->extra_info != nullptr) { - ffi::reflection::ForEachFieldInfo(tinfo, [&](const TVMFFIFieldInfo* field_info) { - Any field_value = ffi::reflection::FieldGetter(field_info)(obj); - switch (field_value.type_index()) { - case ffi::TypeIndex::kTVMFFINone: { - SetNodeAttr(field_info->name.data, {""}); - break; - } - case ffi::TypeIndex::kTVMFFIBool: - case ffi::TypeIndex::kTVMFFIInt: { - int64_t value = field_value.cast(); - this->Visit(field_info->name.data, &value); - break; - } - case ffi::TypeIndex::kTVMFFIFloat: { - double value = field_value.cast(); - this->Visit(field_info->name.data, &value); - break; - } - case ffi::TypeIndex::kTVMFFIDataType: { - DataType value(field_value.cast()); - this->Visit(field_info->name.data, &value); - break; - } - case ffi::TypeIndex::kTVMFFINDArray: { - runtime::NDArray value = field_value.cast(); - this->Visit(field_info->name.data, &value); + ICHECK(tinfo->extra_info != nullptr) + << "Object `" << obj->GetTypeKey() + << "` misses reflection registration and do not support serialization"; + ffi::reflection::ForEachFieldInfo(tinfo, [&](const TVMFFIFieldInfo* field_info) { + Any field_value = ffi::reflection::FieldGetter(field_info)(obj); + switch (field_value.type_index()) { + case ffi::TypeIndex::kTVMFFINone: { + SetNodeAttr(field_info->name.data, {""}); + break; + } + case ffi::TypeIndex::kTVMFFIBool: + case ffi::TypeIndex::kTVMFFIInt: { + int64_t value = field_value.cast(); + this->Visit(field_info->name.data, &value); + break; + } + case ffi::TypeIndex::kTVMFFIFloat: { + double value = field_value.cast(); + this->Visit(field_info->name.data, &value); + break; + } + case ffi::TypeIndex::kTVMFFIDataType: { + DataType value(field_value.cast()); + this->Visit(field_info->name.data, &value); + break; + } + case ffi::TypeIndex::kTVMFFINDArray: { + runtime::NDArray value = field_value.cast(); + this->Visit(field_info->name.data, &value); + break; + } + default: { + if (field_value.type_index() >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) { + ObjectRef obj = field_value.cast(); + this->Visit(field_info->name.data, &obj); break; } - default: { - if (field_value.type_index() >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) { - ObjectRef obj = field_value.cast(); - this->Visit(field_info->name.data, &obj); - break; - } - LOG(FATAL) << "Unsupported type: " << field_value.GetTypeKey(); - } + LOG(FATAL) << "Unsupported type: " << field_value.GetTypeKey(); } - }); - } else { - // TODO(tvm-team): remove this once all objects are transitioned to the new reflection - reflection_->VisitAttrs(obj, this); - } + } + }); } JSONGraphObjectPtr node_; diff --git a/src/relax/backend/vm/vm_shape_lower.cc b/src/relax/backend/vm/vm_shape_lower.cc index 50e71ccc14ed..986626a6eae0 100644 --- a/src/relax/backend/vm/vm_shape_lower.cc +++ b/src/relax/backend/vm/vm_shape_lower.cc @@ -528,7 +528,7 @@ class VMShapeLowerMutator builder_->Emit(call, "_"); } } - return std::move(outstanding_todos); + return outstanding_todos; } /*! diff --git a/src/relax/ir/binding_rewrite.cc b/src/relax/ir/binding_rewrite.cc index e6dd082f9a24..c65fe1d0ddeb 100644 --- a/src/relax/ir/binding_rewrite.cc +++ b/src/relax/ir/binding_rewrite.cc @@ -265,7 +265,7 @@ class RemoveUnusedVars : public ExprMutator { caught_rewrite = Downcast(output); } - return std::move(output); + return output; } private: diff --git a/src/relax/transform/static_plan_block_memory.cc b/src/relax/transform/static_plan_block_memory.cc index 0a51e9cd4acb..7521b21d9418 100644 --- a/src/relax/transform/static_plan_block_memory.cc +++ b/src/relax/transform/static_plan_block_memory.cc @@ -400,7 +400,7 @@ void SetTIRVarUpperBound(Function func, arith::Analyzer* ana, for (ObjectRef var_name : non_negative_var_attr_raw) { const auto* key = var_name.as(); CHECK(key != nullptr) << "The element of attr `tir_non_negative_var` should be string. However " - << key->GetTypeKey() << " is got."; + << var_name->GetTypeKey() << " is got."; non_negative_var_attr.insert(GetRef(key)); } Array var_in_signature = TIRVarsInStructInfo(GetStructInfo(func)); diff --git a/src/script/printer/ir_docsifier.cc b/src/script/printer/ir_docsifier.cc index 173eb58a306b..a7530f4ba57b 100644 --- a/src/script/printer/ir_docsifier.cc +++ b/src/script/printer/ir_docsifier.cc @@ -111,47 +111,35 @@ void IRDocsifierNode::RemoveVar(const ObjectRef& obj) { void IRDocsifierNode::SetCommonPrefix(const ObjectRef& root, ffi::TypedFunction is_var) { - class Visitor : private AttrVisitor { + class Visitor { public: - void operator()(ObjectRef obj) { this->Visit("", &obj); } + void operator()(ObjectRef obj) { this->VisitObjectRef(obj); } private: void RecursiveVisitAny(ffi::Any* value) { if (std::optional opt = value->as()) { - this->Visit("", &opt.value()); + this->VisitObjectRef(*opt); } } - void Visit(const char* key, double* value) final {} - void Visit(const char* key, int64_t* value) final {} - void Visit(const char* key, uint64_t* value) final {} - void Visit(const char* key, int* value) final {} - void Visit(const char* key, bool* value) final {} - void Visit(const char* key, std::string* value) final {} - void Visit(const char* key, void** value) final {} - void Visit(const char* key, DataType* value) final {} - void Visit(const char* key, runtime::NDArray* value) final {} - void Visit(const char* key, Optional* value) final {} - void Visit(const char* key, Optional* value) final {} - void Visit(const char* key, ObjectRef* value) final { - const Object* obj = value->get(); - if (obj == nullptr) { + void VisitObjectRef(ObjectRef obj) { + if (!obj.defined()) { return; } - if (visited_.count(obj)) { - if (is_var(GetRef(obj))) { - HandleVar(obj); + if (visited_.count(obj.get())) { + if (is_var(obj)) { + HandleVar(obj.get()); } return; } - visited_.insert(obj); - stack_.push_back(obj); + visited_.insert(obj.get()); + stack_.push_back(obj.get()); if (obj->IsInstance()) { - const ffi::ArrayObj* array = static_cast(obj); + const ffi::ArrayObj* array = static_cast(obj.get()); for (Any element : *array) { this->RecursiveVisitAny(&element); } } else if (obj->IsInstance()) { - const ffi::MapObj* map = static_cast(obj); + const ffi::MapObj* map = static_cast(obj.get()); for (std::pair kv : *map) { this->RecursiveVisitAny(&kv.first); this->RecursiveVisitAny(&kv.second); @@ -159,18 +147,14 @@ void IRDocsifierNode::SetCommonPrefix(const ObjectRef& root, } else { const TVMFFITypeInfo* tinfo = TVMFFIGetTypeInfo(obj->type_index()); if (tinfo->extra_info != nullptr) { - // visit fields with the new reflection ffi::reflection::ForEachFieldInfo(tinfo, [&](const TVMFFIFieldInfo* field_info) { Any field_value = ffi::reflection::FieldGetter(field_info)(obj); this->RecursiveVisitAny(&field_value); }); - } else { - // legacy VisitAttrs mechanism - vtable_->VisitAttrs(const_cast(obj), this); } } - if (is_var(GetRef(obj))) { - HandleVar(obj); + if (is_var(obj)) { + HandleVar(obj.get()); } stack_.pop_back(); } diff --git a/src/script/printer/relax/call.cc b/src/script/printer/relax/call.cc index 9c4efadc2b83..f629ea52bf66 100644 --- a/src/script/printer/relax/call.cc +++ b/src/script/printer/relax/call.cc @@ -26,16 +26,23 @@ namespace tvm { namespace script { namespace printer { -class AttrPrinter : private AttrVisitor { +class AttrPrinter { public: explicit AttrPrinter(ObjectPath p, const IRDocsifier& d, Array* keys, Array* values) : p(std::move(p)), d(d), keys(keys), values(values) {} void operator()(const tvm::Attrs& attrs) { - // NOTE: reflection dispatch for both new and legacy reflection mechanism - const TVMFFITypeInfo* attrs_tinfo = TVMFFIGetTypeInfo(attrs->type_index()); - if (attrs_tinfo->extra_info != nullptr && attrs_tinfo->extra_info->creator != nullptr) { + if (const auto* dict_attrs = attrs.as()) { + for (const auto& [key, value] : dict_attrs->dict) { + keys->push_back(key); + values->push_back(d->AsDoc(value, p->Attr(key))); + } + } else { + const TVMFFITypeInfo* attrs_tinfo = TVMFFIGetTypeInfo(attrs->type_index()); + ICHECK(attrs_tinfo->extra_info != nullptr) + << "Object `" << attrs->GetTypeKey() + << "` misses reflection registration and do not support serialization"; // new printing mechanism using the new reflection ffi::reflection::ForEachFieldInfo(attrs_tinfo, [&](const TVMFFIFieldInfo* field_info) { String field_name = String(field_info->name); @@ -43,75 +50,6 @@ class AttrPrinter : private AttrVisitor { keys->push_back(field_name); values->push_back(d->AsDoc(field_value, p->Attr(field_name))); }); - } else { - const_cast(attrs.get())->VisitAttrs(this); - } - } - - private: - void Visit(const char* key, double* value) final { - keys->push_back(key); - values->push_back(LiteralDoc::Float(*value, p->Attr(key))); - } - - void Visit(const char* key, int64_t* value) final { - keys->push_back(key); - values->push_back(LiteralDoc::Int(*value, p->Attr(key))); - } - - void Visit(const char* key, uint64_t* value) final { - keys->push_back(key); - values->push_back(LiteralDoc::Int(*value, p->Attr(key))); - } - - void Visit(const char* key, int* value) final { - keys->push_back(key); - values->push_back(LiteralDoc::Int(*value, p->Attr(key))); - } - - void Visit(const char* key, bool* value) final { - keys->push_back(key); - values->push_back(LiteralDoc::Boolean(*value, p->Attr(key))); - } - - void Visit(const char* key, std::string* value) final { - keys->push_back(key); - values->push_back(LiteralDoc::Str(*value, p->Attr(key))); - } - - void Visit(const char* key, DataType* value) final { - keys->push_back(key); - values->push_back(LiteralDoc::DataType(*value, p->Attr(key))); - } - - void Visit(const char* key, runtime::ObjectRef* value) final { - keys->push_back(key); - values->push_back(d->AsDoc(*value, p->Attr(key))); - } - - void Visit(const char* key, void** value) final { - LOG(FATAL) << "TypeError: void is not allowed in Attrs"; - } - - void Visit(const char* key, runtime::NDArray* value) final { - LOG(FATAL) << "TypeError: NDArray is not allowed in Attrs"; - } - - void Visit(const char* key, Optional* value) final { - keys->push_back(key); - if (value->has_value()) { - values->push_back(LiteralDoc::Float(value->value(), p->Attr(key))); - } else { - values->push_back(LiteralDoc::None(p->Attr(key))); - } - } - - void Visit(const char* key, Optional* value) final { - keys->push_back(key); - if (value->has_value()) { - values->push_back(LiteralDoc::Int(value->value(), p->Attr(key))); - } else { - values->push_back(LiteralDoc::None(p->Attr(key))); } } diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 1534cfc35889..27993f4a5bf6 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -424,7 +424,7 @@ Stmt GenerateBodyStmt(const Array& indices, const Array& buffe const PrimExpr& compute_body = f_transform_and_remap(expr_body); body = BufferStore(buffers[0], analyzer->Simplify(compute_body), indices); } - return std::move(body); + return body; } /*! \brief Record loops, block vars and binding in the single level scope. */ diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc index 304fe0bf820f..c0a9400e422f 100644 --- a/src/tir/ir/data_type_rewriter.cc +++ b/src/tir/ir/data_type_rewriter.cc @@ -585,7 +585,7 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const LetStmtNode* op) { is_enabled_ = is_condition_ && op->a->dtype.is_int() && op->b->dtype.is_int(); \ auto result = Parent::VisitExpr_(op); \ is_enabled_ = is_enabled; \ - return std::move(result); \ + return result; \ } TVM_DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(EQNode, operator==); diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 70cb57dc6423..60b196d71b06 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -352,13 +352,13 @@ template inline T ConcreteScheduleNode::CreateRV(const StmtSRef& sref) { T rv; this->symbol_table_.Set(rv, sref); - return std::move(rv); + return rv; } inline ExprRV ConcreteScheduleNode::CreateRV(int64_t value) { Var rv("v" + std::to_string(this->symbol_table_.size() + 1), DataType::Int(32)); this->symbol_table_.Set(rv, Integer(static_cast(value))); - return std::move(rv); + return rv; } inline Array ConcreteScheduleNode::CreateRV(const std::vector& value, diff --git a/src/tir/schedule/primitive/decompose_padding.cc b/src/tir/schedule/primitive/decompose_padding.cc index 94db2070c709..ea081721a3a1 100644 --- a/src/tir/schedule/primitive/decompose_padding.cc +++ b/src/tir/schedule/primitive/decompose_padding.cc @@ -390,7 +390,7 @@ class DecomposePaddingBlockReplacer : public StmtMutator { // position to insert pad value filling code return std::move(SeqStmt({desc_.const_filling_loop, new_loop})); } - return std::move(new_loop); + return new_loop; } private: From fbd8aab26200dd16ed7643b5fc457782a5daed77 Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 2 Jul 2025 08:21:42 -0400 Subject: [PATCH 2/3] update local copy move return to fix compiler warning --- apps/hexagon_launcher/launcher_util.cc | 2 +- src/arith/canonical_simplify.cc | 22 ++++++++++-------- src/arith/int_set.cc | 2 +- src/arith/iter_affine_map.cc | 20 +++++++++------- src/arith/narrow_predicate_expression.cc | 3 ++- src/arith/rewrite_simplify.cc | 6 ++--- src/relax/ir/transform.cc | 2 +- src/relax/op/op.cc | 4 ++-- src/relax/transform/convert_layout.cc | 2 +- src/relax/transform/fold_constant.cc | 2 +- src/relax/transform/fuse_ops.cc | 2 +- src/relax/transform/fuse_tir.cc | 12 ++++++---- src/relax/transform/inline_functions.cc | 2 +- src/relax/transform/normalize.cc | 4 ++-- src/target/llvm/codegen_llvm.cc | 14 ++++++++++- src/te/operation/create_primfunc.cc | 2 +- src/tir/ir/data_type_rewriter.cc | 23 +++++++++++-------- src/tir/ir/script/script_complete.cc | 2 +- src/tir/ir/stmt_functor.cc | 4 ++-- src/tir/schedule/concrete_schedule.cc | 4 ++-- .../schedule/primitive/blockize_tensorize.cc | 2 +- src/tir/schedule/primitive/cache_index.cc | 2 +- .../schedule/primitive/cache_read_write.cc | 22 +++++++++--------- src/tir/schedule/primitive/compute_inline.cc | 8 +++---- .../primitive/layout_transformation.cc | 16 ++++++------- .../schedule/primitive/loop_transformation.cc | 14 +++++------ src/tir/schedule/primitive/rolling_buffer.cc | 10 ++++---- src/tir/schedule/transform.cc | 8 +++---- src/tir/transforms/combine_context_call.cc | 2 +- src/tir/transforms/compact_buffer_region.cc | 12 +++++----- .../transforms/convert_blocks_to_opaque.cc | 2 +- src/tir/transforms/flatten_buffer.cc | 8 +++---- src/tir/transforms/inject_double_buffer.cc | 4 ++-- .../transforms/inject_software_pipeline.cc | 14 +++++------ src/tir/transforms/ir_utils.cc | 8 +++---- .../lower_cross_thread_reduction.cc | 7 +++--- src/tir/transforms/lower_custom_datatypes.cc | 7 +++--- .../transforms/lower_device_kernel_launch.cc | 4 ++-- .../lower_device_storage_access_info.cc | 2 +- src/tir/transforms/lower_match_buffer.cc | 2 +- src/tir/transforms/lower_opaque_block.cc | 3 ++- src/tir/transforms/lower_thread_allreduce.cc | 8 +++---- src/tir/transforms/lower_warp_memory.cc | 6 ++--- src/tir/transforms/make_unpacked_api.cc | 2 +- .../manifest_shared_memory_local_stage.cc | 4 ++-- .../transforms/memhammer_lower_auto_copy.cc | 8 +++---- .../merge_shared_memory_allocations.cc | 2 +- .../plan_update_buffer_allocation_location.cc | 4 ++-- .../reduce_branching_through_overcompute.cc | 7 +++--- src/tir/transforms/remove_no_op.cc | 5 ++-- .../remove_weight_layout_rewrite_block.cc | 2 +- src/tir/transforms/simplify.cc | 2 +- src/tir/transforms/split_host_device.cc | 3 ++- src/tir/transforms/storage_rewrite.cc | 9 ++++---- .../transforms/transform_mma_buffer_layout.cc | 12 ++++++---- src/tir/transforms/unify_thread_binding.cc | 3 ++- .../transforms/unsupported_dtype_legalize.cc | 4 ++-- .../update_pointer_storage_scope.cc | 2 +- src/tir/transforms/vectorize_loop.cc | 6 ++--- 59 files changed, 208 insertions(+), 172 deletions(-) diff --git a/apps/hexagon_launcher/launcher_util.cc b/apps/hexagon_launcher/launcher_util.cc index 9c565167142b..5524c2f0f338 100644 --- a/apps/hexagon_launcher/launcher_util.cc +++ b/apps/hexagon_launcher/launcher_util.cc @@ -47,7 +47,7 @@ std::string load_text_file(const std::string& file_name) { std::string buffer(file_size + 1, 0); in_file.read(&buffer[0], file_size); - return std::move(buffer); + return buffer; } void* load_binary_file(const std::string& file_name, void* buffer, size_t buffer_size) { diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index 0f4c773e0d7e..7a02a3bedba8 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -752,7 +752,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const AddNode* op) { } else { ret.CopyOnWrite()->AddToSelf(ToSplitExpr(b), 1); } - return std::move(ret); + return ret; } PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const SubNode* op) { @@ -776,7 +776,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const SubNode* op) { } else { ret.CopyOnWrite()->AddToSelf(ToSplitExpr(b), -1); } - return std::move(ret); + return ret; } PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const MulNode* op) { @@ -798,11 +798,12 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const MulNode* op) { if (a.as()) { SumExpr ret = Downcast(std::move(a)); ret.CopyOnWrite()->MulToSelf(bconst->value); - return std::move(ret); + return ret; + } else { SplitExpr ret = ToSplitExpr(std::move(a)); ret.CopyOnWrite()->MulToSelf(bconst->value); - return std::move(ret); + return ret; } } @@ -969,7 +970,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const DivNode* op) { // can be divided by cval if (extra->IsZero()) { lhs.CopyOnWrite()->DivideBy(cval); - return std::move(lhs); + return lhs; } // both lhs and extra are non-negative if (analyzer_->CanProveGreaterEqual(lhs->Normalize(), 0) && @@ -984,7 +985,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const DivNode* op) { lhs.CopyOnWrite()->AddToSelf(SplitDivConst(ToSplitExpr(temp), cval, kTruncDiv), 1); } } - return std::move(lhs); + return lhs; } } else { // if a >= 0 && a < cval, then result == 0 @@ -1031,7 +1032,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { SeparateDivisibleParts(psum, cval, &lhs, &extra); if (extra->IsZero()) { lhs.CopyOnWrite()->DivideBy(cval); - return std::move(lhs); + return lhs; } // continue simplification. lhs.CopyOnWrite()->DivideBy(cval); @@ -1045,7 +1046,8 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { lhs.CopyOnWrite()->AddToSelf(SplitDivConst(ToSplitExpr(temp), cval, kFloorDiv), 1); } } - return std::move(lhs); + return lhs; + } else { // if a >= 0 && a < cval, then result == 0 auto cbound = analyzer_->const_int_bound(Normalize(a)); @@ -1371,14 +1373,14 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const CastNode* op) { SumExpr se = Downcast(value); if (se->CanPushCastToChildren(op->dtype, analyzer_)) { se.CopyOnWrite()->PushCastToChildren(op->dtype); - return std::move(se); + return se; } } if (value.as()) { SplitExpr se = Downcast(value); if (se->CanPushCastToChildren(op->dtype, analyzer_)) { se.CopyOnWrite()->PushCastToChildren(op->dtype); - return std::move(se); + return se; } } return Rewriter::VisitExpr_(op); diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index f9ade53a3516..0a347040b76b 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -1001,7 +1001,7 @@ IntSet EvalSet(Range r, const Map& dom_map) { // Simplifying first can give tighter bounds if r->min and r->extent share variables PrimExpr sum = r->min + r->extent - 1; auto res = m.Eval(IntervalSet(r->min, ana.Simplify(sum))); - return std::move(res); + return res; } IntSet EvalSet(Range r, const std::unordered_map& dom_map) { diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 01aeba305027..52ed71edeac3 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -1555,7 +1555,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const VarNode* op) { auto var = GetRef(op); auto it = var_map_.find(var); if (it != var_map_.end()) return it->second; - return std::move(var); + return var; } PrimExpr IterMapRewriter::VisitExpr_(const AddNode* op) { @@ -1588,7 +1588,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const AddNode* op) { } else { AddToLhs(ret.CopyOnWrite(), ToIterSumExpr(b), 1); } - return std::move(ret); + return ret; } PrimExpr IterMapRewriter::VisitExpr_(const SubNode* op) { @@ -1623,7 +1623,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const SubNode* op) { } else { AddToLhs(ret.CopyOnWrite(), ToIterSumExpr(b), -1); } - return std::move(ret); + return ret; } PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) { @@ -1660,12 +1660,13 @@ PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) { if (a->IsInstance()) { IterSumExpr ret = Downcast(std::move(a)); MulToLhs(ret.CopyOnWrite(), b); - return std::move(ret); + return ret; + } else { ICHECK(a->IsInstance()); IterSplitExpr ret = Downcast(std::move(a)); ret.CopyOnWrite()->scale *= b; - return std::move(ret); + return ret; } } @@ -1854,7 +1855,8 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr base, P if (is_one(rhs)) { if (is_zero(base)) { // floordiv(x, 1) = x - return std::move(lhs); + return lhs; + } else { // floordiv(x+y, 1) = x+y return IterSumExpr({lhs}, base); @@ -1865,7 +1867,8 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr base, P if (CanProveDivisible(lhs->scale, rhs) && is_zero(base)) { // floordiv(x*c1*c2, c2) = x*c1, c1=scale/rhs lhs.CopyOnWrite()->scale = floordiv(lhs->scale, rhs); - return std::move(lhs); + return lhs; + } else if (CanProveDivisible(lhs->scale, rhs) && CanProveDivisible(base, rhs)) { // floordiv(x*c1*c2 + y*c2, c2) = x*c1 + y, c1=scale/rhs lhs.CopyOnWrite()->scale = floordiv(lhs->scale, rhs); @@ -1929,7 +1932,8 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr base, P auto new_base = analyzer_->Simplify(floordiv(base - left_pad, rhs), 6); if (is_zero(new_base)) { - return std::move(new_split); + return new_split; + } else { return IterSumExpr({new_split}, new_base); } diff --git a/src/arith/narrow_predicate_expression.cc b/src/arith/narrow_predicate_expression.cc index a1a9768110ed..e998ba65f354 100644 --- a/src/arith/narrow_predicate_expression.cc +++ b/src/arith/narrow_predicate_expression.cc @@ -83,7 +83,8 @@ class ExpressionNarrower : public tir::ExprMutator { contains_unknown_expr_ = false; return Bool(CurrentContext() == Context::Minimize); } else if (a.same_as(t->a) && b.same_as(t->b)) { - return std::move(t); + return t; + } else { return T(a, b); } diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index ae6985ddf67e..5ff935776cae 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -1720,7 +1720,7 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(EQ ret) { // supported path. TVM_TRY_REWRITE_IF(x == x, ctrue, SideEffect(x.Eval()) <= CallEffectKind::kReadState); } - return std::move(ret); + return ret; } PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NENode* op) { @@ -1981,7 +1981,7 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(LT ret) { return RecursiveRewrite(floordiv(ret->a, common_factor) < floordiv(ret->b, common_factor)); } } - return std::move(ret); + return ret; } PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NotNode* op) { @@ -2009,7 +2009,7 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(Not ret) { TVM_TRY_REWRITE(!(x != y), x == y); TVM_TRY_RECURSIVE_REWRITE(!(x || y), (!x) && (!y)); TVM_TRY_RECURSIVE_REWRITE(!(x && y), (!x) || (!y)); - return std::move(ret); + return ret; } PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) { diff --git a/src/relax/ir/transform.cc b/src/relax/ir/transform.cc index a63671a0a2d0..382fa6284124 100644 --- a/src/relax/ir/transform.cc +++ b/src/relax/ir/transform.cc @@ -283,7 +283,7 @@ class DataflowBlockMutator : public ExprMutator { ICHECK(global_scope_vars.empty() && symbolic_vars.empty()) << "Error: DataflowBlock Pass should not delete any GlobalScope/Symbolic Var."; - return std::move(updated_block); + return updated_block; } private: diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 9ecf19e7ae11..5fa7feb90e42 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -528,7 +528,7 @@ Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call) { call.CopyOnWrite()->args = new_args; } - return std::move(call); + return call; } void ValidateCallTIR(Call call) { @@ -750,7 +750,7 @@ Expr NormalizeCallTIRInPlace(const BlockBuilder& ctx, Call call) { } } - return std::move(call); + return call; } TVM_REGISTER_NODE_TYPE(CallTIRInplaceAttrs); diff --git a/src/relax/transform/convert_layout.cc b/src/relax/transform/convert_layout.cc index 0c06cac75d19..a9d482139194 100644 --- a/src/relax/transform/convert_layout.cc +++ b/src/relax/transform/convert_layout.cc @@ -157,7 +157,7 @@ class LayoutConvertMutator : public ExprMutator { new_args.push_back(arg); } - return std::move(new_args); + return new_args; } void VisitBinding(const Binding& binding) final { diff --git a/src/relax/transform/fold_constant.cc b/src/relax/transform/fold_constant.cc index 489b17f15c32..8f8cb0b18cb5 100644 --- a/src/relax/transform/fold_constant.cc +++ b/src/relax/transform/fold_constant.cc @@ -301,7 +301,7 @@ class ConstantFolder : public ExprMutator { } } - return std::move(post_call); + return post_call; } Expr VisitExpr_(const VarNode* op) final { diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index 801dea14856d..e21c8a30a0e9 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -1300,7 +1300,7 @@ class CompositeFunctionAnnotator : public ExprMutator { if (!func_node->GetAttr(attr::kComposite)) { // This lambda function doesn't have `attr::kComposite`, so it // was not produced by FuseOps. - return std::move(f_inner); + return f_inner; } f_inner = WithoutAttr(std::move(f_inner), tvm::relax::attr::kPrimitive); diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index 90cb6a00fcfd..f3b9108fa689 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -202,7 +202,8 @@ class FuseTIRBufferSubstitutor : private StmtExprMutator { BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(_op)); const Buffer& buffer = SubstituteBuffer(load->buffer); if (buffer.same_as(load->buffer)) { - return std::move(load); + return load; + } else { auto n = make_object(*load.get()); n->buffer = buffer; @@ -214,7 +215,8 @@ class FuseTIRBufferSubstitutor : private StmtExprMutator { BufferStore store = Downcast(StmtExprMutator::VisitStmt_(_op)); const Buffer& buffer = SubstituteBuffer(store->buffer); if (buffer.same_as(store->buffer)) { - return std::move(store); + return store; + } else { auto n = make_object(*store.get()); n->buffer = buffer; @@ -271,7 +273,8 @@ class FuseTIRBufferSubstitutor : private StmtExprMutator { writes.same_as(block->writes) && // match_buffers.same_as(block->match_buffers) && alloc_buffers.same_as(block->alloc_buffers)) { - return std::move(block); + return block; + } else { auto n = CopyOnWrite(block.get()); n->reads = std::move(reads); @@ -342,7 +345,8 @@ class BlockNameDeduplicator : public tir::StmtMutator { String name = GetUniqueName(block->name_hint); if (name == block->name_hint) { - return std::move(block); + return block; + } else { ObjectPtr n = CopyOnWrite(block.get()); n->name_hint = std::move(name); diff --git a/src/relax/transform/inline_functions.cc b/src/relax/transform/inline_functions.cc index e295226e9e72..c0d69ee810f0 100644 --- a/src/relax/transform/inline_functions.cc +++ b/src/relax/transform/inline_functions.cc @@ -75,7 +75,7 @@ class FunctionInliner : public ExprMutator { } } - return std::move(node); + return node; } private: diff --git a/src/relax/transform/normalize.cc b/src/relax/transform/normalize.cc index d0c41ff77a37..d997ea040d60 100644 --- a/src/relax/transform/normalize.cc +++ b/src/relax/transform/normalize.cc @@ -188,7 +188,7 @@ class GlobalVarNormalizer : private ExprMutator { IRModule RenameModule() { if (!NeedRename()) { - return module_; + return std::move(module_); } // Step 1. Add public functions (functions with global_symbol attributes) @@ -212,7 +212,7 @@ class GlobalVarNormalizer : private ExprMutator { auto module_node = module_.CopyOnWrite(); module_node->functions = after_module->functions; module_node->global_var_map_ = after_module->global_var_map_; - return module_; + return std::move(module_); } /*! \brief Check if any function needs to be renamed. */ diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 45dafa85b939..200692df2b52 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -623,7 +623,7 @@ llvm::Type* CodeGenLLVM::GetLLVMType(const Type& type) const { return t_void_p_; } } else if (ptr->element_type->IsInstance()) { - return t_tvm_tensormap_->getPointerTo(); + return llvmGetPointerTo(t_tvm_tensormap_, 0); } // TODO(tvm-team) consider put storage scope into the pointer type. return llvmGetPointerTo(GetLLVMType(ptr->element_type), GetGlobalAddressSpace()); @@ -2247,9 +2247,15 @@ void CodeGenLLVM::AddDebugInformation(llvm::Function* f_llvm, const Array& auto* store = builder.CreateStore(iter_param, paramAlloca); auto* di_loc = llvm::DILocation::get(*ctx, 0, 0, di_subprogram_); +#if TVM_LLVM_VERSION >= 190 + dbg_info_->di_builder_->insertDeclare( + paramAlloca, param, dbg_info_->di_builder_->createExpression(), llvm::DebugLoc(di_loc), + llvm::BasicBlock::iterator(store)); +#else dbg_info_->di_builder_->insertDeclare(paramAlloca, param, dbg_info_->di_builder_->createExpression(), llvm::DebugLoc(di_loc), store); +#endif } dbg_info_->di_builder_->finalizeSubprogram(f_llvm->getSubprogram()); auto* scope = f_llvm->getSubprogram(); @@ -2283,9 +2289,15 @@ void CodeGenLLVM::AddDebugInformation(llvm::Value* llvm_value, const Var& tir_va auto* di_loc = llvm::DILocation::get(*llvm_target_->GetContext(), 0, 0, di_subprogram_); if (insert_before) { +#if TVM_LLVM_VERSION >= 190 + dbg_info_->di_builder_->insertDeclare( + llvm_value, local_var, dbg_info_->di_builder_->createExpression(), llvm::DebugLoc(di_loc), + llvm::BasicBlock::iterator(insert_before)); +#else dbg_info_->di_builder_->insertDeclare(llvm_value, local_var, dbg_info_->di_builder_->createExpression(), llvm::DebugLoc(di_loc), insert_before); +#endif } else { dbg_info_->di_builder_->insertDeclare(llvm_value, local_var, dbg_info_->di_builder_->createExpression(), diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 27993f4a5bf6..0e90984e28b7 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -173,7 +173,7 @@ class LayoutFreePlaceholdersNormalizer : public StmtMutator { n->annotations.erase(attr); } } - return std::move(block); + return block; } std::unordered_map buffer2index_; diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc index c0a9400e422f..11b29016bf72 100644 --- a/src/tir/ir/data_type_rewriter.cc +++ b/src/tir/ir/data_type_rewriter.cc @@ -61,7 +61,7 @@ Stmt DataTypeLegalizer::VisitStmt_(const BlockRealizeNode* op) { if (changed) { realize.CopyOnWrite()->iter_values = std::move(new_iter_values); } - return std::move(realize); + return realize; } Stmt DataTypeLegalizer::VisitStmt_(const BlockNode* op) { @@ -80,7 +80,7 @@ Stmt DataTypeLegalizer::VisitStmt_(const BlockNode* op) { if (!op->iter_vars.same_as(new_iter_vars)) { new_block.CopyOnWrite()->iter_vars = std::move(new_iter_vars); } - return std::move(new_block); + return new_block; } Stmt DataTypeLegalizer::VisitStmt_(const AttrStmtNode* op) { @@ -269,7 +269,8 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const AllocateNode* op) { n->extents = std::move(new_extents); n->condition = std::move(new_cond); n->body = std::move(new_body); - return std::move(new_allocate); + return new_allocate; + } else { return GetRef(op); } @@ -292,7 +293,7 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const DeclBufferNode* op) { if (!new_buffer.same_as(op->buffer)) { decl_buffer.CopyOnWrite()->buffer = new_buffer; } - return std::move(decl_buffer); + return decl_buffer; } Stmt IndexDataTypeRewriter::VisitStmt_(const BlockRealizeNode* op) { @@ -314,7 +315,8 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const BlockRealizeNode* op) { n->predicate = std::move(new_predicate); n->iter_values = std::move(new_iter_values); n->block = std::move(new_body); - return std::move(new_block_realize); + return new_block_realize; + } else { return GetRef(op); } @@ -362,7 +364,7 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const BlockNode* op) { n->init = std::move(new_init); n->annotations = std::move(new_annotations); n->body = std::move(new_body); - return std::move(new_block); + return new_block; } return GetRef(op); } @@ -483,7 +485,7 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const BufferStoreNode* op) { writer->indices = indices; } - return std::move(store); + return store; } PrimExpr IndexDataTypeRewriter::VisitExpr_(const BufferLoadNode* op) { @@ -498,7 +500,7 @@ PrimExpr IndexDataTypeRewriter::VisitExpr_(const BufferLoadNode* op) { writer->buffer = new_buffer; } - return std::move(load); + return load; } Array IndexDataTypeRewriter::VisitIndices(Array indices) { @@ -529,7 +531,7 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const IfThenElseNode* op) { n->condition = std::move(cond); n->then_case = std::move(then_case); n->else_case = std::move(else_case); - return std::move(new_stmt); + return new_stmt; } return GetRef(op); } @@ -558,7 +560,8 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const ForNode* op) { n->thread_binding = std::move(Optional(std::move(old_thread_binding))); } n->body = new_body; - return std::move(new_for); + return new_for; + } else { return GetRef(op); } diff --git a/src/tir/ir/script/script_complete.cc b/src/tir/ir/script/script_complete.cc index 7e8c2913e55f..00313bfd0227 100644 --- a/src/tir/ir/script/script_complete.cc +++ b/src/tir/ir/script/script_complete.cc @@ -95,7 +95,7 @@ class ScriptCompleter : public StmtMutator { n->annotations.erase(attr::script_parsing_detect_access); return Block(n); } else { - return std::move(block); + return block; } } diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index 4daa0e9a5468..7ecb4558506c 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -602,7 +602,7 @@ class IRSubstitute : public StmtExprMutator { } return ret.value(); } - return std::move(var); + return var; } PrimExpr VisitExpr_(const BufferLoadNode* op) final { @@ -754,7 +754,7 @@ class IRSubstituteWithDataTypeLegalization : public DataTypeLegalizer { if (ret.defined()) { return ret.value(); } - return std::move(var); + return var; } PrimExpr VisitExpr_(const BufferLoadNode* op) final { diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index edaccb51d687..0b8aeec82c1f 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -947,7 +947,7 @@ Any ConcreteScheduleNode::CheckAndGetAnnotationValue(const ffi::Any& ann_val) { for (size_t i = 0; i < arr->size(); i++) { result.push_back(CheckAndGetAnnotationValue(arr->at(i))); } - return std::move(result); + return result; } if (const auto* dict = ann_val.as()) { Map result; @@ -962,7 +962,7 @@ Any ConcreteScheduleNode::CheckAndGetAnnotationValue(const ffi::Any& ann_val) { LOG(FATAL) << "TypeError: annotation dict key expect to be String or StringImm"; } } - return std::move(result); + return result; } LOG(FATAL) << "TypeError: Only strings, integers, floats, ExprRVs and Arrays are supported for now, but " diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc index a5ec9d436b17..4828701bb571 100644 --- a/src/tir/schedule/primitive/blockize_tensorize.cc +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -401,7 +401,7 @@ Stmt Substitute(const Stmt& stmt, const Map& sub, if (!src.same_as(tgt)) { block_sref_reuse_->Set(src, tgt); } - return std::move(tgt); + return tgt; } const Map& sub_; diff --git a/src/tir/schedule/primitive/cache_index.cc b/src/tir/schedule/primitive/cache_index.cc index 2e94b2050496..9ea47def4c31 100644 --- a/src/tir/schedule/primitive/cache_index.cc +++ b/src/tir/schedule/primitive/cache_index.cc @@ -403,7 +403,7 @@ class CacheIndexRewriter : public StmtExprMutator { stmt = Block(n); } info_->block_reuse.Set(old_stmt, stmt); - return std::move(stmt); + return stmt; } Stmt VisitStmt_(const BufferStoreNode* store) final { diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index 1b2a3a1cb478..7f4415f85c2b 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -934,7 +934,7 @@ class CacheReadRewriter : public StmtExprMutator { // We don't mutate the block which generates info->read_buffer. if (block != scope_sref_->stmt && GetBufferRegionFromBuffer(block->writes, info_->read_buffer).defined()) { - return std::move(old_stmt); + return old_stmt; } // Mutate the body Block stmt = Downcast(StmtMutator::VisitStmt_(block)); @@ -970,7 +970,7 @@ class CacheReadRewriter : public StmtExprMutator { } } info_->block_reuse.Set(old_stmt, stmt); - return std::move(stmt); + return stmt; } Array RewriteIndices(const Array& indices) { @@ -1194,15 +1194,15 @@ class CacheWriteRewriter : public StmtExprMutator { n->body = VisitStmt(block->body); Block new_consumer = Block(n); info_->block_reuse.Set(old_stmt, new_consumer); - return std::move(new_consumer); + return new_consumer; } - return std::move(old_stmt); + return old_stmt; } } // We only mutate the block which generates info->write_buffer if (block != writer_block_sref_->stmt && block != scope_sref_->stmt && !under_writer_block_) { - return std::move(old_stmt); + return old_stmt; } // Mutate the body @@ -1240,7 +1240,7 @@ class CacheWriteRewriter : public StmtExprMutator { } } info_->block_reuse.Set(old_stmt, stmt); - return std::move(stmt); + return stmt; } Array RewriteIndices(const Array& indices) { @@ -1261,7 +1261,7 @@ class CacheWriteRewriter : public StmtExprMutator { } return Stmt(n); } else { - return std::move(stmt); + return stmt; } } @@ -1371,7 +1371,7 @@ class ReindexCacheWriteRewriter : public CacheWriteRewriter { n->indices = new_indices_; return Stmt(n); } else { - return std::move(stmt); + return stmt; } } @@ -1565,7 +1565,7 @@ class ReIndexRewriter : public StmtExprMutator { n->alloc_buffers.push_back(info_->alloc.value()); stmt = Block(n); info_->block_reuse.Set(old_stmt, stmt); - return std::move(stmt); + return stmt; } // Visiting the blokc being reindexed @@ -1594,9 +1594,9 @@ class ReIndexRewriter : public StmtExprMutator { stmt = Block(n); } info_->block_reuse.Set(old_stmt, stmt); - return std::move(stmt); + return stmt; } - return std::move(old_stmt); + return old_stmt; } template diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index 85f3a0f82f76..5a6cddd3ccdc 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -324,7 +324,7 @@ class BaseInliner : public StmtExprMutator { bool is_scope_root = src_block.get() == scope_root_sref_->stmt; tgt_block = UpdateBuffersInBlockSignature(std::move(tgt_block), is_scope_root); block_reuse.Set(src_block, tgt_block); - return std::move(tgt_block); + return tgt_block; } private: @@ -527,7 +527,7 @@ class ComputeInliner : public BaseInliner { PrimExpr VisitExpr_(const BufferLoadNode* _load) final { BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(_load)); if (!load->buffer.same_as(inlined_buffer_)) { - return std::move(load); + return load; } return ReplaceInlinedBuffer(std::move(load)); } @@ -758,13 +758,13 @@ class ReverseComputeInliner : public BaseInliner { tgt_block_realize = BuildInlinedConsumerPredicate(tgt_block_realize); block_reuse.Set(src_block, tgt_block_realize->block); } - return std::move(tgt_block_realize); + return tgt_block_realize; } Stmt VisitStmt_(const BufferStoreNode* _store) final { BufferStore store = Downcast(StmtExprMutator::VisitStmt_(_store)); if (!store->buffer.same_as(inlined_buffer_)) { - return std::move(store); + return store; } return ReplaceInlinedBuffer(std::move(store)); } diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index a455afe6b067..5cce3ea7758f 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -383,7 +383,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { } } - return std::move(realize); + return realize; } Stmt VisitStmt_(const BlockNode* op) final { @@ -391,7 +391,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { Block mutated = Downcast(StmtExprMutator::VisitStmt_(op)); RecordReplacement(orig, mutated); - return std::move(mutated); + return mutated; } PrimExpr VisitExpr_(const VarNode* op) final { @@ -399,7 +399,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { if (auto opt = var_remap.Get(var)) { return opt.value(); } else { - return std::move(var); + return var; } } @@ -841,7 +841,7 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { auto* n = buffer_load.CopyOnWrite(); RewriteBufferAccess(&n->buffer, &n->indices); } - return std::move(buffer_load); + return buffer_load; } Stmt VisitStmt_(const BufferStoreNode* op) final { @@ -850,7 +850,7 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { auto* n = buffer_store.CopyOnWrite(); RewriteBufferAccess(&n->buffer, &n->indices); } - return std::move(buffer_store); + return buffer_store; } void RewriteAccessRegion(Array* old_access_regions, @@ -893,7 +893,7 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { }); RecordReplacement(orig, block); - return std::move(block); + return block; } void RecordReplacement(Block before, Block after) { @@ -1598,7 +1598,7 @@ struct TransformLayoutTraits : public UnpackedInstTraits attrs_record.push_back(attrs[2]); } attrs_record.push_back(attrs[3]); - return std::move(attrs_record); + return attrs_record; } static Array AttrsFromJSON(const ObjectRef& attrs_record_) { @@ -1644,7 +1644,7 @@ struct TransformBlockLayoutTraits : public UnpackedInstTraits attrs_record; attrs_record.reserve(kNumAttrs); attrs_record.push_back(String(::tvm::SaveJSON(attrs[0]))); - return std::move(attrs_record); + return attrs_record; } static Array AttrsFromJSON(const ObjectRef& attrs_record_) { diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index d112560a1fee..88e6f61eb333 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -56,7 +56,7 @@ class SubstituteVarAndCollectOpaqueBlock : public StmtExprMutator { if (Optional ret = vmap_(var)) { return tvm::cast(var.dtype(), ret.value()); } else { - return std::move(var); + return var; } } @@ -65,7 +65,7 @@ class SubstituteVarAndCollectOpaqueBlock : public StmtExprMutator { if (realize->block->iter_vars.empty()) { opaque_blocks_->Set(op->block, realize->block); } - return std::move(realize); + return realize; } /*! \brief The substitute function */ @@ -113,7 +113,7 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator { break; } } - return std::move(realize); + return realize; } Array v = arith::IterMapSimplify(/*indices=*/op->iter_values, @@ -537,7 +537,7 @@ class BlockMutator : public StmtExprMutator { if (new_block->iter_vars.size() == 0 || inner_iter_var_index == -1) { new_block.CopyOnWrite()->name_hint = new_block.CopyOnWrite()->name_hint + "_" + new_loop_var_->name_hint; - return std::move(new_block); + return new_block; } Var iter_var_ = new_block->iter_vars[inner_iter_var_index]->var; @@ -594,7 +594,7 @@ class BlockMutator : public StmtExprMutator { // Update all instances of old iter_vars in the block with new iter_vars auto block_stmt = tir::Substitute(new_block, var_map); - return std::move(block_stmt); + return block_stmt; } Stmt VisitStmt_(const BlockRealizeNode* realize) final { @@ -607,7 +607,7 @@ class BlockMutator : public StmtExprMutator { } } BlockRealize stmt = Downcast(StmtExprMutator::VisitStmt_(realize)); - return std::move(stmt); + return stmt; } Stmt VisitStmt_(const ForNode* op) final { @@ -776,7 +776,7 @@ class LoopReconstructor : private StmtMutator { } else if (ret->size() == 1) { return ret->seq[0]; } else { - return std::move(ret); + return ret; } } diff --git a/src/tir/schedule/primitive/rolling_buffer.cc b/src/tir/schedule/primitive/rolling_buffer.cc index 19375f7235fc..bef5faf92b67 100644 --- a/src/tir/schedule/primitive/rolling_buffer.cc +++ b/src/tir/schedule/primitive/rolling_buffer.cc @@ -215,7 +215,7 @@ class RollingBufferInfoCollector { // Pick the outermost iter_var that's mentioned in the bounds // to be the rolling axis Optional roll_iter_var; - int roll_axis; + int roll_axis = 0; for (const tir::StmtSRef& loop_sref : loop_srefs) { auto loop_var = loop_sref->StmtAs()->loop_var; @@ -331,7 +331,7 @@ class RollingBufferRewriter : public StmtExprMutator { RewriteAccessRegion(&n->writes, infered_access_regions[1]); } info_->block_reuse.Set(old_stmt, stmt); - return std::move(stmt); + return stmt; } Stmt VisitStmt_(const BlockRealizeNode* realize) final { @@ -355,7 +355,7 @@ class RollingBufferRewriter : public StmtExprMutator { BlockRealizeNode* n = stmt.CopyOnWrite(); n->predicate = condition; } - return std::move(stmt); + return stmt; } Stmt VisitStmt_(const BufferStoreNode* op) final { @@ -366,7 +366,7 @@ class RollingBufferRewriter : public StmtExprMutator { // Need to add predicate to the current block to avoid recomputing elements. rewrite_block_predicate_ = true; } - return std::move(stmt); + return stmt; } PrimExpr VisitExpr_(const BufferLoadNode* op) final { @@ -375,7 +375,7 @@ class RollingBufferRewriter : public StmtExprMutator { BufferLoadNode* n = stmt.CopyOnWrite(); RewriteBufferAccess(&n->buffer, &n->indices); } - return std::move(stmt); + return stmt; } private: diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index c0929e01a8ad..fc284fec20e6 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -226,7 +226,7 @@ Stmt ReplaceBufferMutator::VisitStmt_(const BlockNode* block) { if (block_sref_reuse_ != nullptr) { block_sref_reuse_->Set(GetRef(block), new_block); } - return std::move(new_block); + return new_block; } } @@ -470,19 +470,19 @@ Stmt BlockBufferAccessSimplifier::VisitStmt_(const BlockNode* op) { auto* n = block.CopyOnWrite(); SimplifyAccessRegion(&n->reads); SimplifyAccessRegion(&n->writes); - return std::move(block); + return block; } Stmt BlockBufferAccessSimplifier::VisitStmt_(const BufferStoreNode* op) { BufferStore node = Downcast(arith::IRMutatorWithAnalyzer::VisitStmt_(op)); SimplifyBufferIndices(&node.CopyOnWrite()->indices); - return std::move(node); + return node; } PrimExpr BlockBufferAccessSimplifier::VisitExpr_(const BufferLoadNode* op) { BufferLoad node = Downcast(arith::IRMutatorWithAnalyzer::VisitExpr_(op)); SimplifyBufferIndices(&node.CopyOnWrite()->indices); - return std::move(node); + return node; } /******** PrimFunc-level analysis and transformation ********/ diff --git a/src/tir/transforms/combine_context_call.cc b/src/tir/transforms/combine_context_call.cc index 2ff7c03c6287..38fe86a9e2ac 100644 --- a/src/tir/transforms/combine_context_call.cc +++ b/src/tir/transforms/combine_context_call.cc @@ -53,7 +53,7 @@ class ContextCallCombiner final : public StmtExprMutator { ICHECK(ctx.dtype().is_handle()); Var ctx_var("ctx_cache_", ctx.dtype()); ctx_map_[ctx] = ctx_var; - return std::move(ctx_var); + return ctx_var; } } else { return StmtExprMutator::VisitExpr_(op); diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index c5c6accf221a..b75d73bc0f5f 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -552,14 +552,14 @@ class BufferCompactor : public StmtExprMutator { BufferStore store = Downcast(StmtExprMutator::VisitStmt_(_op)); BufferStoreNode* op = store.CopyOnWrite(); RewriteBufferAccess(&op->buffer, &op->indices); - return std::move(store); + return store; } PrimExpr VisitExpr_(const BufferLoadNode* _op) final { BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(_op)); BufferLoadNode* op = load.CopyOnWrite(); RewriteBufferAccess(&op->buffer, &op->indices); - return std::move(load); + return load; } Stmt VisitStmt_(const BlockNode* op) final { @@ -576,7 +576,7 @@ class BufferCompactor : public StmtExprMutator { RewriteBufferRegions(&n->writes); RewriteMatchBuffers(&n->match_buffers); n->alloc_buffers = std::move(alloc_buffers); - return std::move(block); + return block; } Stmt VisitStmt_(const DeclBufferNode* op) final { @@ -591,20 +591,20 @@ class BufferCompactor : public StmtExprMutator { Allocate allocate = Downcast(StmtExprMutator::VisitStmt_(op)); auto it = buffer_info_.find(allocate->buffer_var); if (it == buffer_info_.end()) { - return std::move(allocate); + return allocate; } // Rewrite allocation shape if the corresponding buffer is in the buffer_info_ // dict and the dtype is consistent, which denotes there are no buffer aliasing // and the compaction is safe. const Buffer& new_buffer = it->second.new_buffer; if (op->dtype != new_buffer->dtype) { - return std::move(allocate); + return allocate; } Array new_shape = GetBufferAllocationShape(new_buffer); auto n = allocate.CopyOnWrite(); ICHECK(n->buffer_var.same_as(new_buffer->data)); n->extents = new_shape; - return std::move(allocate); + return allocate; } Buffer RewriteAllocBuffer(const Buffer& buffer) { diff --git a/src/tir/transforms/convert_blocks_to_opaque.cc b/src/tir/transforms/convert_blocks_to_opaque.cc index 1b29cea2f27a..09c2762efab5 100644 --- a/src/tir/transforms/convert_blocks_to_opaque.cc +++ b/src/tir/transforms/convert_blocks_to_opaque.cc @@ -63,7 +63,7 @@ class OpaqueBlockConverter : public StmtExprMutator { if (!new_block->iter_vars.empty()) { new_block.CopyOnWrite()->iter_vars.clear(); } - return std::move(new_block); + return new_block; } Stmt VisitStmt_(const BlockRealizeNode* realize) final { diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index 5ea0a60ea2a8..a0c39c8fcb68 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -149,7 +149,7 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { alloc.CopyOnWrite()->extents = new_extents; } - return std::move(alloc); + return alloc; } Stmt VisitStmt_(const DeclBufferNode* op) final { @@ -196,9 +196,9 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { << "Expected int8 backing array for boolean tensor"; auto writer = store.CopyOnWrite(); writer->value = tvm::cast(DataType::Int(8), store->value); - return std::move(store); + return store; } - return std::move(store); + return store; } PrimExpr VisitExpr_(const BufferLoadNode* op) final { @@ -214,7 +214,7 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { load.CopyOnWrite()->dtype = DataType::Int(8); return tvm::cast(DataType::Bool(), load); } else { - return std::move(load); + return load; } } diff --git a/src/tir/transforms/inject_double_buffer.cc b/src/tir/transforms/inject_double_buffer.cc index 3c3ddf4b4b26..8c4e526e5175 100644 --- a/src/tir/transforms/inject_double_buffer.cc +++ b/src/tir/transforms/inject_double_buffer.cc @@ -202,7 +202,7 @@ class DoubleBufferInjector : public StmtExprMutator { writer->indices = {e.switch_write_var * e.stride + node->indices[0]}; } - return std::move(node); + return node; } PrimExpr VisitExpr_(const BufferLoadNode* op) final { @@ -221,7 +221,7 @@ class DoubleBufferInjector : public StmtExprMutator { writer->indices = {e.switch_read_var * e.stride + node->indices[0]}; } - return std::move(node); + return node; } Buffer GetRemappedBuffer(Buffer buf, PrimExpr stride) { diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc index 4f137619ea7e..d5f69315b149 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -260,14 +260,14 @@ class PipelineBodyRewriter : public StmtExprMutator { for (const Buffer& alloc_buffer : op->alloc_buffers) { buffer_data_to_buffer_.erase(alloc_buffer->data); } - return std::move(block); + return block; } Stmt VisitStmt_(const BufferStoreNode* op) final { BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); auto it = buffer_remap_.find(store->buffer); if (it == buffer_remap_.end()) { - return std::move(store); + return store; } const Buffer& new_buffer = (*it).second; auto* n = store.CopyOnWrite(); @@ -275,14 +275,14 @@ class PipelineBodyRewriter : public StmtExprMutator { PrimExpr version = floormod((pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]); n->indices.insert(n->indices.begin(), version); - return std::move(store); + return store; } PrimExpr VisitExpr_(const BufferLoadNode* op) final { BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(op)); auto it = buffer_remap_.find(load->buffer); if (it == buffer_remap_.end()) { - return std::move(load); + return load; } const Buffer& new_buffer = (*it).second; auto* n = load.CopyOnWrite(); @@ -290,7 +290,7 @@ class PipelineBodyRewriter : public StmtExprMutator { PrimExpr version = floormod((pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]); n->indices.insert(n->indices.begin(), version); - return std::move(load); + return load; } PrimExpr VisitExpr_(const CallNode* op) final { @@ -1074,7 +1074,7 @@ class PipelineInjector : private StmtExprMutator { // Step 1: Recursively rewrite the children first. For for_node = Downcast(StmtExprMutator::VisitStmt_(op)); if (!HasPipelineAnnotation(op)) { - return std::move(for_node); + return for_node; } // Step 2: Find the body and buffer allocations of the pipeline. The body can be direct child of // the for-loop. If the for-loop has BlockRealize as its child, the pipeline body will be the @@ -1215,7 +1215,7 @@ class PipelineInjector : private StmtExprMutator { for (const auto& buffer : op->alloc_buffers) { buffer_data_to_buffer_.erase(buffer->data); } - return std::move(block); + return block; } bool HasPipelineAnnotation(const ForNode* op) const { diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index 0017e97beb88..72ee656d1bac 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -228,13 +228,13 @@ class IRConvertSSA final : public StmtExprMutator { PrimExpr VisitExpr_(const BufferLoadNode* op) final { auto node = Downcast(StmtExprMutator::VisitExpr_(op)); auto output = VisitBufferAccess(std::move(node)); - return std::move(output); + return output; } Stmt VisitStmt_(const BufferStoreNode* op) final { auto node = Downcast(StmtExprMutator::VisitStmt_(op)); auto output = VisitBufferAccess(std::move(node)); - return std::move(output); + return output; } Stmt VisitStmt_(const DeclBufferNode* op) final { @@ -243,7 +243,7 @@ class IRConvertSSA final : public StmtExprMutator { if (!new_buffer.same_as(decl->buffer)) { decl.CopyOnWrite()->buffer = std::move(new_buffer); } - return std::move(decl); + return decl; } Stmt VisitStmt_(const BlockNode* op) final { @@ -669,7 +669,7 @@ Optional ConditionalBoundsContext::TrySolveCondition() { if (!result->relations.empty()) { return std::nullopt; } - return std::move(result); + return result; } ConditionalBoundsContext::ConditionalBoundsContext( diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc b/src/tir/transforms/lower_cross_thread_reduction.cc index 31d7f91d74e9..a10937a2b7c9 100644 --- a/src/tir/transforms/lower_cross_thread_reduction.cc +++ b/src/tir/transforms/lower_cross_thread_reduction.cc @@ -255,9 +255,10 @@ class InThreadReducerMaker : private StmtMutator { if (!res->body.defined() || collector.CheckHasReductionBlocks(res)) { return res->body; } - return std::move(res); + return res; + } else { - return std::move(res); + return res; } } else { return Stmt{nullptr}; @@ -788,7 +789,7 @@ class CrossThreadReductionTransformer : public StmtMutator { } } } - return std::move(new_block); + return new_block; } void MakeCrossThreadReduction(const BlockRealizeNode* realize, diff --git a/src/tir/transforms/lower_custom_datatypes.cc b/src/tir/transforms/lower_custom_datatypes.cc index dbc529cfeabd..e0863b865d15 100644 --- a/src/tir/transforms/lower_custom_datatypes.cc +++ b/src/tir/transforms/lower_custom_datatypes.cc @@ -80,7 +80,7 @@ class CustomDatatypesLowerer : public StmtExprMutator { if (itr != var_remap_.end()) { return itr->second; } else { - return std::move(var); + return var; } } @@ -115,11 +115,12 @@ class CustomDatatypesLowerer : public StmtExprMutator { // Not needed for BufferStoreNode, so we can't just call // LegalizeDtype() in VisitBufferAccess. if (node.same_as(modified)) { - return std::move(node); + return node; + } else { auto writer = modified.CopyOnWrite(); writer->LegalizeDType(); - return std::move(modified); + return modified; } } diff --git a/src/tir/transforms/lower_device_kernel_launch.cc b/src/tir/transforms/lower_device_kernel_launch.cc index 2ca0e6d92f68..c32c0c3debf3 100644 --- a/src/tir/transforms/lower_device_kernel_launch.cc +++ b/src/tir/transforms/lower_device_kernel_launch.cc @@ -240,7 +240,7 @@ class DeviceKernelMutator : public StmtExprMutator { auto node = Downcast(Parent::VisitExpr_(op)); auto* gvar = op->op.as(); - if (!gvar) return std::move(node); + if (!gvar) return node; auto it = device_info_map_.find(gvar); ICHECK(it != device_info_map_.end()) @@ -255,7 +255,7 @@ class DeviceKernelMutator : public StmtExprMutator { if (same_target) { // Calls within the same target may be handled at codegen time // as internal subroutine calls. - return std::move(node); + return node; } bool same_device_type = diff --git a/src/tir/transforms/lower_device_storage_access_info.cc b/src/tir/transforms/lower_device_storage_access_info.cc index a30232b9ce80..fe5e0389676d 100644 --- a/src/tir/transforms/lower_device_storage_access_info.cc +++ b/src/tir/transforms/lower_device_storage_access_info.cc @@ -68,7 +68,7 @@ class StorageAccessInfoLower : public StmtExprMutator { it != storage_info_.end() && !it->second->head_address.defined()) { return node->body; } else { - return std::move(node); + return node; } } diff --git a/src/tir/transforms/lower_match_buffer.cc b/src/tir/transforms/lower_match_buffer.cc index 6e2ea5bc14af..e0cb7cf80acc 100644 --- a/src/tir/transforms/lower_match_buffer.cc +++ b/src/tir/transforms/lower_match_buffer.cc @@ -78,7 +78,7 @@ class MatchBufferLower : public StmtExprMutator { if (it != var_map_.end()) { return (*it).second; } else { - return std::move(v); + return v; } } diff --git a/src/tir/transforms/lower_opaque_block.cc b/src/tir/transforms/lower_opaque_block.cc index f3551987426d..64de12263c3e 100644 --- a/src/tir/transforms/lower_opaque_block.cc +++ b/src/tir/transforms/lower_opaque_block.cc @@ -120,7 +120,8 @@ class OpaqueBlockLower : public StmtExprMutator { Var var = GetRef(op); auto it = unit_loop_vars_.find(var); if (it == unit_loop_vars_.end()) { - return std::move(var); + return var; + } else { PrimExpr expr = it->second; if (expr.dtype() != var.dtype()) { diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 0d2092338228..81023d5471f3 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -88,7 +88,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { write_ptr->body = AttrStmt(buf->data, attr::volatile_scope, 1, write_ptr->body); } } - return std::move(node); + return node; } Optional GetRemappedBuffer(const Buffer& buf) { @@ -111,7 +111,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { if (auto buf = GetRemappedBuffer(node->buffer)) { node.CopyOnWrite()->buffer = buf.value(); } - return std::move(node); + return node; } PrimExpr VisitExpr_(const BufferLoadNode* op) final { @@ -128,7 +128,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { if (auto opt = GetRemappedBuffer(load->buffer)) { load.CopyOnWrite()->buffer = opt.value(); } - return std::move(load); + return load; } Stmt VisitStmt_(const BufferStoreNode* op) final { @@ -137,7 +137,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { if (auto opt = GetRemappedBuffer(store->buffer)) { store.CopyOnWrite()->buffer = opt.value(); } - return std::move(store); + return store; } private: diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index b1642bef3c92..0cf6f9d152d1 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -302,14 +302,14 @@ class WarpAccessRewriter : protected StmtExprMutator { writer->indices = {local_index}; } - return std::move(store); + return store; } PrimExpr VisitExpr_(const BufferLoadNode* op) override { auto load = Downcast(StmtExprMutator::VisitExpr_(op)); if (load->buffer->data.get() != buffer_) { - return std::move(load); + return load; } ICHECK_EQ(op->indices.size(), 1) << "Expected flat memory to use as warp memory. " @@ -325,7 +325,7 @@ class WarpAccessRewriter : protected StmtExprMutator { writer->indices = {local_index}; if (analyzer_->CanProveEqual(group, warp_index_)) { - return std::move(load); + return load; } PrimExpr mask = Call(DataType::UInt(32), builtin::tvm_warp_activemask(), {}); diff --git a/src/tir/transforms/make_unpacked_api.cc b/src/tir/transforms/make_unpacked_api.cc index a72d68972735..989b39f0e370 100644 --- a/src/tir/transforms/make_unpacked_api.cc +++ b/src/tir/transforms/make_unpacked_api.cc @@ -83,7 +83,7 @@ class SubroutineCallRewriter : public StmtExprMutator { } } - return std::move(node); + return node; } const std::unordered_set& external_methods_; bool made_change_{false}; diff --git a/src/tir/transforms/manifest_shared_memory_local_stage.cc b/src/tir/transforms/manifest_shared_memory_local_stage.cc index dc9420f728be..7c9fe2b9aacd 100644 --- a/src/tir/transforms/manifest_shared_memory_local_stage.cc +++ b/src/tir/transforms/manifest_shared_memory_local_stage.cc @@ -198,7 +198,7 @@ class SharedMemoryLocalStageInserter : public StmtMutator { buffer_local_stage_.Set(target_buffer, local_stage); target_buffers_.push_back(target_buffer); - return std::move(new_block); + return new_block; } std::unordered_set allocated_buffers( @@ -255,7 +255,7 @@ class SharedMemoryLocalStageInserter : public StmtMutator { new_block_node->alloc_buffers = Concat(new_block_node->alloc_buffers, new_alloc_buffers); } new_block_node->body = new_seq.size() == 1 ? new_seq[0] : SeqStmt(new_seq); - return std::move(new_block); + return new_block; } std::vector ancestor_loop_or_blocks_; // ancestor loops or block realize diff --git a/src/tir/transforms/memhammer_lower_auto_copy.cc b/src/tir/transforms/memhammer_lower_auto_copy.cc index 916c5c84e9af..334a44df069c 100644 --- a/src/tir/transforms/memhammer_lower_auto_copy.cc +++ b/src/tir/transforms/memhammer_lower_auto_copy.cc @@ -198,7 +198,7 @@ class AutoPadder { if (buffer_map_.count(op->buffer)) { op->buffer = buffer_map_[op->buffer]; } - return std::move(load); + return load; } Stmt VisitStmt_(const BufferStoreNode* _op) final { @@ -207,7 +207,7 @@ class AutoPadder { if (buffer_map_.count(op->buffer)) { op->buffer = buffer_map_[op->buffer]; } - return std::move(store); + return store; } Stmt VisitStmt_(const BlockNode* op) final { @@ -665,7 +665,7 @@ class AutoCopyMutator : public StmtExprMutator { if (!GetAnn(op, tir::attr::auto_copy).value_or(false)) { BlockNode* n = block.CopyOnWrite(); n->alloc_buffers = padder.PadSharedMemory(std::move(n->alloc_buffers)); - return std::move(block); + return block; } ICHECK_EQ(block->writes.size(), 1); ICHECK_GE(block->reads.size(), 1); @@ -703,7 +703,7 @@ class AutoCopyMutator : public StmtExprMutator { } padder.AnalyzeSharedMemoryAccess(block->body, outer_loops_, data_bits, thread_extent_); n->alloc_buffers = padder.PadSharedMemory(std::move(n->alloc_buffers)); - return std::move(block); + return block; } Stmt VisitStmt_(const ForNode* op) final { diff --git a/src/tir/transforms/merge_shared_memory_allocations.cc b/src/tir/transforms/merge_shared_memory_allocations.cc index 52966e005aaa..eaf0aab391ec 100644 --- a/src/tir/transforms/merge_shared_memory_allocations.cc +++ b/src/tir/transforms/merge_shared_memory_allocations.cc @@ -331,7 +331,7 @@ class SharedMemoryRewriter : public StmtExprMutator { if (auto new_buf = GetUpdatedBuffer(node->buffer); !new_buf.same_as(node->buffer)) { node.CopyOnWrite()->buffer = new_buf; } - return std::move(node); + return node; } PrimExpr VisitExpr_(const BufferLoadNode* op) final { diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc b/src/tir/transforms/plan_update_buffer_allocation_location.cc index c141ef33c289..f4547a57581a 100644 --- a/src/tir/transforms/plan_update_buffer_allocation_location.cc +++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc @@ -156,7 +156,7 @@ class BufferAllocationLocator : public StmtExprMutator { node.CopyOnWrite()->body = InjectOpaqueBlock(node->body, new_block_alloc_bufs); } - return std::move(node); + return node; } Stmt VisitStmt_(const BlockNode* op) final { @@ -220,7 +220,7 @@ class BufferAllocationLocator : public StmtExprMutator { n->reads = access[0]; n->writes = access[1]; BlockRealize realize({}, Bool(true), Block(n)); - return std::move(realize); + return realize; } Array RemoveRedundantBufferRegion(const Array& region) const { diff --git a/src/tir/transforms/reduce_branching_through_overcompute.cc b/src/tir/transforms/reduce_branching_through_overcompute.cc index 45fc523a7f9f..5015d2418a47 100644 --- a/src/tir/transforms/reduce_branching_through_overcompute.cc +++ b/src/tir/transforms/reduce_branching_through_overcompute.cc @@ -71,7 +71,8 @@ struct ElseBranchFiller : StmtExprMutator { Stmt VisitStmt_(const IfThenElseNode* op) override { IfThenElse ret = Downcast(StmtExprMutator::VisitStmt_(op)); if (ret->else_case.defined()) { - return std::move(ret); + return ret; + } else { auto new_else_clause = Evaluate(0); new_else_clauses.insert(new_else_clause); @@ -95,7 +96,7 @@ class ElseBranchStripper : public StmtExprMutator { as_eval && new_else_clauses_.count(as_eval.value())) { return IfThenElse(ret->condition, ret->then_case); } else { - return std::move(ret); + return ret; } } @@ -137,7 +138,7 @@ class BranchReducer : public arith::IRMutatorWithAnalyzer { } else if (is_special_case(!cond->condition, cond->then_case, else_case)) { return cond->then_case; } else { - return std::move(cond); + return cond; } } diff --git a/src/tir/transforms/remove_no_op.cc b/src/tir/transforms/remove_no_op.cc index f752901619eb..a67b2bf17878 100644 --- a/src/tir/transforms/remove_no_op.cc +++ b/src/tir/transforms/remove_no_op.cc @@ -239,7 +239,7 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer { } } - return std::move(store); + return store; } Stmt VisitStmt_(const DeclBufferNode* op) final { @@ -249,7 +249,8 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer { var_use(node->body); if (var_use.buffer_use_count_.count(node->buffer.get())) { - return std::move(node); + return node; + } else { return node->body; } diff --git a/src/tir/transforms/remove_weight_layout_rewrite_block.cc b/src/tir/transforms/remove_weight_layout_rewrite_block.cc index 881f321bf673..0ca4262fc119 100644 --- a/src/tir/transforms/remove_weight_layout_rewrite_block.cc +++ b/src/tir/transforms/remove_weight_layout_rewrite_block.cc @@ -64,7 +64,7 @@ class RemoveLayoutRewriteBlock : public StmtMutator { n->alloc_buffers = std::move(alloc_buffers); return Stmt(n); } else { - return std::move(block); + return block; } } diff --git a/src/tir/transforms/simplify.cc b/src/tir/transforms/simplify.cc index f6b367337381..1d8ff9dab0c8 100644 --- a/src/tir/transforms/simplify.cc +++ b/src/tir/transforms/simplify.cc @@ -299,7 +299,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { return Evaluate(0); } } - return std::move(store); + return store; } private: diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 0b12bd02d482..13bf0186895a 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -110,7 +110,8 @@ class HostDeviceSplitter : public StmtMutator { StringImm("Error executing compute kernel"), Evaluate(0)); LetStmt let_check(kernel_error_code, kernel_call, assert_success); - return std::move(let_check); + return let_check; + } else { return Evaluate(Call(DataType::Void(), kernel_symbol_global, args)); } diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index b8062e2a2f10..bd3afe1d3f84 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -528,7 +528,7 @@ class StoragePlanRewriter : public StmtExprMutator { Buffer buf = RemapBuffer(op->buffer, it->second->alloc_var); node.CopyOnWrite()->buffer = buf; } - return std::move(node); + return node; } private: @@ -1510,14 +1510,15 @@ class VectorTypeRewriter : public StmtExprMutator { // Not needed for BufferStoreNode, so we can't just call // LegalizeDtype() in VisitBufferAccess. if (node.same_as(modified)) { - return std::move(node); + return node; + } else { auto writer = modified.CopyOnWrite(); writer->LegalizeDType(); if (shuffle_index >= 0) { return Shuffle::ExtractElement(std::move(modified), shuffle_index); } - return std::move(modified); + return modified; } } @@ -1525,7 +1526,7 @@ class VectorTypeRewriter : public StmtExprMutator { auto node = Downcast(StmtExprMutator::VisitStmt_(op)); auto [modified, shuffle_index] = VisitBufferAccess(std::move(node)); ICHECK(shuffle_index < 0); - return std::move(modified); + return modified; } Stmt VisitStmt_(const LetStmtNode* op) final { diff --git a/src/tir/transforms/transform_mma_buffer_layout.cc b/src/tir/transforms/transform_mma_buffer_layout.cc index 3ef35d74cf8a..797caae31100 100644 --- a/src/tir/transforms/transform_mma_buffer_layout.cc +++ b/src/tir/transforms/transform_mma_buffer_layout.cc @@ -71,7 +71,8 @@ class MmaBufferLayoutTransformer : public StmtExprMutator { buffer->axis_separators); this->buffer_map_.insert({buffer, new_buffer}); this->buffer_var_map_.insert({buffer->data, new_buffer->data}); - return std::move(new_buffer); + return new_buffer; + } else if (buffer.scope() == "m16n8k8.matrixA") { // m16n8k8.matrixA // bi = 32, bj = 8 @@ -92,7 +93,8 @@ class MmaBufferLayoutTransformer : public StmtExprMutator { buffer->axis_separators); this->buffer_map_.insert({buffer, new_buffer}); this->buffer_var_map_.insert({buffer->data, new_buffer->data}); - return std::move(new_buffer); + return new_buffer; + } else if (buffer.scope() == "m16n8k8.matrixB") { // m16n8k8.matrixB // bj = 8, bj = 32 @@ -113,7 +115,7 @@ class MmaBufferLayoutTransformer : public StmtExprMutator { buffer->axis_separators); this->buffer_map_.insert({buffer, new_buffer}); this->buffer_var_map_.insert({buffer->data, new_buffer->data}); - return std::move(new_buffer); + return new_buffer; } return buffer; }; @@ -138,7 +140,7 @@ class MmaBufferLayoutTransformer : public StmtExprMutator { n->buffer = buffer_map_[store->buffer]; } } - return std::move(store); + return store; } PrimExpr VisitExpr_(const BufferLoadNode* op) { @@ -157,7 +159,7 @@ class MmaBufferLayoutTransformer : public StmtExprMutator { n->buffer = buffer_map_[load->buffer]; } } - return std::move(load); + return load; } PrimExpr VisitExpr_(const VarNode* op) { diff --git a/src/tir/transforms/unify_thread_binding.cc b/src/tir/transforms/unify_thread_binding.cc index 08fc921f4ebf..c83f00d25b82 100644 --- a/src/tir/transforms/unify_thread_binding.cc +++ b/src/tir/transforms/unify_thread_binding.cc @@ -68,7 +68,8 @@ class ThreadBindingUnifier : public StmtExprMutator { if (const auto* loop = stmt.as()) { For new_loop = GetRef(loop); new_loop.CopyOnWrite()->annotations = std::move(annotations); - return std::move(new_loop); + return new_loop; + } else { // Create a new unit loop with the annotation. DataType dtype = op->loop_var->dtype; diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc b/src/tir/transforms/unsupported_dtype_legalize.cc index c4d2d4608044..8ee1656b3fe4 100644 --- a/src/tir/transforms/unsupported_dtype_legalize.cc +++ b/src/tir/transforms/unsupported_dtype_legalize.cc @@ -257,7 +257,7 @@ class ComputeLegalizer : public StmtExprMutator { if (itr != var_remap_.end()) { return itr->second; } else { - return std::move(var); + return var; } } @@ -530,7 +530,7 @@ class StorageLegalizer : public StmtExprMutator { if (itr != var_remap_.end()) { return itr->second; } else { - return std::move(var); + return var; } } diff --git a/src/tir/transforms/update_pointer_storage_scope.cc b/src/tir/transforms/update_pointer_storage_scope.cc index 2049487b4a78..9af990d1e2bf 100644 --- a/src/tir/transforms/update_pointer_storage_scope.cc +++ b/src/tir/transforms/update_pointer_storage_scope.cc @@ -64,7 +64,7 @@ Stmt UpdatePointerStorageScope::VisitStmt_(const AllocateNode* op) { if (auto it = new_var_remap_.find(node->buffer_var.get()); it != new_var_remap_.end()) { node.CopyOnWrite()->buffer_var = it->second; } - return std::move(node); + return node; } template diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index c23ce2828ce5..cfe0145d9278 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -464,7 +464,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorsecond; } else { - return std::move(var); + return var; } } // IfThenElse expr @@ -597,7 +597,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorLegalizeDType(); } - return std::move(load); + return load; } // Let PrimExpr VisitExpr_(const LetNode* op) final { @@ -738,7 +738,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorvalue = BroadcastTo(value, total_lanes, is_last_index_scalable); } - return std::move(store); + return store; } // For Stmt VisitStmt_(const ForNode* op) final { From e6dfee5e9e7758b0c0aead8fdb1707be543afd52 Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 2 Jul 2025 14:19:03 -0400 Subject: [PATCH 3/3] fix dataflow block --- include/tvm/relax/expr.h | 5 +++++ src/relax/ir/expr.cc | 1 + 2 files changed, 6 insertions(+) diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 808fbed3cfc7..df512186696e 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -816,6 +816,11 @@ class BindingBlock : public ObjectRef { class DataflowBlockNode : public BindingBlockNode { public: + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + bool SEqualReduce(const DataflowBlockNode* other, SEqualReducer equal) const { return equal(bindings, other->bindings); } diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index 4db18817e154..da4f3cb22ec2 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -45,6 +45,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ MatchCastNode::RegisterReflection(); VarBindingNode::RegisterReflection(); BindingBlockNode::RegisterReflection(); + DataflowBlockNode::RegisterReflection(); SeqExprNode::RegisterReflection(); IfNode::RegisterReflection(); FunctionNode::RegisterReflection();