diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 97c6565..2ffc4c5 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -172,7 +172,9 @@ _lib.llama_free.restype = None # TODO: not great API - very likely to change # Returns 0 on success # nthread - how many threads to use. If <=0, will use std::thread::hardware_concurrency(), else the number given -def llama_model_quantize(fname_inp: bytes, fname_out: bytes, ftype: c_int, nthread: c_int) -> c_int: +def llama_model_quantize( + fname_inp: bytes, fname_out: bytes, ftype: c_int, nthread: c_int +) -> c_int: return _lib.llama_model_quantize(fname_inp, fname_out, ftype, nthread) @@ -187,7 +189,10 @@ _lib.llama_model_quantize.restype = c_int # will be applied on top of the previous one # Returns 0 on success def llama_apply_lora_from_file( - ctx: llama_context_p, path_lora: ctypes.c_char_p, path_base_model: ctypes.c_char_p, n_threads: c_int + ctx: llama_context_p, + path_lora: ctypes.c_char_p, + path_base_model: ctypes.c_char_p, + n_threads: c_int, ) -> c_int: return _lib.llama_apply_lora_from_file(ctx, path_lora, path_base_model, n_threads) @@ -235,6 +240,36 @@ _lib.llama_set_kv_cache.argtypes = [llama_context_p, POINTER(c_uint8), c_size_t, _lib.llama_set_kv_cache.restype = None +# Returns the size in bytes of the state (rng, logits, embedding and kv_cache) +def llama_get_state_size(ctx: llama_context_p) -> c_size_t: + return _lib.llama_get_state_size(ctx) + + +_lib.llama_get_state_size.argtypes = [llama_context_p] +_lib.llama_get_state_size.restype = c_size_t + + +# Copies the state to the specified destination address. +# Destination needs to have allocated enough memory. +# Returns the number of bytes copied +def llama_copy_state_data(ctx: llama_context_p, dest) -> c_size_t: + return _lib.llama_copy_state_data(ctx, dest) + + +_lib.llama_copy_state_data.argtypes = [llama_context_p, POINTER(c_uint8)] +_lib.llama_copy_state_data.restype = c_size_t + + +# Set the state reading from the specified address +# Returns the number of bytes read +def llama_set_state_data(ctx: llama_context_p, src) -> c_size_t: + return _lib.llama_set_state_data(ctx, src) + + +_lib.llama_set_state_data.argtypes = [llama_context_p, POINTER(c_uint8)] +_lib.llama_set_state_data.restype = c_size_t + + # Run the llama inference to obtain the logits and probabilities for the next token. # tokens + n_tokens is the provided batch of new tokens to process # n_past is the number of tokens to use from previous eval calls diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 50cb666..0e018fe 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 50cb666b8a2e35a49b08c0f6bc81138c8f6f2ac1 +Subproject commit 0e018fe008eacebdbcfa2d61b6c988c245c961cd