@@ -1946,6 +1946,44 @@ bool ModelLoader::load_tensors(std::map<std::string, struct ggml_tensor*>& tenso
1946
1946
return true ;
1947
1947
}
1948
1948
1949
+ std::vector<std::pair<std::string,ggml_type> > parse_quant_overrides (const std::string & overrides)
1950
+ {
1951
+ std::vector<std::pair<std::string, ggml_type> > result;
1952
+ for (const auto & item : splitString (overrides, ' ,' )) {
1953
+ std::string::size_type pos = item.find (' =' );
1954
+ if (pos == std::string::npos) {
1955
+ LOG_WARN (" ignoring invalid quant override \" %s\" " , item.c_str ());
1956
+ continue ;
1957
+ }
1958
+ std::string tensor_pattern = item.substr (0 , pos);
1959
+ std::string quant_name = item.substr (pos + 1 );
1960
+
1961
+ ggml_type over_type = GGML_TYPE_COUNT;
1962
+
1963
+ if (quant_name == " f32" ) {
1964
+ over_type = GGML_TYPE_F32;
1965
+ }
1966
+ else {
1967
+ for (size_t i = 0 ; i < SD_TYPE_COUNT; i++) {
1968
+ auto trait = ggml_get_type_traits ((ggml_type)i);
1969
+ if (trait->to_float && trait->type_size && quant_name == trait->type_name ) {
1970
+ over_type = (ggml_type)i;
1971
+ }
1972
+ }
1973
+ }
1974
+
1975
+ if (over_type != GGML_TYPE_COUNT) {
1976
+ // LOG_INFO("adding quant override \"%s\" ==> %s", tensor_pattern.c_str(), ggml_type_name(over_type));
1977
+ result.emplace_back (tensor_pattern, over_type);
1978
+ }
1979
+ else {
1980
+ LOG_WARN (" ignoring invalid quant override \" %s\" " , item.c_str ());
1981
+ }
1982
+
1983
+ }
1984
+ return result;
1985
+ }
1986
+
1949
1987
bool ModelLoader::tensor_should_be_converted (const TensorStorage& tensor_storage, ggml_type type) {
1950
1988
const std::string& name = tensor_storage.name ;
1951
1989
if (type != GGML_TYPE_COUNT) {
@@ -1977,7 +2015,7 @@ bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage
1977
2015
return false ;
1978
2016
}
1979
2017
1980
- bool ModelLoader::save_to_gguf_file (const std::string& file_path, ggml_type type) {
2018
+ bool ModelLoader::save_to_gguf_file (const std::string& file_path, ggml_type type, const char * overrides ) {
1981
2019
auto backend = ggml_backend_cpu_init ();
1982
2020
size_t mem_size = 1 * 1024 * 1024 ; // for padding
1983
2021
mem_size += tensor_storages.size () * ggml_tensor_overhead ();
@@ -1987,12 +2025,27 @@ bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type
1987
2025
1988
2026
gguf_context* gguf_ctx = gguf_init_empty ();
1989
2027
2028
+ if (overrides == nullptr )
2029
+ overrides = " " ;
2030
+ auto quant_overrides = parse_quant_overrides (overrides);
2031
+
1990
2032
auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool {
1991
2033
const std::string& name = tensor_storage.name ;
1992
-
1993
2034
ggml_type tensor_type = tensor_storage.type ;
1994
- if (tensor_should_be_converted (tensor_storage, type)) {
1995
- tensor_type = type;
2035
+ ggml_type change_type = type;
2036
+
2037
+ for (const auto & quant_override : quant_overrides) {
2038
+ std::regex pattern (quant_override.first );
2039
+ if (std::regex_search (name, pattern)) {
2040
+ change_type = quant_override.second ;
2041
+ // LOG_DEBUG("%s quant overriden to %s by \"%s\"",
2042
+ // name.c_str(), ggml_type_name(quant_override.second), quant_override.first.c_str());
2043
+ break ;
2044
+ }
2045
+ }
2046
+
2047
+ if (tensor_should_be_converted (tensor_storage, change_type)) {
2048
+ tensor_type = change_type;
1996
2049
}
1997
2050
1998
2051
ggml_tensor* tensor = ggml_new_tensor (ggml_ctx, tensor_type, tensor_storage.n_dims , tensor_storage.ne );
@@ -2051,7 +2104,8 @@ int64_t ModelLoader::get_params_mem_size(ggml_backend_t backend, ggml_type type)
2051
2104
return mem_size;
2052
2105
}
2053
2106
2054
- bool convert (const char * input_path, const char * vae_path, const char * output_path, sd_type_t output_type) {
2107
+ bool convert (const char * input_path, const char * vae_path, const char * output_path, sd_type_t output_type,
2108
+ const char * output_tensor_type) {
2055
2109
ModelLoader model_loader;
2056
2110
2057
2111
if (!model_loader.init_from_file (input_path)) {
@@ -2065,6 +2119,6 @@ bool convert(const char* input_path, const char* vae_path, const char* output_pa
2065
2119
return false ;
2066
2120
}
2067
2121
}
2068
- bool success = model_loader.save_to_gguf_file (output_path, (ggml_type)output_type);
2122
+ bool success = model_loader.save_to_gguf_file (output_path, (ggml_type)output_type, output_tensor_type );
2069
2123
return success;
2070
2124
}
0 commit comments