Skip to content

Commit e128cfa

Browse files
committed
WIP: feat: support for per-tensor quant types
Allows changing weight quants by regex, with a similar syntax to the llama.cpp tensor overrides. This is useful for instance to reduce unet memory requirements without changing the more sensitive clip weights. I only did the conversion part for now, but I intend to adapt the inference code afterwards. BTW, if we consider a plain quant name as matching any tensor, this could reuse the --type command line argument instead of demanding another one. Same thing for the type parameters for conversion and inference, which in principle could be simply changed to a const char*.
1 parent 10c6501 commit e128cfa

File tree

4 files changed

+71
-9
lines changed

4 files changed

+71
-9
lines changed

examples/cli/main.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ struct SDParams {
8484
std::string stacked_id_embeddings_path;
8585
std::string input_id_images_path;
8686
sd_type_t wtype = SD_TYPE_COUNT;
87+
std::string tensor_wtype;
8788
std::string lora_model_dir;
8889
std::string output_path = "output.png";
8990
std::string input_path;
@@ -204,6 +205,7 @@ void print_usage(int argc, const char* argv[]) {
204205
printf(" --upscale-repeats Run the ESRGAN upscaler this many times (default 1)\n");
205206
printf(" --type [TYPE] weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K)\n");
206207
printf(" If not specified, the default is the type of the weight file\n");
208+
printf(" --tensor_type [EXPRESSION] weight type per tensor pattern (example: \"^vae\\.=f16,model\\.=q8_0\")\n");
207209
printf(" --lora-model-dir [DIR] lora model directory\n");
208210
printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n");
209211
printf(" --mask [MASK] path to the mask image, required by img2img with mask\n");
@@ -381,6 +383,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
381383
valid_types.c_str());
382384
exit(1);
383385
}
386+
} else if (arg == "--tensor_type") {
387+
if (++i >= argc) {
388+
invalid_arg = true;
389+
break;
390+
}
391+
params.tensor_wtype = argv[i];
384392
} else if (arg == "--lora-model-dir") {
385393
if (++i >= argc) {
386394
invalid_arg = true;
@@ -800,7 +808,7 @@ int main(int argc, const char* argv[]) {
800808
}
801809

802810
if (params.mode == CONVERT) {
803-
bool success = convert(params.model_path.c_str(), params.vae_path.c_str(), params.output_path.c_str(), params.wtype);
811+
bool success = convert(params.model_path.c_str(), params.vae_path.c_str(), params.output_path.c_str(), params.wtype, params.tensor_wtype.c_str());
804812
if (!success) {
805813
fprintf(stderr,
806814
"convert '%s'/'%s' to '%s' failed\n",

model.cpp

Lines changed: 60 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1946,6 +1946,44 @@ bool ModelLoader::load_tensors(std::map<std::string, struct ggml_tensor*>& tenso
19461946
return true;
19471947
}
19481948

1949+
std::vector<std::pair<std::string,ggml_type> > parse_quant_overrides (const std::string & overrides)
1950+
{
1951+
std::vector<std::pair<std::string, ggml_type> > result;
1952+
for (const auto & item : splitString(overrides, ',')) {
1953+
std::string::size_type pos = item.find('=');
1954+
if (pos == std::string::npos) {
1955+
LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str());
1956+
continue;
1957+
}
1958+
std::string tensor_pattern = item.substr(0, pos);
1959+
std::string quant_name = item.substr(pos + 1);
1960+
1961+
ggml_type over_type = GGML_TYPE_COUNT;
1962+
1963+
if (quant_name == "f32") {
1964+
over_type = GGML_TYPE_F32;
1965+
}
1966+
else {
1967+
for (size_t i = 0; i < SD_TYPE_COUNT; i++) {
1968+
auto trait = ggml_get_type_traits((ggml_type)i);
1969+
if (trait->to_float && trait->type_size && quant_name == trait->type_name) {
1970+
over_type = (ggml_type)i;
1971+
}
1972+
}
1973+
}
1974+
1975+
if (over_type != GGML_TYPE_COUNT) {
1976+
//LOG_INFO("adding quant override \"%s\" ==> %s", tensor_pattern.c_str(), ggml_type_name(over_type));
1977+
result.emplace_back(tensor_pattern, over_type);
1978+
}
1979+
else {
1980+
LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str());
1981+
}
1982+
1983+
}
1984+
return result;
1985+
}
1986+
19491987
bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type) {
19501988
const std::string& name = tensor_storage.name;
19511989
if (type != GGML_TYPE_COUNT) {
@@ -1977,7 +2015,7 @@ bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage
19772015
return false;
19782016
}
19792017

1980-
bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type) {
2018+
bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type, const char * overrides) {
19812019
auto backend = ggml_backend_cpu_init();
19822020
size_t mem_size = 1 * 1024 * 1024; // for padding
19832021
mem_size += tensor_storages.size() * ggml_tensor_overhead();
@@ -1987,12 +2025,27 @@ bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type
19872025

19882026
gguf_context* gguf_ctx = gguf_init_empty();
19892027

2028+
if (overrides == nullptr)
2029+
overrides = "";
2030+
auto quant_overrides = parse_quant_overrides(overrides);
2031+
19902032
auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool {
19912033
const std::string& name = tensor_storage.name;
1992-
19932034
ggml_type tensor_type = tensor_storage.type;
1994-
if (tensor_should_be_converted(tensor_storage, type)) {
1995-
tensor_type = type;
2035+
ggml_type change_type = type;
2036+
2037+
for (const auto & quant_override : quant_overrides) {
2038+
std::regex pattern(quant_override.first);
2039+
if (std::regex_search(name, pattern)) {
2040+
change_type = quant_override.second;
2041+
//LOG_DEBUG("%s quant overriden to %s by \"%s\"",
2042+
// name.c_str(), ggml_type_name(quant_override.second), quant_override.first.c_str());
2043+
break;
2044+
}
2045+
}
2046+
2047+
if (tensor_should_be_converted(tensor_storage, change_type)) {
2048+
tensor_type = change_type;
19962049
}
19972050

19982051
ggml_tensor* tensor = ggml_new_tensor(ggml_ctx, tensor_type, tensor_storage.n_dims, tensor_storage.ne);
@@ -2051,7 +2104,8 @@ int64_t ModelLoader::get_params_mem_size(ggml_backend_t backend, ggml_type type)
20512104
return mem_size;
20522105
}
20532106

2054-
bool convert(const char* input_path, const char* vae_path, const char* output_path, sd_type_t output_type) {
2107+
bool convert(const char* input_path, const char* vae_path, const char* output_path, sd_type_t output_type,
2108+
const char * output_tensor_type) {
20552109
ModelLoader model_loader;
20562110

20572111
if (!model_loader.init_from_file(input_path)) {
@@ -2065,6 +2119,6 @@ bool convert(const char* input_path, const char* vae_path, const char* output_pa
20652119
return false;
20662120
}
20672121
}
2068-
bool success = model_loader.save_to_gguf_file(output_path, (ggml_type)output_type);
2122+
bool success = model_loader.save_to_gguf_file(output_path, (ggml_type)output_type, output_tensor_type);
20692123
return success;
20702124
}

model.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ class ModelLoader {
221221
ggml_backend_t backend,
222222
std::set<std::string> ignore_tensors = {});
223223

224-
bool save_to_gguf_file(const std::string& file_path, ggml_type type);
224+
bool save_to_gguf_file(const std::string& file_path, ggml_type type, const char * overrides);
225225
bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type);
226226
int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT);
227227
~ModelLoader() = default;

stable-diffusion.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ SD_API void free_upscaler_ctx(upscaler_ctx_t* upscaler_ctx);
228228

229229
SD_API sd_image_t upscale(upscaler_ctx_t* upscaler_ctx, sd_image_t input_image, uint32_t upscale_factor);
230230

231-
SD_API bool convert(const char* input_path, const char* vae_path, const char* output_path, enum sd_type_t output_type);
231+
SD_API bool convert(const char* input_path, const char* vae_path, const char* output_path, enum sd_type_t output_type, const char * output_tensor_type);
232232

233233
SD_API uint8_t* preprocess_canny(uint8_t* img,
234234
int width,

0 commit comments

Comments
 (0)