@@ -120,19 +120,12 @@ class MatMulNBits final : public OpKernel {
120
120
Status Compute (OpKernelContext* context) const override ;
121
121
122
122
Status PrePack (const Tensor& tensor, int input_idx, AllocatorPtr alloc,
123
- bool save_prepacked_initializers,
124
123
/* out*/ bool & is_packed,
125
124
/* out*/ PrePackedWeights* prepacked_weights) override ;
126
125
127
- void ConvertPrepackWeightIntoTensor (const onnxruntime::Tensor& tensor, int input_idx);
128
-
129
126
Status UseSharedPrePackedBuffers (std::vector<BufferUniquePtr>& prepacked_buffers, int input_idx,
130
127
/* out*/ bool & used_shared_buffers) override ;
131
128
132
- std::optional<Tensor> GetPrePackTensor (int /* input_idx*/ ) override ;
133
-
134
- Status SetPrePackTensor (int input_idx, const Tensor& pre_packed_tensor) override ;
135
-
136
129
private:
137
130
const size_t K_;
138
131
const size_t N_;
@@ -147,8 +140,6 @@ class MatMulNBits final : public OpKernel {
147
140
size_t packed_b_size_{0 };
148
141
IAllocatorUniquePtr<float > scales_fp32_{};
149
142
IAllocatorUniquePtr<float > bias_fp32_{};
150
- std::optional<Tensor> packed_tensor_{std::nullopt };
151
- MLDataType prepack_tensor_data_type_;
152
143
153
144
bool has_zp_input_{false };
154
145
@@ -176,22 +167,8 @@ class MatMulNBits final : public OpKernel {
176
167
const MatMulComputeHelper& helper) const ;
177
168
};
178
169
179
- template <typename T1>
180
- void MatMulNBits<T1>::ConvertPrepackWeightIntoTensor(const onnxruntime::Tensor& tensor, int input_idx) {
181
- if (input_idx == InputIndex::B) {
182
- prepack_tensor_data_type_ = tensor.DataType ();
183
- }
184
-
185
- TensorShapeVector weights_dims = {static_cast <int64_t >((packed_b_size_ - 1 ) / prepack_tensor_data_type_->Size ()) + 1 };
186
- packed_tensor_ = Tensor (prepack_tensor_data_type_,
187
- TensorShape (weights_dims),
188
- packed_b_.get (),
189
- OrtMemoryInfo (CPU, OrtAllocatorType::OrtDeviceAllocator));
190
- }
191
-
192
170
template <typename T1>
193
171
Status MatMulNBits<T1>::PrePack(const Tensor& tensor, int input_idx, /* out*/ AllocatorPtr alloc,
194
- bool save_prepacked_initializers,
195
172
/* out*/ bool & is_packed,
196
173
/* out*/ PrePackedWeights* prepacked_weights) {
197
174
ORT_UNUSED_PARAMETER (prepacked_weights);
@@ -227,18 +204,13 @@ Status MatMulNBits<T1>::PrePack(const Tensor& tensor, int input_idx, /*out*/ All
227
204
#endif // MLAS_TARGET_AMD64_IX86
228
205
}
229
206
230
- if (save_prepacked_initializers) {
231
- ConvertPrepackWeightIntoTensor (tensor, input_idx);
232
- }
233
-
234
207
return Status::OK ();
235
208
}
236
209
237
210
#if !defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || !defined(MLAS_TARGET_ARM64)
238
211
// Non-ARM-with-fp16-intrinsics fall back fp16 to fp32.
239
212
template <>
240
213
Status MatMulNBits<MLFloat16>::PrePack(const Tensor& tensor, int input_idx, /* out*/ AllocatorPtr alloc,
241
- bool save_prepacked_initializers,
242
214
/* out*/ bool & is_packed,
243
215
/* out*/ PrePackedWeights* prepacked_weights) {
244
216
ORT_UNUSED_PARAMETER (prepacked_weights);
@@ -288,34 +260,6 @@ Status MatMulNBits<MLFloat16>::PrePack(const Tensor& tensor, int input_idx, /*ou
288
260
#endif // MLAS_TARGET_AMD64_IX86
289
261
}
290
262
291
- if (save_prepacked_initializers) {
292
- ConvertPrepackWeightIntoTensor (tensor, input_idx);
293
- }
294
-
295
- return Status::OK ();
296
- }
297
-
298
- template <typename T1>
299
- std::optional<Tensor> MatMulNBits<T1>::GetPrePackTensor(int input_idx) {
300
- // For this kernel, prepack is performed on input_B, and possibly scales, zeros_points.
301
- // During compute process, scales and zeros_points will keep as it is and only use prepacked
302
- // buffer to replace input_B.
303
- // Inorder to cope with this logic, we need to return latest prepacked buffer and only serialize
304
- // the latest one. So, we need to always return packed_tensor_ here not only for input_B.
305
- ORT_UNUSED_PARAMETER (input_idx);
306
- return std::move (packed_tensor_);
307
- }
308
-
309
- template <typename T1>
310
- Status MatMulNBits<T1>::SetPrePackTensor(int input_idx, const Tensor& pre_packed_tensor) {
311
- if (input_idx == 1 ) {
312
- // pre_packed_tensor is constant initialized tensor and its lifecycle is managed by session_state,
313
- // session_state will release memory from pre_packed_tensor. packed_b_ will not release memory so
314
- // pass empty/default buffer deleter here.
315
- // const_cast here is temporary, will fix in follow up PR.
316
- packed_b_ = BufferUniquePtr (const_cast <void *>(pre_packed_tensor.DataRaw ()), BufferDeleter ());
317
- }
318
-
319
263
return Status::OK ();
320
264
}
321
265
#endif // end !MLAS_F16VEC_INTRINSICS_SUPPORTED || !MLAS_TARGET_ARM64
0 commit comments