Skip to content

Commit d225438

Browse files
committed
[REFACTOR] Phase out the RelaxExpr.checked_type in favor of struct_info
Previously checked_type was used to keep track of the type in relay expression. Now that we have struct_info which strictly contains super set of information and relay is phased out. We can simplify the code logic by removing the checked_type.
1 parent 2d63574 commit d225438

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+90
-350
lines changed

include/tvm/ir/expr.h

Lines changed: 1 addition & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -400,38 +400,13 @@ TVM_DLL PrimExpr operator~(PrimExpr a);
400400
*/
401401
class RelaxExprNode : public BaseExprNode {
402402
public:
403-
/*!
404-
* \brief Stores the result of type inference(type checking).
405-
*
406-
* \note This can be undefined before type inference.
407-
* This value is discarded during serialization.
408-
*/
409-
mutable Type checked_type_ = Type(nullptr);
410-
411403
/*!
412404
* \brief Stores the result of structure information of the
413405
* expression that encapsulate both static shape and
414406
* runtime information such as shape.
415407
*/
416408
mutable Optional<ObjectRef> struct_info_ = Optional<ObjectRef>();
417409

418-
/*!
419-
* \return The checked_type
420-
*/
421-
inline const Type& checked_type() const;
422-
/*!
423-
* \brief Check if the inferred(checked) type of the Expr
424-
* is backed by a TTypeNode and return it.
425-
*
426-
* \note This function will thrown an error if the node type
427-
* of this Expr is not TTypeNode.
428-
*
429-
* \return The corresponding TTypeNode pointer.
430-
* \tparam The specific TypeNode we look for.
431-
*/
432-
template <typename TTypeNode>
433-
inline const TTypeNode* type_as() const;
434-
435410
static constexpr const char* _type_key = "RelaxExpr";
436411
static constexpr const uint32_t _type_child_slots = 22;
437412
TVM_DECLARE_BASE_OBJECT_INFO(RelaxExprNode, BaseExprNode);
@@ -463,7 +438,6 @@ class GlobalVarNode : public RelaxExprNode {
463438
void VisitAttrs(AttrVisitor* v) {
464439
v->Visit("name_hint", &name_hint);
465440
v->Visit("span", &span);
466-
v->Visit("_checked_type_", &checked_type_);
467441
v->Visit("struct_info_", &struct_info_);
468442
}
469443

@@ -487,7 +461,7 @@ class GlobalVarNode : public RelaxExprNode {
487461
*/
488462
class GlobalVar : public RelaxExpr {
489463
public:
490-
TVM_DLL explicit GlobalVar(String name_hint, Type type = {}, Span span = {});
464+
TVM_DLL explicit GlobalVar(String name_hint, Span span = {});
491465

492466
TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelaxExpr, GlobalVarNode);
493467
TVM_DEFINE_OBJECT_REF_COW_METHOD(GlobalVarNode);
@@ -747,26 +721,6 @@ class Range : public ObjectRef {
747721
TVM_DEFINE_OBJECT_REF_METHODS(Range, ObjectRef, RangeNode);
748722
};
749723

750-
// implementations
751-
inline const Type& RelaxExprNode::checked_type() const {
752-
ICHECK(checked_type_.defined()) << "internal error: the type checker has "
753-
<< "not populated the checked_type "
754-
<< "field for " << GetRef<RelaxExpr>(this);
755-
return this->checked_type_;
756-
}
757-
758-
template <typename TTypeNode>
759-
inline const TTypeNode* RelaxExprNode::type_as() const {
760-
static_assert(std::is_base_of<TypeNode, TTypeNode>::value,
761-
"TType must be a special case of type");
762-
ICHECK(checked_type_.defined())
763-
<< "Type inference for this Expr has not completed. Try to call infer_type pass.";
764-
const TTypeNode* node = checked_type_.as<TTypeNode>();
765-
ICHECK(node != nullptr) << "Expected type to be " << TTypeNode::_type_key << ", but get "
766-
<< checked_type_->GetTypeKey();
767-
return node;
768-
}
769-
770724
namespace ffi {
771725
// Type traits to enable automatic conversion into IntImm, Integer, and Bool
772726
// when called through the FFI

include/tvm/ir/function.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ constexpr const char* kGlobalSymbol = "global_symbol";
131131
* \brief Base node of all functions.
132132
*
133133
* We support several variants of functions throughout the stack.
134-
* All of the functions share the same type system(via checked_type)
134+
* All of the functions share the same type system
135135
* to support cross variant calls.
136136
*
137137
* \sa BaseFunc

include/tvm/relax/dataflow_pattern.h

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ class AndPattern;
5555
class NotPattern;
5656
class ShapePattern;
5757
class StructInfoPattern;
58-
class TypePattern;
5958
class DataTypePattern;
6059
class AttrPattern;
6160
class SameShapeConstraint;
@@ -116,8 +115,6 @@ class DFPattern : public ObjectRef {
116115
TVM_DLL AttrPattern HasAttr(const Map<String, Any>& attrs) const;
117116
/*! \brief Syntatic Sugar for creating a StructInfoPattern */
118117
TVM_DLL StructInfoPattern HasStructInfo(const StructInfo& struct_info) const;
119-
/*! \brief Syntatic Sugar for creating a TypePattern */
120-
TVM_DLL TypePattern HasType(const Type& type) const;
121118
/*! \brief Syntatic Sugar for creating a DataTypePattern with a DataType */
122119
TVM_DLL DataTypePattern HasDtype(const DataType& dtype) const;
123120
/*! \brief Syntatic Sugar for creating a DataTypePattern with a data type's name */
@@ -742,34 +739,6 @@ class WildcardPattern : public DFPattern {
742739
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(WildcardPattern, DFPattern, WildcardPatternNode);
743740
};
744741

745-
/*!
746-
* \brief Pattern for matching a certain type.
747-
* \sa TypePattern
748-
*/
749-
class TypePatternNode : public DFPatternNode {
750-
public:
751-
DFPattern pattern; /*!< The pattern to match */
752-
Type type; /*!< The type to match */
753-
754-
void VisitAttrs(tvm::AttrVisitor* v) {
755-
v->Visit("pattern", &pattern);
756-
v->Visit("type", &type);
757-
}
758-
759-
static constexpr const char* _type_key = "relax.dpl.TypePattern";
760-
TVM_DECLARE_FINAL_OBJECT_INFO(TypePatternNode, DFPatternNode);
761-
};
762-
763-
/*!
764-
* \brief Managed reference to TypePatternNode.
765-
* \sa TypePatternNode
766-
*/
767-
class TypePattern : public DFPattern {
768-
public:
769-
TVM_DLL TypePattern(DFPattern pattern, Type type);
770-
TVM_DEFINE_OBJECT_REF_METHODS(TypePattern, DFPattern, TypePatternNode);
771-
};
772-
773742
/*!
774743
* \brief Pattern for matching a certain struct info.
775744
* \sa StructInfoPattern

include/tvm/relax/dataflow_pattern_functor.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
9696
virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
9797
virtual R VisitDFPattern_(const StructInfoPatternNode* op,
9898
Args... args) DFPATTERN_FUNCTOR_DEFAULT;
99-
virtual R VisitDFPattern_(const TypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
10099
virtual R VisitDFPattern_(const WildcardPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
101100
virtual R VisitDFPattern_(const VarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
102101

@@ -132,7 +131,6 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
132131
RELAX_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode);
133132
RELAX_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode);
134133
RELAX_DFPATTERN_FUNCTOR_DISPATCH(StructInfoPatternNode);
135-
RELAX_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode);
136134
RELAX_DFPATTERN_FUNCTOR_DISPATCH(WildcardPatternNode);
137135
RELAX_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode);
138136
RELAX_DFPATTERN_FUNCTOR_DISPATCH(DataflowVarPatternNode);
@@ -167,7 +165,6 @@ class DFPatternVisitor : public DFPatternFunctor<void(const DFPattern&)> {
167165
void VisitDFPattern_(const TupleGetItemPatternNode* op) override;
168166
void VisitDFPattern_(const TuplePatternNode* op) override;
169167
void VisitDFPattern_(const StructInfoPatternNode* op) override;
170-
void VisitDFPattern_(const TypePatternNode* op) override;
171168
void VisitDFPattern_(const WildcardPatternNode* op) override;
172169
void VisitDFPattern_(const VarPatternNode* op) override;
173170

include/tvm/relax/expr.h

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,6 @@ class CallNode : public ExprNode {
166166
v->Visit("attrs", &attrs);
167167
v->Visit("sinfo_args", &sinfo_args);
168168
v->Visit("struct_info_", &struct_info_);
169-
v->Visit("_checked_type_", &checked_type_);
170169
v->Visit("span", &span);
171170
}
172171

@@ -224,7 +223,6 @@ class TupleNode : public ExprNode {
224223

225224
void VisitAttrs(tvm::AttrVisitor* v) {
226225
v->Visit("fields", &fields);
227-
v->Visit("_checked_type_", &checked_type_);
228226
v->Visit("struct_info_", &struct_info_);
229227
v->Visit("span", &span);
230228
}
@@ -291,7 +289,6 @@ class TupleGetItemNode : public ExprNode {
291289
v->Visit("tuple_value", &tuple);
292290
v->Visit("index", &index);
293291
v->Visit("struct_info_", &struct_info_);
294-
v->Visit("_checked_type_", &checked_type_);
295292
v->Visit("span", &span);
296293
}
297294

@@ -362,7 +359,6 @@ class ShapeExprNode : public LeafExprNode {
362359
void VisitAttrs(AttrVisitor* v) {
363360
v->Visit("values", &values);
364361
v->Visit("struct_info_", &struct_info_);
365-
v->Visit("_checked_type_", &checked_type_);
366362
v->Visit("span", &span);
367363
}
368364

@@ -399,7 +395,6 @@ class VarNode : public LeafExprNode {
399395
void VisitAttrs(AttrVisitor* v) {
400396
v->Visit("vid", &vid);
401397
v->Visit("struct_info_", &struct_info_);
402-
v->Visit("_checked_type_", &checked_type_);
403398
v->Visit("span", &span);
404399
}
405400

@@ -440,7 +435,6 @@ class DataflowVarNode : public VarNode {
440435
void VisitAttrs(AttrVisitor* v) {
441436
v->Visit("vid", &vid);
442437
v->Visit("struct_info_", &struct_info_);
443-
v->Visit("_checked_type_", &checked_type_);
444438
v->Visit("span", &span);
445439
}
446440

@@ -492,7 +486,6 @@ class ConstantNode : public LeafExprNode {
492486
void VisitAttrs(tvm::AttrVisitor* v) {
493487
v->Visit("data", &data);
494488
v->Visit("struct_info_", &struct_info_);
495-
v->Visit("_checked_type_", &checked_type_);
496489
v->Visit("span", &span);
497490
}
498491

@@ -540,7 +533,6 @@ class PrimValueNode : public LeafExprNode {
540533
void VisitAttrs(tvm::AttrVisitor* v) {
541534
v->Visit("value", &value);
542535
v->Visit("struct_info_", &struct_info_);
543-
v->Visit("_checked_type_", &checked_type_);
544536
v->Visit("span", &span);
545537
}
546538

@@ -591,7 +583,6 @@ class StringImmNode : public LeafExprNode {
591583
void VisitAttrs(tvm::AttrVisitor* v) {
592584
v->Visit("value", &value);
593585
v->Visit("struct_info_", &struct_info_);
594-
v->Visit("_checked_type_", &checked_type_);
595586
v->Visit("span", &span);
596587
}
597588

@@ -634,7 +625,6 @@ class DataTypeImmNode : public LeafExprNode {
634625
void VisitAttrs(tvm::AttrVisitor* v) {
635626
v->Visit("value", &value);
636627
v->Visit("struct_info_", &struct_info_);
637-
v->Visit("_checked_type_", &checked_type_);
638628
v->Visit("span", &span);
639629
}
640630

@@ -824,7 +814,6 @@ class SeqExprNode : public ExprNode {
824814
v->Visit("blocks", &blocks);
825815
v->Visit("body", &body);
826816
v->Visit("struct_info_", &struct_info_);
827-
v->Visit("_checked_type_", &checked_type_);
828817
v->Visit("span", &span);
829818
}
830819

@@ -889,7 +878,6 @@ class IfNode : public ExprNode {
889878
v->Visit("cond", &cond);
890879
v->Visit("true_branch", &true_branch);
891880
v->Visit("false_branch", &false_branch);
892-
v->Visit("_checked_type_", &checked_type_);
893881
v->Visit("struct_info_", &struct_info_);
894882
v->Visit("span", &span);
895883
}
@@ -966,7 +954,6 @@ class FunctionNode : public BaseFuncNode {
966954
v->Visit("ret_struct_info", &ret_struct_info);
967955
v->Visit("attrs", &attrs);
968956
v->Visit("struct_info_", &struct_info_);
969-
v->Visit("_checked_type_", &checked_type_);
970957
v->Visit("span", &span);
971958
}
972959

@@ -1071,7 +1058,6 @@ class ExternFuncNode : public BaseFuncNode {
10711058
void VisitAttrs(AttrVisitor* v) {
10721059
v->Visit("global_symbol", &global_symbol);
10731060
v->Visit("struct_info_", &struct_info_);
1074-
v->Visit("_checked_type_", &checked_type_);
10751061
v->Visit("span", &span);
10761062
}
10771063

include/tvm/relax/expr_functor.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ void PostOrderVisit(const Expr& node, std::function<void(const Expr&)> fvisit);
312312
/*!
313313
* \brief A mutator works in unnormalized form.
314314
*
315-
* ExprMutatorBase expects input AST to be in the unnormalized form, i.e., checked_type_ and shape_
315+
* ExprMutatorBase expects input AST to be in the unnormalized form, i.e., struct_info_
316316
* of expressions can be nullptr, and the expressions may nest(and as a result the AST is not in
317317
* ANF).
318318
*/
@@ -414,7 +414,7 @@ class ExprMutatorBase : public ExprFunctor<Expr(const Expr&)> {
414414
* \brief A mutator works in normal form.
415415
*
416416
* ExprMutator expects input AST to be in the normal form, i.e., the expressions are normalized(no
417-
* nesting and hence the AST is in ANF), and all checked_type_ and shape_ of expressions are
417+
* nesting and hence the AST is in ANF), and all struct_info_ of expressions are
418418
* available.
419419
*/
420420
class ExprMutator : public ExprMutatorBase {

include/tvm/relax/transform.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ TVM_DLL Pass AttachGlobalSymbol();
149149

150150
/*!
151151
* \brief Transform Relax IR to normal form: transform AST to A-normal form, and fill the
152-
* checked_type_ and shape_ of expressions.
152+
* struct_info_ of expressions.
153153
*
154154
* \return The Pass.
155155
*/

include/tvm/tir/function.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ class PrimFuncNode : public BaseFuncNode {
107107
v->Visit("buffer_map", &buffer_map);
108108
v->Visit("attrs", &attrs);
109109
v->Visit("span", &span);
110-
v->Visit("_checked_type_", &checked_type_);
111110
}
112111

113112
bool SEqualReduce(const PrimFuncNode* other, SEqualReducer equal) const {

python/tvm/ir/expr.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from ..runtime import Object, Scriptable
2424
from . import _ffi_api
2525
from .base import Node, Span
26-
from .type import Type
2726

2827

2928
class BaseExpr(Node):
@@ -45,20 +44,6 @@ class PrimExpr(BaseExpr):
4544
class RelaxExpr(BaseExpr):
4645
"""Base class of all non-primitive expressions."""
4746

48-
@property
49-
def checked_type(self):
50-
"""Get the checked type of tvm.relax.Expr.
51-
52-
Returns
53-
-------
54-
checked_type : tvm.ir.Type
55-
The checked type.
56-
"""
57-
ret = self._checked_type_
58-
if ret is None:
59-
raise ValueError("The type checker has not populated the checked_type for this node")
60-
return ret
61-
6247
@property
6348
def struct_info(self) -> Optional["tvm.relax.StructInfo"]:
6449
"""Get the struct info field
@@ -86,8 +71,8 @@ class GlobalVar(RelaxExpr):
8671

8772
name_hint: str
8873

89-
def __init__(self, name_hint: str, type_annot: Optional[Type] = None):
90-
self.__init_handle_by_constructor__(_ffi_api.GlobalVar, name_hint, type_annot)
74+
def __init__(self, name_hint: str):
75+
self.__init_handle_by_constructor__(_ffi_api.GlobalVar, name_hint)
9176

9277
def __call__(self, *args: RelaxExpr) -> BaseExpr:
9378
"""Call the global variable.

0 commit comments

Comments
 (0)