From 6d8bc090f9b17dd10b1359a8df09a9b25f8f9036 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 22 Dec 2023 14:52:20 -0500 Subject: [PATCH] fix: inccorect bindings for kv override. Based on #1011 --- llama_cpp/llama_cpp.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 3732b58..64b567b 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -9,6 +9,7 @@ from ctypes import ( c_int32, c_uint8, c_uint32, + c_int64, c_size_t, c_float, c_double, @@ -16,6 +17,7 @@ from ctypes import ( POINTER, _Pointer, # type: ignore Structure, + Union as CtypesUnion, Array, ) import pathlib @@ -317,12 +319,9 @@ class llama_batch(Structure): # LLAMA_KV_OVERRIDE_FLOAT, # LLAMA_KV_OVERRIDE_BOOL, # }; -class llama_model_kv_override_type(Structure): - _fields_ = [ - ("LLAMA_KV_OVERRIDE_INT", c_int), - ("LLAMA_KV_OVERRIDE_FLOAT", c_int), - ("LLAMA_KV_OVERRIDE_BOOL", c_int), - ] +LLAMA_KV_OVERRIDE_INT = 0 +LLAMA_KV_OVERRIDE_FLOAT = 1 +LLAMA_KV_OVERRIDE_BOOL = 2 # struct llama_model_kv_override { # char key[128]; @@ -333,13 +332,18 @@ class llama_model_kv_override_type(Structure): # bool bool_value; # }; # }; +class llama_model_kv_override_value(CtypesUnion): + _fields_ = [ + ("int_value", c_int64), + ("float_value", c_double), + ("bool_value", c_bool), + ] + class llama_model_kv_override(Structure): _fields_ = [ ("key", ctypes.c_char * 128), - ("tag", llama_model_kv_override_type), - ("int_value", ctypes.c_int64), - ("float_value", c_double), - ("bool_value", c_bool), + ("tag", c_int), + ("value", llama_model_kv_override_value), ] # struct llama_model_params {