Skip to content

Commit fd8ee48

Browse files
[JS/WebGPU] GroupQueryAttention rewrite (#20946)
### Description Implement JSEP GroupQueryAttention ### Motivation and Context Required to enable certain LLM models to run using WebGPU.
1 parent 33e2f6a commit fd8ee48

File tree

7 files changed

+1304
-547
lines changed

7 files changed

+1304
-547
lines changed

js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import { gather, parseGatherAttributes } from './ops/gather';
1919
import { gatherBlockQuantized, parseGatherBlockQuantizedAttributes } from './ops/gather-block-quantized';
2020
import { gatherElements, parseGatherElementsAttributes } from './ops/gather-elements';
2121
import { gemm, parseGemmAttributes } from './ops/gemm';
22-
import { groupQueryAttention, parseGroupQueryAttentionAttributes } from './ops/group-query-attention';
22+
import { groupQueryAttention } from './ops/group-query-attention';
2323
import { instanceNorm } from './ops/instance-norm';
2424
import { layerNorm } from './ops/layer-norm';
2525
import { matMul } from './ops/matmul';
@@ -104,7 +104,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
104104
['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]],
105105
['Greater', [binaryOps.greater]],
106106
['GreaterOrEqual', [binaryOps.greaterOrEqual]],
107-
['GroupQueryAttention', [groupQueryAttention, parseGroupQueryAttentionAttributes]],
107+
['GroupQueryAttention', [groupQueryAttention]],
108108
['HardSigmoid', [unaryOps.hardSigmoid, unaryOps.parseHardSigmoidAttributes]],
109109
['InstanceNormalization', [instanceNorm]],
110110
['LayerNormalization', [layerNorm]],

js/web/lib/wasm/jsep/webgpu/ops/attention.ts

Lines changed: 258 additions & 96 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)