-
Notifications
You must be signed in to change notification settings - Fork 0
Implement embedding with NNC IR #5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Re-implemented embedding with native NNC IR, and rename the exsiting external call based embedding implementation.
Updated the performance numbers for this PR. @jgong5 and @EikanWang, pls kindly comments and see whether it is good now for PR upstream. |
// generating the overall loop-nest for the Op and indirect-indexing related | ||
// logic, while leaving Op-specific logic to be defined by Op with injected | ||
// function. E.x.: take a 2D indices and a 3D target (to be indirect indexing), | ||
// and the 2nd dim is the dim to be indirect-indexing for i : size_i |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: formatting
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain the usage of the arguments by this example?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
computeIndirectIndexing() should be called by a NNC lowering function like computeEmbedding().
posOfIndices - position of indices in Inputs of the original lowering function, which should be 1 for embedding
posOfIdxingTarget - position of target (to be indirect indexing) in inputs of the original lowering function, which should be 0 for embedding
dimOfIndirectIdxing - Indirect indexing dim of target, which should be 1 (the 2nd dim) for this case
maybe I should direct passing the indices and target NNC BufHandle instead of the positions and the entire Inputs, which may be more self-explainable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: formatting
Do you mean lint formatting?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
computeIndirectIndexing() should be called by a NNC lowering function like computeEmbedding().
posOfIndices - position of indices in Inputs of the original lowering function, which should be 1 for embedding posOfIdxingTarget - position of target (to be indirect indexing) in inputs of the original lowering function, which should be 0 for embedding dimOfIndirectIdxing - Indirect indexing dim of target, which should be 1 (the 2nd dim) for this case
maybe I should direct passing the indices and target NNC BufHandle instead of the positions and the entire Inputs, which may be more self-explainable.
Oh, I mean to explain that in the comment. And explain with this example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it is clearer. BTW, it is basically a unary point-wise with indirect indexing on one of the dim, right? Is it a too specific abstraction? Do you think we should call out such semantics in the function name?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about this previously, but stay on current naming as because ---- from capability perspective, this function can support non point-wise operations, like loop-carried operations: output[i, j, k] = output[i, j, k-1] or even more complicated computation. And also this function can help with Ops with binary or even more operands with the only constrain that only one operand can be indirect-indexing. I cannot come up with a detailed Op for these scenarios, but this function can support this possibility.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have strong opinion here but I think designing the semantics with DL use cases in consideration would make things simpler. The use cases in my mind are unary pointwise (embedding lookup, index select, upsamping etc.) and sparse matmul (sparse tensor with CSR mutiplied by a dense tensor). Do you target the interface to cover both scenarios or the former? Is there any other use case you want to target?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently it is targeting the former, there is no sparse support in NNC yet I am not sure whether this interface can fully support spmm related.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, that's why I thought making the naming/semantics specific for the pointwise use case would be simpler? Anyway, you decide. :)
This PR re-implemented aten::embedding with native NNC IR, which should provide better performance and fusion potential. The existing embedding has been renamed with externalcall suffix.
A new generic infrastructure function is implemented for indirect-indexing related Ops. It helps generating the overall loop-nest for the Op and indirect-indexing related logic, while leaving Op-specific logic to be defined by Op with injected function. E.x.: take a 2D indices and a 3D target (to be indirect indexing), and the 2nd dim is the dim to be indirect-indexing. This new infrastructure function will generate below lowering statement, while each Op can specify its specific logic with a lambda function as the inner-most loop body.
aten::embedding is implemented based on above new function very easily.
This PR has been tested and verified with unit case as below:
With this PR, embedding will be pulled into NNC fusion group as below:
And related NNC IR (original) are generated as below:
The performance of this PR (embedding+add+mul all fused in NNC) vs. before this PR (add+mul fused in NNC, embedding as aten call) under JIT mode is as below table (run above unit test with numactl to set CPU Cores on a CPX 4x8380H server), which shows about 80% ~ more than 200% performance benefit.