fix: inccorect bindings for kv override. Based on #1011

This commit is contained in:
Andrei Betlen 2023-12-22 14:52:20 -05:00
parent f4be84c122
commit 6d8bc090f9

View file

@ -9,6 +9,7 @@ from ctypes import (
c_int32, c_int32,
c_uint8, c_uint8,
c_uint32, c_uint32,
c_int64,
c_size_t, c_size_t,
c_float, c_float,
c_double, c_double,
@ -16,6 +17,7 @@ from ctypes import (
POINTER, POINTER,
_Pointer, # type: ignore _Pointer, # type: ignore
Structure, Structure,
Union as CtypesUnion,
Array, Array,
) )
import pathlib import pathlib
@ -317,12 +319,9 @@ class llama_batch(Structure):
# LLAMA_KV_OVERRIDE_FLOAT, # LLAMA_KV_OVERRIDE_FLOAT,
# LLAMA_KV_OVERRIDE_BOOL, # LLAMA_KV_OVERRIDE_BOOL,
# }; # };
class llama_model_kv_override_type(Structure): LLAMA_KV_OVERRIDE_INT = 0
_fields_ = [ LLAMA_KV_OVERRIDE_FLOAT = 1
("LLAMA_KV_OVERRIDE_INT", c_int), LLAMA_KV_OVERRIDE_BOOL = 2
("LLAMA_KV_OVERRIDE_FLOAT", c_int),
("LLAMA_KV_OVERRIDE_BOOL", c_int),
]
# struct llama_model_kv_override { # struct llama_model_kv_override {
# char key[128]; # char key[128];
@ -333,13 +332,18 @@ class llama_model_kv_override_type(Structure):
# bool bool_value; # bool bool_value;
# }; # };
# }; # };
class llama_model_kv_override_value(CtypesUnion):
_fields_ = [
("int_value", c_int64),
("float_value", c_double),
("bool_value", c_bool),
]
class llama_model_kv_override(Structure): class llama_model_kv_override(Structure):
_fields_ = [ _fields_ = [
("key", ctypes.c_char * 128), ("key", ctypes.c_char * 128),
("tag", llama_model_kv_override_type), ("tag", c_int),
("int_value", ctypes.c_int64), ("value", llama_model_kv_override_value),
("float_value", c_double),
("bool_value", c_bool),
] ]
# struct llama_model_params { # struct llama_model_params {