Skip to content

Conversation

Guobing-Chen
Copy link
Owner

@Guobing-Chen Guobing-Chen commented Aug 21, 2022

This PR provides a decompose pass in TE graph optimization to decompose specified Ops
into a series of other Ops providing equivalent functionality. And add aten::Linear support in TE based on the newly added decompose pass as Linear can be constructed with matmul and add by nature.

The decompose pass can help TE to support more Ops like aten::linear which can be constructed with other Ops by nature, and also other scenarios with performance beneficial or ease the lowering/optimization inside TE.

Two tricky parts inside the decompose process:

  • A pattern OP (graph) may be decomposed into two or more different target graphs. A common scenario is for those Ops with bias as optional. Then two target graphs are needed for with or w/o bias. And the right target graph will be selected during runtime based on the real input.
  • As the target graph usually contains more than 1 Ops, there will be newly added TorchScript Values in the target graph besides graph inputs/outputs. These values do not have shape information embedded as they have not gone through the profiling run of graph executor. This will block TE as shape information are required during OP lowering. To deal with this, a dedicated mechanism is added to reformalize these values with shapes by providing a shape reformalize function for each Op to be decomposed.

This PR has been tested with unit test and also on wide&deep model which can successfully make aten::linear been pulled into NNC fusion group and provided better performance.

  • TorchScript graph before decompose Linear
%x_cont.1 : Float(1, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::to(%5, %13, %12, %12, %11)
%15 : Tensor[] = prim::ListConstruct(%1, %2, %3, %4, %x_cont.1)
%input.20 : Float(1, 82, strides=[82, 1], requires_grad=0, device=cpu) = aten::cat(%15, %10)
%input.16 : Float(1, 64, strides=[64, 1], requires_grad=0, device=cpu) = aten::linear(%input.20, %self.deepdense.0.dense.dense_layer_0.0.weight, %self.deepdense.0.dense.dense_layer_0.0.bias)
%input.24 : Float(1, 749, strides=[749, 1], requires_grad=0, device=cpu) = aten::to(%X.3, %13, %12, %12, %11)
%out.1 : Float(1, 1, strides=[1, 1], requires_grad=0, device=cpu) = aten::linear(%input.24, %self.wide.wide_linear.weight, %self.wide.wide_linear.bias)
  • TorchScript graph after decompose Linear
%x_cont.1 : Float(1, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::to(%5, %13, %12, %12, %11)
%15 : Tensor[] = prim::ListConstruct(%1, %2, %3, %4, %x_cont.1)
%input.20 : Float(1, 82, strides=[82, 1], requires_grad=0, device=cpu) = aten::cat(%15, %10)
%21 : Float(82, 64, strides=[64, 1], requires_grad=0, device=cpu) = aten::t(%self.deepdense.0.dense.dense_layer_0.0.weight)
%22 : Float(1, 64, strides=[64, 1], requires_grad=0, device=cpu) = aten::matmul(%input.20, %21)
%23 : Float(1, 64, strides=[64, 1], requires_grad=0, device=cpu) = aten::add(%22, %self.deepdense.0.dense.dense_layer_0.0.bias, %10)
%input.24 : Float(1, 749, strides=[749, 1], requires_grad=0, device=cpu) = aten::to(%X.3, %13, %12, %12, %11)
%25 : Float(749, 1, strides=[1, 1], requires_grad=0, device=cpu) = aten::t(%self.wide.wide_linear.weight)
%26 : Float(1, 1, strides=[1, 1], requires_grad=0, device=cpu) = aten::matmul(%input.24, %25)
%27 : Float(1, 1, strides=[1, 1], requires_grad=0, device=cpu) = aten::add(%26, %self.wide.wide_linear.bias, %10)

Add a decompose pass in TE graph optimization to decompose specified Ops
into a series of other Ops providing equivalent functionality.

Add aten::Linear support in TE based on the newly added decompose pass
as Linear can be constructed with matmul and add by nature.
Copy link
Collaborator

@jgong5 jgong5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The concern of adding such decomposition logic in NNC TS pass is the duplicate logic with the decomposition support that PyTorch is adding in the Python frontend. Also, it is more complicate due to the need of shape propagation which is not required if it is done from the frontend. Have we considered a lighter-weight approach that does "decomposition" directly from inside the lowering function? Can we do the lowering of "linear" by calling the lowering of "mm" and "add" directly?

@Guobing-Chen
Copy link
Owner Author

The concern of adding such decomposition logic in NNC TS pass is the duplicate logic with the decomposition support that PyTorch is adding in the Python frontend. Also, it is more complicate due to the need of shape propagation which is not required if it is done from the frontend. Have we considered a lighter-weight approach that does "decomposition" directly from inside the lowering function? Can we do the lowering of "linear" by calling the lowering of "mm" and "add" directly?

Yes, agree on the point that it is a bit complicated to do the decompose at NNC TS pass comparing with at python frontend. However, I am not sure whether we can depend on the work of this frontend work for our NNC support.

As for another alternative to do the decompose at lower function level, my thinking is that the way of NNC TS pass can provide a generic mechanism that can be used for other Ops besides Linear when needed, which is a benefit comparing with doing the decompose at lowering function level. And the shape propagation is required for both ways.

This decompose way at NNC TS pass is quite similar with our previous approach for quantization support, which actually can share the same entry point to decompose OPs during NNC optimization pass. We can justify with same philosophy for them IMHO.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants