@@ -118,6 +118,16 @@ def _compute_cos_sin_cache(self) -> torch.Tensor:
118
118
cache = torch .cat ((cos , sin ), dim = - 1 )
119
119
return cache
120
120
121
+ def forward (self , * args , ** kwargs ):
122
+ if torch .compiler .is_compiling ():
123
+ return self .forward_native (* args , ** kwargs )
124
+ if _is_cuda :
125
+ return self .forward_cuda (* args , ** kwargs )
126
+ elif _is_cpu_amx :
127
+ return self .forward_cpu (* args , ** kwargs )
128
+ else :
129
+ return self .forward_native (* args , ** kwargs )
130
+
121
131
def forward_native (
122
132
self ,
123
133
positions : torch .Tensor ,
@@ -148,6 +158,26 @@ def forward_native(
148
158
key = torch .cat ((key_rot , key_pass ), dim = - 1 ).reshape (key_shape )
149
159
return query , key
150
160
161
+ def forward_cpu (
162
+ self ,
163
+ positions : torch .Tensor ,
164
+ query : torch .Tensor ,
165
+ key : torch .Tensor ,
166
+ offsets : Optional [torch .Tensor ] = None ,
167
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
168
+ positions = torch .add (positions , offsets ) if offsets is not None else positions
169
+ if positions .device == torch .device ("cpu" ) and _is_cpu_amx :
170
+ return torch .ops .sgl_kernel .rotary_embedding_cpu (
171
+ positions ,
172
+ query ,
173
+ key ,
174
+ self .head_size ,
175
+ self .cos_sin_cache ,
176
+ self .is_neox_style ,
177
+ )
178
+ else :
179
+ return self .forward_native (positions , query , key , offsets )
180
+
151
181
def forward_cuda (
152
182
self ,
153
183
positions : torch .Tensor ,
0 commit comments