Update llama.cpp

This commit is contained in:
Andrei Betlen 2023-04-22 19:50:28 -04:00
parent 643b73e155
commit e99caedbbd
2 changed files with 38 additions and 3 deletions

View file

@ -172,7 +172,9 @@ _lib.llama_free.restype = None
# TODO: not great API - very likely to change
# Returns 0 on success
# nthread - how many threads to use. If <=0, will use std::thread::hardware_concurrency(), else the number given
def llama_model_quantize(fname_inp: bytes, fname_out: bytes, ftype: c_int, nthread: c_int) -> c_int:
def llama_model_quantize(
fname_inp: bytes, fname_out: bytes, ftype: c_int, nthread: c_int
) -> c_int:
return _lib.llama_model_quantize(fname_inp, fname_out, ftype, nthread)
@ -187,7 +189,10 @@ _lib.llama_model_quantize.restype = c_int
# will be applied on top of the previous one
# Returns 0 on success
def llama_apply_lora_from_file(
ctx: llama_context_p, path_lora: ctypes.c_char_p, path_base_model: ctypes.c_char_p, n_threads: c_int
ctx: llama_context_p,
path_lora: ctypes.c_char_p,
path_base_model: ctypes.c_char_p,
n_threads: c_int,
) -> c_int:
return _lib.llama_apply_lora_from_file(ctx, path_lora, path_base_model, n_threads)
@ -235,6 +240,36 @@ _lib.llama_set_kv_cache.argtypes = [llama_context_p, POINTER(c_uint8), c_size_t,
_lib.llama_set_kv_cache.restype = None
# Returns the size in bytes of the state (rng, logits, embedding and kv_cache)
def llama_get_state_size(ctx: llama_context_p) -> c_size_t:
return _lib.llama_get_state_size(ctx)
_lib.llama_get_state_size.argtypes = [llama_context_p]
_lib.llama_get_state_size.restype = c_size_t
# Copies the state to the specified destination address.
# Destination needs to have allocated enough memory.
# Returns the number of bytes copied
def llama_copy_state_data(ctx: llama_context_p, dest) -> c_size_t:
return _lib.llama_copy_state_data(ctx, dest)
_lib.llama_copy_state_data.argtypes = [llama_context_p, POINTER(c_uint8)]
_lib.llama_copy_state_data.restype = c_size_t
# Set the state reading from the specified address
# Returns the number of bytes read
def llama_set_state_data(ctx: llama_context_p, src) -> c_size_t:
return _lib.llama_set_state_data(ctx, src)
_lib.llama_set_state_data.argtypes = [llama_context_p, POINTER(c_uint8)]
_lib.llama_set_state_data.restype = c_size_t
# Run the llama inference to obtain the logits and probabilities for the next token.
# tokens + n_tokens is the provided batch of new tokens to process
# n_past is the number of tokens to use from previous eval calls

2
vendor/llama.cpp vendored

@ -1 +1 @@
Subproject commit 50cb666b8a2e35a49b08c0f6bc81138c8f6f2ac1
Subproject commit 0e018fe008eacebdbcfa2d61b6c988c245c961cd