Skip to content

Conversation

Guobing-Chen
Copy link
Owner

@Guobing-Chen Guobing-Chen commented Aug 22, 2022

This PR re-implemented aten::embedding with native NNC IR, which should provide better performance and fusion potential. The existing embedding has been renamed with externalcall suffix.

A new generic infrastructure function is implemented for indirect-indexing related Ops. It helps generating the overall loop-nest for the Op and indirect-indexing related logic, while leaving Op-specific logic to be defined by Op with injected function. E.x.: take a 2D indices and a 3D target (to be indirect indexing), and the 2nd dim is the dim to be indirect-indexing. This new infrastructure function will generate below lowering statement, while each Op can specify its specific logic with a lambda function as the inner-most loop body.

for i : size_i
   for j : size_j
     x = indices[i, j]
     for m : size_m
       for n : size_n
         innerStmt by Op with innerStmtFunc(idxingTarget[m, x, n], [i, j, m, n]) 

aten::embedding is implemented based on above new function very easily.

This PR has been tested and verified with unit case as below:

import time
import torch
import torch.nn as nn

warm_up_num = 10
run_num = 200000

class EmbeddingModel(nn.Module):
  def __init__(self):
    super(EmbeddingModel,self).__init__()
    self.embedding = nn.Embedding(10000, 300)
  def forward(self,x):
    x = self.embedding(x)
    x = x + x
    x = x * x
    return x

if __name__ =='__main__':
  torch.manual_seed(0)
  model = EmbeddingModel()
  model.eval()
  print(model)

  jit_model = torch.jit.script(model)
  jit_model = torch.jit.freeze(jit_model)
  print("[INFO]: Before fusion")
  print(jit_model.graph)

  input_tensor = torch.LongTensor([[1,2,4,5],[4,3,2,9]])

  print("Warming up ...")
  with torch.no_grad():
    for i in range(warm_up_num):
      warmup_ts = time.time()
      out = jit_model(input_tensor)
      warmup_duration = (time.time() - warmup_ts)*1000000
      print(f"  round {i} : {warmup_duration} us")

  print("")
  print("Official run and benchmarking...")
  with torch.no_grad():
    start_ts = time.time()
    for i in range(run_num):
      out = jit_model(input_tensor)
    end_ts = time.time()
  print(f"Latency: {(end_ts-start_ts)/run_num*1000000} us")

With this PR, embedding will be pulled into NNC fusion group as below:

%x.26 : Float(2, 4, 300, strides=[1200, 300, 1], requires_grad=0, device=cpu) = aten::embedding(%self.embedding.weight, %x.1, %padding_idx.41, %2, %2)
%x.22 : Float(2, 4, 300, strides=[1200, 300, 1], requires_grad=0, device=cpu) = aten::add(%x.26, %x.26, %1)
%x.18 : Float(2, 4, 300, strides=[1200, 300, 1], requires_grad=0, device=cpu) = aten::mul(%x.22, %x.22)

And related NNC IR (original) are generated as below:

after fuse{
  for (int64_t i = 0ll; i < 2ll; i++) {
    for (int64_t j = 0ll; j < 4ll; j++) {
      int64_t ind_idx = tx_1[i, j];
      for (int64_t k_2 = 0ll; k_2 < 300ll; k_2++) {
        aten_mul[i, j, k_2] = ((const_self_embedding_weight[ind_idx, k_2]) + (const_self_embedding_weight[ind_idx, k_2])) * ((const_self_embedding_weight[ind_idx, k_2]) + (const_self_embedding_weight[ind_idx, k_2]));
      }
    }
  }
}

The performance of this PR (embedding+add+mul all fused in NNC) vs. before this PR (add+mul fused in NNC, embedding as aten call) under JIT mode is as below table (run above unit test with numactl to set CPU Cores on a CPX 4x8380H server), which shows about 80% ~ more than 200% performance benefit.

