@@ -98,19 +98,12 @@ class MatMulNBits final : public OpKernel {
98
98
Status Compute (OpKernelContext* context) const override ;
99
99
100
100
Status PrePack (const Tensor& tensor, int input_idx, AllocatorPtr alloc,
101
- bool save_prepacked_initializers,
102
101
/* out*/ bool & is_packed,
103
102
/* out*/ PrePackedWeights* prepacked_weights) override ;
104
103
105
- void ConvertPrepackWeightIntoTensor (const onnxruntime::Tensor& tensor, int input_idx);
106
-
107
104
Status UseSharedPrePackedBuffers (std::vector<BufferUniquePtr>& prepacked_buffers, int input_idx,
108
105
/* out*/ bool & used_shared_buffers) override ;
109
106
110
- std::optional<Tensor> GetPrePackTensor (int /* input_idx*/ ) override ;
111
-
112
- Status SetPrePackTensor (int input_idx, const Tensor& pre_packed_tensor) override ;
113
-
114
107
private:
115
108
const size_t K_;
116
109
const size_t N_;
@@ -126,8 +119,6 @@ class MatMulNBits final : public OpKernel {
126
119
size_t packed_b_size_{0 };
127
120
IAllocatorUniquePtr<float > scales_fp32_{};
128
121
IAllocatorUniquePtr<float > bias_fp32_{};
129
- std::optional<Tensor> packed_tensor_{std::nullopt };
130
- MLDataType prepack_tensor_data_type_;
131
122
132
123
bool has_zp_input_{false };
133
124
@@ -157,22 +148,8 @@ class MatMulNBits final : public OpKernel {
157
148
}
158
149
};
159
150
160
- template <typename T1>
161
- void MatMulNBits<T1>::ConvertPrepackWeightIntoTensor(const onnxruntime::Tensor& tensor, int input_idx) {
162
- if (input_idx == InputIndex::B) {
163
- prepack_tensor_data_type_ = tensor.DataType ();
164
- }
165
-
166
- TensorShapeVector weights_dims = {static_cast <int64_t >((packed_b_size_ - 1 ) / prepack_tensor_data_type_->Size ()) + 1 };
167
- packed_tensor_ = Tensor (prepack_tensor_data_type_,
168
- TensorShape (weights_dims),
169
- packed_b_.get (),
170
- OrtMemoryInfo (CPU, OrtAllocatorType::OrtDeviceAllocator));
171
- }
172
-
173
151
template <typename T1>
174
152
Status MatMulNBits<T1>::PrePack(const Tensor& tensor, int input_idx, /* out*/ AllocatorPtr alloc,
175
- bool save_prepacked_initializers,
176
153
/* out*/ bool & is_packed,
177
154
/* out*/ PrePackedWeights* prepacked_weights) {
178
155
ORT_UNUSED_PARAMETER (prepacked_weights);
@@ -208,16 +185,11 @@ Status MatMulNBits<T1>::PrePack(const Tensor& tensor, int input_idx, /*out*/ All
208
185
#endif // MLAS_TARGET_AMD64_IX86
209
186
}
210
187
211
- if (save_prepacked_initializers) {
212
- ConvertPrepackWeightIntoTensor (tensor, input_idx);
213
- }
214
-
215
188
return Status::OK ();
216
189
}
217
190
218
191
template <>
219
192
Status MatMulNBits<MLFloat16>::PrePack(const Tensor& tensor, int input_idx, /* out*/ AllocatorPtr alloc,
220
- bool save_prepacked_initializers,
221
193
/* out*/ bool & is_packed,
222
194
/* out*/ PrePackedWeights* prepacked_weights) {
223
195
ORT_UNUSED_PARAMETER (prepacked_weights);
@@ -267,34 +239,6 @@ Status MatMulNBits<MLFloat16>::PrePack(const Tensor& tensor, int input_idx, /*ou
267
239
#endif // MLAS_TARGET_AMD64_IX86
268
240
}
269
241
270
- if (save_prepacked_initializers) {
271
- ConvertPrepackWeightIntoTensor (tensor, input_idx);
272
- }
273
-
274
- return Status::OK ();
275
- }
276
-
277
- template <typename T1>
278
- std::optional<Tensor> MatMulNBits<T1>::GetPrePackTensor(int input_idx) {
279
- // For this kernel, prepack is performed on input_B, and possibly scales, zeros_points.
280
- // During compute process, scales and zeros_points will keep as it is and only use prepacked
281
- // buffer to replace input_B.
282
- // Inorder to cope with this logic, we need to return latest prepacked buffer and only serialize
283
- // the latest one. So, we need to always return packed_tensor_ here not only for input_B.
284
- ORT_UNUSED_PARAMETER (input_idx);
285
- return std::move (packed_tensor_);
286
- }
287
-
288
- template <typename T1>
289
- Status MatMulNBits<T1>::SetPrePackTensor(int input_idx, const Tensor& pre_packed_tensor) {
290
- if (input_idx == 1 ) {
291
- // pre_packed_tensor is constant initialized tensor and its lifecycle is managed by session_state,
292
- // session_state will release memory from pre_packed_tensor. packed_b_ will not release memory so
293
- // pass empty/default buffer deleter here.
294
- // const_cast here is temporary, will fix in follow up PR.
295
- packed_b_ = BufferUniquePtr (const_cast <void *>(pre_packed_tensor.DataRaw ()), BufferDeleter ());
296
- }
297
-
298
242
return Status::OK ();
299
243
}
300
244
0 commit comments