@@ -49,6 +49,7 @@ class TensorInfo:
49
49
class GGUFValue :
50
50
value : Any
51
51
type : GGUFValueType
52
+ sub_type : GGUFValueType | None = None
52
53
53
54
54
55
class WriterState (Enum ):
@@ -238,7 +239,7 @@ def write_kv_data_to_file(self) -> None:
238
239
239
240
for key , val in kv_data .items ():
240
241
kv_bytes += self ._pack_val (key , GGUFValueType .STRING , add_vtype = False )
241
- kv_bytes += self ._pack_val (val .value , val .type , add_vtype = True )
242
+ kv_bytes += self ._pack_val (val .value , val .type , add_vtype = True , sub_type = val . sub_type )
242
243
243
244
fout .write (kv_bytes )
244
245
@@ -268,11 +269,11 @@ def write_ti_data_to_file(self) -> None:
268
269
fout .flush ()
269
270
self .state = WriterState .TI_DATA
270
271
271
- def add_key_value (self , key : str , val : Any , vtype : GGUFValueType ) -> None :
272
+ def add_key_value (self , key : str , val : Any , vtype : GGUFValueType , sub_type : GGUFValueType | None = None ) -> None :
272
273
if any (key in kv_data for kv_data in self .kv_data ):
273
274
raise ValueError (f'Duplicated key name { key !r} ' )
274
275
275
- self .kv_data [0 ][key ] = GGUFValue (value = val , type = vtype )
276
+ self .kv_data [0 ][key ] = GGUFValue (value = val , type = vtype , sub_type = sub_type )
276
277
277
278
def add_uint8 (self , key : str , val : int ) -> None :
278
279
self .add_key_value (key ,val , GGUFValueType .UINT8 )
@@ -1022,7 +1023,7 @@ def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
1022
1023
pack_prefix = '<' if self .endianess == GGUFEndian .LITTLE else '>'
1023
1024
return struct .pack (f'{ pack_prefix } { fmt } ' , value )
1024
1025
1025
- def _pack_val (self , val : Any , vtype : GGUFValueType , add_vtype : bool ) -> bytes :
1026
+ def _pack_val (self , val : Any , vtype : GGUFValueType , add_vtype : bool , sub_type : GGUFValueType | None = None ) -> bytes :
1026
1027
kv_data = bytearray ()
1027
1028
1028
1029
if add_vtype :
@@ -1043,7 +1044,9 @@ def _pack_val(self, val: Any, vtype: GGUFValueType, add_vtype: bool) -> bytes:
1043
1044
if len (val ) == 0 :
1044
1045
raise ValueError ("Invalid GGUF metadata array. Empty array" )
1045
1046
1046
- if isinstance (val , bytes ):
1047
+ if sub_type is not None :
1048
+ ltype = sub_type
1049
+ elif isinstance (val , bytes ):
1047
1050
ltype = GGUFValueType .UINT8
1048
1051
else :
1049
1052
ltype = GGUFValueType .get_type (val [0 ])
0 commit comments