Execution CPU Vector DIM Word Num Indices Pytorch (us) Pytorch + NNC Embedding (us) Perf Improvement
1 16 16 [16] 9.72 4.72 105.93%
1 16 16 [128] 10.47 5.69 84.01%
1 16 16 [256, 256] 570.15 199.63 185.60%
1 16 10000 [16] 10.02 4.76 110.50%
1 16 10000 [128] 10.41 5.79 79.79%
1 16 10000 [256, 256] 762.89 268.16 184.49%
1 128 100 [16] 10.36 5.08 103.94%
1 128 100 [128] 13.57 5.82 133.16%
1 128 100 [256, 256] 27898.13 13461.12 107.25%
1 128 10000 [16] 10.38 5.13 102.34%
1 128 10000 [128] 13.51 6.15 119.67%
1 128 10000 [256, 256] 29312.62 14779.03 98.34%
1 512 10000 [16] 11.41 5.29 115.69%
1 512 10000 [128] 44.72 10.04 345.42%
1 512 10000 [256, 256] 122491.56 62629.17 95.58%
4 128 10000 [16] 10.45 4.96 110.69%
4 128 10000 [128] 13.78 5.99 130.05%
4 128 10000 [256, 256] 9771.33 4973.41 96.47%
4 512 10000 [16] 11.58 5.54 109.03%
4 512 10000 [128] 27.39 8.96 205.69%
4 512 10000 [256, 256] 40666.72 20633.8 97.09%

Re-implemented embedding with native NNC IR, and rename the exsiting
external call based embedding implementation.
@Guobing-Chen
Copy link
Owner Author

Updated the performance numbers for this PR. @jgong5 and @EikanWang, pls kindly comments and see whether it is good now for PR upstream.

// generating the overall loop-nest for the Op and indirect-indexing related
// logic, while leaving Op-specific logic to be defined by Op with injected
// function. E.x.: take a 2D indices and a 3D target (to be indirect indexing),
// and the 2nd dim is the dim to be indirect-indexing for i : size_i
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: formatting

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain the usage of the arguments by this example?

Copy link
Owner Author

@Guobing-Chen Guobing-Chen Aug 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

computeIndirectIndexing() should be called by a NNC lowering function like computeEmbedding().

posOfIndices - position of indices in Inputs of the original lowering function, which should be 1 for embedding
posOfIdxingTarget - position of target (to be indirect indexing) in inputs of the original lowering function, which should be 0 for embedding
dimOfIndirectIdxing - Indirect indexing dim of target, which should be 1 (the 2nd dim) for this case

maybe I should direct passing the indices and target NNC BufHandle instead of the positions and the entire Inputs, which may be more self-explainable.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: formatting

Do you mean lint formatting?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

computeIndirectIndexing() should be called by a NNC lowering function like computeEmbedding().

posOfIndices - position of indices in Inputs of the original lowering function, which should be 1 for embedding posOfIdxingTarget - position of target (to be indirect indexing) in inputs of the original lowering function, which should be 0 for embedding dimOfIndirectIdxing - Indirect indexing dim of target, which should be 1 (the 2nd dim) for this case

maybe I should direct passing the indices and target NNC BufHandle instead of the positions and the entire Inputs, which may be more self-explainable.

Oh, I mean to explain that in the comment. And explain with this example.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is clearer. BTW, it is basically a unary point-wise with indirect indexing on one of the dim, right? Is it a too specific abstraction? Do you think we should call out such semantics in the function name?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about this previously, but stay on current naming as because ---- from capability perspective, this function can support non point-wise operations, like loop-carried operations: output[i, j, k] = output[i, j, k-1] or even more complicated computation. And also this function can help with Ops with binary or even more operands with the only constrain that only one operand can be indirect-indexing. I cannot come up with a detailed Op for these scenarios, but this function can support this possibility.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have strong opinion here but I think designing the semantics with DL use cases in consideration would make things simpler. The use cases in my mind are unary pointwise (embedding lookup, index select, upsamping etc.) and sparse matmul (sparse tensor with CSR mutiplied by a dense tensor). Do you target the interface to cover both scenarios or the former? Is there any other use case you want to target?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently it is targeting the former, there is no sparse support in NNC yet I am not sure whether this interface can fully support spmm related.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, that's why I thought making the naming/semantics specific for the pointwise use case would be simpler? Anyway, you decide. :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants