Skip to content

Commit 103e54b

Browse files
Deivanayaki-Sdeivanayakisankaralingam
andauthored
[Relax][PyTorch] Add PReLU Op Support for Exported Program and FX graph (#17816)
* prelu op support and test script added * end-of-file issue fixed * trailing whitespace issue fixed * fixing lint issues * fix assertion error in test_op_nn.py file * add test script in test_frontend_nn_op.py * include wrapper function for prelu in op.py * fixing unity check issue by modifying test func * conflicts resolved * add doc for prelu op axis arg * fixed failing checks issue --------- Co-authored-by: deivanayakisankaralingam <deiva@Deivanayaki>
1 parent 601d570 commit 103e54b

File tree

15 files changed

+194
-0
lines changed

15 files changed

+194
-0
lines changed

include/tvm/relax/attrs/nn.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,15 @@ struct SoftplusAttrs : public tvm::AttrsNode<SoftplusAttrs> {
468468
}
469469
};
470470

471+
/*! \brief Attributes used in PReLU operator */
472+
struct PReluAttrs : public tvm::AttrsNode<PReluAttrs> {
473+
int axis;
474+
475+
TVM_DECLARE_ATTRS(PReluAttrs, "relax.attrs.PReluAttrs") {
476+
TVM_ATTR_FIELD(axis).describe("The axis along which the alpha values are applied.");
477+
}
478+
};
479+
471480
/*! \brief Attributes used in batch_norm operator */
472481
struct BatchNormAttrs : public tvm::AttrsNode<BatchNormAttrs> {
473482
int axis;

python/tvm/relax/frontend/nn/op.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1072,6 +1072,34 @@ def softplus(x: Tensor, beta: float = 1.0, threshold: float = 20.0, name: str =
10721072
return wrap_nested(_op.nn.softplus(x._expr, beta=beta, threshold=threshold), name)
10731073

10741074

1075+
def prelu(x: Tensor, alpha: Tensor, name: str = "prelu"):
1076+
r"""Parametric ReLU activation function.
1077+
1078+
.. math::
1079+
\text{PReLU}(x) = \begin{cases}
1080+
x & \text{if } x \geq 0 \\
1081+
\alpha \cdot x & \text{if } x < 0
1082+
\end{cases}
1083+
1084+
Parameters
1085+
----------
1086+
x : Tensor
1087+
The input data.
1088+
1089+
alpha : Tensor
1090+
Slope coefficient for the negative part of the input.
1091+
1092+
name : str, optional
1093+
Optional name for the operation. Default is "prelu".
1094+
1095+
Returns
1096+
-------
1097+
result : Tensor
1098+
The computed result.
1099+
"""
1100+
return wrap_nested(_op.nn.prelu(x._expr, alpha._expr), name)
1101+
1102+
10751103
def tanh(x: Tensor, name: str = "tanh") -> Tensor:
10761104
r"""Applies the hyperbolic tangent function.
10771105

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,12 @@ def _log_softmax(self, node: fx.Node) -> relax.Var:
307307
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1)
308308
return self.block_builder.emit(relax.op.nn.log_softmax(x, dim))
309309

310+
def _prelu(self, node: fx.Node) -> relax.Var:
311+
x = self.env[node.args[0]]
312+
alpha = self.env[node.args[1]]
313+
axis = 0 if len(x.struct_info.shape) == 1 else 1
314+
return self.block_builder.emit(relax.op.nn.prelu(x, alpha, axis))
315+
310316
def _round(self, node: fx.Node) -> relax.Expr:
311317
if node.kwargs.get("decimals", 0) != 0:
312318
raise ValueError("specifying decimals for round is not supported yet")

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ def create_convert_map(
299299
"log1p.default": self._log1p,
300300
"log_softmax.int": self._log_softmax,
301301
"neg.default": self._unary_op(relax.op.negative),
302+
"prelu.default": self._prelu,
302303
"reciprocal.default": self._reciprocal,
303304
"relu.default": self._unary_op(relax.op.nn.relu),
304305
"relu_.default": self._unary_op(relax.op.nn.relu),

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,14 @@ def _log_softmax_module(self, node: fx.Node) -> relax.Var:
103103
assert dim is not None
104104
return self.block_builder.emit(relax.op.nn.log_softmax(x, dim))
105105

106+
def _prelu_module(self, node: fx.Node) -> relax.Var:
107+
x = self.env[node.args[0]]
108+
module = self.named_modules[node.target]
109+
alpha_tensor = module.weight.numpy()
110+
alpha = relax.const(alpha_tensor, dtype="float32")
111+
axis = 0 if len(x.struct_info.shape) == 1 else 1 # Extract Channel size
112+
return self.block_builder.emit(relax.op.nn.prelu(x, alpha, axis))
113+
106114
def _softmax_module(self, node: fx.Node) -> relax.Var:
107115
x = self.env[node.args[0]]
108116
module = self.named_modules[node.target]
@@ -595,6 +603,7 @@ def create_convert_map(
595603
nn.Identity: lambda node: self.env[node.args[0]],
596604
nn.LeakyReLU: self._leakyrelu_module,
597605
nn.LogSoftmax: self._log_softmax_module,
606+
nn.PReLU: self._prelu_module,
598607
nn.ReLU: self._unary_op(relax.op.nn.relu),
599608
nn.ReLU6: lambda node: self.block_builder.emit(
600609
relax.op.clip(self.env[node.args[0]], 0, 6)
@@ -657,6 +666,7 @@ def create_convert_map(
657666
"logical_not": self._unary_op(relax.op.logical_not),
658667
"log_softmax": self._log_softmax,
659668
"neg": self._unary_op(relax.op.negative),
669+
"prelu": self._prelu,
660670
"reciprocal": self._reciprocal,
661671
"relu": self._unary_op(relax.op.nn.relu),
662672
"round": self._round,

python/tvm/relax/op/nn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
max_pool3d,
4444
nll_loss,
4545
pad,
46+
prelu,
4647
relu,
4748
rms_norm,
4849
selu,

python/tvm/relax/op/nn/nn.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1431,6 +1431,32 @@ def log_softmax(data: Expr, axis: int = -1) -> Expr:
14311431
return _ffi_api.log_softmax(data, axis) # type: ignore
14321432

14331433

1434+
def prelu(data: Expr, alpha: Expr, axis: int = 1) -> Expr:
1435+
r"""Parametric Rectified Linear Unit (PReLU).
1436+
1437+
.. math::
1438+
PReLU(x) = x \text{ if } x > 0 \text{ else } \alpha * x
1439+
1440+
Parameters
1441+
----------
1442+
data : relax.Expr
1443+
The input tensor.
1444+
1445+
alpha : relax.Expr
1446+
The learnable slope tensor, applied channel-wise.
1447+
1448+
axis : int
1449+
The axis along which the `alpha` values are applied
1450+
Default is 1 (assuming NCHW format).
1451+
1452+
Returns
1453+
-------
1454+
result : relax.Expr
1455+
The computed result.
1456+
"""
1457+
return _ffi_api.prelu(data, alpha, axis)
1458+
1459+
14341460
def batch_norm(
14351461
data: Expr,
14361462
gamma: Expr,

python/tvm/relax/transform/legalize_ops/nn.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,11 @@ def _nn_leakyrelu(bb: BlockBuilder, call: Call) -> Expr:
469469
return bb.call_te(topi.nn.leaky_relu, call.args[0], call.attrs.alpha)
470470

471471

472+
@register_legalize("relax.nn.prelu")
473+
def _nn_prelu(bb: BlockBuilder, call: Call) -> Expr:
474+
return bb.call_te(topi.nn.prelu, call.args[0], call.args[1], call.attrs.axis)
475+
476+
472477
@register_legalize("relax.nn.gelu")
473478
def _nn_gelu(bb: BlockBuilder, call: Call) -> Expr:
474479
def te_gelu(x: te.Tensor):

python/tvm/topi/nn/elemwise.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ def prelu(x, slope, axis=1):
129129

130130
assert len(slope.shape) == 1
131131
assert axis < len(x.shape)
132+
slope = te.compute(
133+
(get_const_int(x.shape[axis]),), lambda c: slope[0], name="slope_broadcasted"
134+
)
132135
assert get_const_int(slope.shape[0]) == get_const_int(x.shape[axis])
133136

134137
def _compute_channelwise(*indices):

src/relax/op/nn/nn.cc

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,27 @@ TVM_REGISTER_OP("relax.nn.softplus")
8181
InferStructInfoUnaryArith</*require_float_dtype=*/true>)
8282
.set_attr<Bool>("FPurity", Bool(true));
8383

84+
/* relax.nn.prelu */
85+
TVM_REGISTER_NODE_TYPE(PReluAttrs);
86+
87+
Expr prelu(Expr data, Expr alpha, int axis = 1) {
88+
auto attrs = make_object<PReluAttrs>();
89+
attrs->axis = axis;
90+
static const Op& op = Op::Get("relax.nn.prelu");
91+
return Call(op, {data, alpha}, Attrs(attrs), {});
92+
}
93+
94+
TVM_REGISTER_GLOBAL("relax.op.nn.prelu").set_body_typed(prelu);
95+
96+
TVM_REGISTER_OP("relax.nn.prelu")
97+
.set_num_inputs(2)
98+
.add_argument("data", "Tensor", "The input tensor.")
99+
.add_argument("alpha", "Tensor", "The channel-wise learnable slope.")
100+
.set_attrs_type<PReluAttrs>()
101+
.set_attr<FInferStructInfo>("FInferStructInfo",
102+
InferStructInfoUnaryArith</*require_float_dtype=*/true>)
103+
.set_attr<Bool>("FPurity", Bool(true));
104+
84105
/* relax.nn.softmax */
85106
TVM_REGISTER_NODE_TYPE(SoftmaxAttrs);
86107

0 commit comments

Comments
 (0)