Update llama.cpp

This commit is contained in:
Andrei Betlen 2023-05-21 17:47:21 -04:00
parent 8f49ca0287
commit fafe47114c
3 changed files with 186 additions and 44 deletions

View file

@ -127,7 +127,6 @@ class Llama:
self.params = llama_cpp.llama_context_default_params()
self.params.n_ctx = n_ctx
self.params.n_parts = n_parts
self.params.n_gpu_layers = n_gpu_layers
self.params.seed = seed
self.params.f16_kv = f16_kv
@ -149,6 +148,10 @@ class Llama:
self.lora_base = lora_base
self.lora_path = lora_path
### DEPRECATED ###
self.n_parts = n_parts
### DEPRECATED ###
if not os.path.exists(model_path):
raise ValueError(f"Model path does not exist: {model_path}")
@ -1225,7 +1228,6 @@ class Llama:
verbose=self.verbose,
model_path=self.model_path,
n_ctx=self.params.n_ctx,
n_parts=self.params.n_parts,
n_gpu_layers=self.params.n_gpu_layers,
seed=self.params.seed,
f16_kv=self.params.f16_kv,
@ -1239,6 +1241,9 @@ class Llama:
n_threads=self.n_threads,
lora_base=self.lora_base,
lora_path=self.lora_path,
### DEPRECATED ###
n_parts=self.n_parts,
### DEPRECATED ###
)
def __setstate__(self, state):

View file

@ -72,31 +72,61 @@ _lib_base_name = "llama"
# Load the library
_lib = _load_shared_library(_lib_base_name)
# C types
LLAMA_FILE_VERSION = c_int(2)
LLAMA_FILE_MAGIC = b"ggjt"
LLAMA_FILE_MAGIC_UNVERSIONED = b"ggml"
LLAMA_SESSION_MAGIC = b"ggsn"
# Misc
c_float_p = POINTER(c_float)
c_uint8_p = POINTER(c_uint8)
c_size_t_p = POINTER(c_size_t)
# llama.h bindings
# #define LLAMA_FILE_MAGIC_GGJT 0x67676a74u // 'ggjt'
LLAMA_FILE_MAGIC_GGJT = ctypes.c_uint(0x67676A74)
# #define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
LLAMA_FILE_MAGIC_GGLA = ctypes.c_uint(0x67676C61)
# #define LLAMA_FILE_MAGIC_GGMF 0x67676d66u // 'ggmf'
LLAMA_FILE_MAGIC_GGMF = ctypes.c_uint(0x67676D66)
# #define LLAMA_FILE_MAGIC_GGML 0x67676d6cu // 'ggml'
LLAMA_FILE_MAGIC_GGML = ctypes.c_uint(0x67676D6C)
# #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
LLAMA_FILE_MAGIC_GGSN = ctypes.c_uint(0x6767736E)
# #define LLAMA_FILE_VERSION 3
LLAMA_FILE_VERSION = c_int(3)
LLAMA_FILE_MAGIC = LLAMA_FILE_MAGIC_GGJT
LLAMA_FILE_MAGIC_UNVERSIONED = LLAMA_FILE_MAGIC_GGML
LLAMA_SESSION_MAGIC = LLAMA_FILE_MAGIC_GGSN
LLAMA_SESSION_VERSION = c_int(1)
# struct llama_context;
llama_context_p = c_void_p
# typedef int llama_token;
llama_token = c_int
llama_token_p = POINTER(llama_token)
# typedef struct llama_token_data {
# llama_token id; // token id
# float logit; // log-odds of the token
# float p; // probability of the token
# } llama_token_data;
class llama_token_data(Structure):
_fields_ = [
("id", llama_token), # token id
("logit", c_float), # log-odds of the token
("p", c_float), # probability of the token
("id", llama_token),
("logit", c_float),
("p", c_float),
]
llama_token_data_p = POINTER(llama_token_data)
# typedef struct llama_token_data_array {
# llama_token_data * data;
# size_t size;
# bool sorted;
# } llama_token_data_array;
class llama_token_data_array(Structure):
_fields_ = [
("data", llama_token_data_p),
@ -107,54 +137,72 @@ class llama_token_data_array(Structure):
llama_token_data_array_p = POINTER(llama_token_data_array)
# typedef void (*llama_progress_callback)(float progress, void *ctx);
llama_progress_callback = ctypes.CFUNCTYPE(None, c_float, c_void_p)
# struct llama_context_params {
# int n_ctx; // text context
# int n_gpu_layers; // number of layers to store in VRAM
# int seed; // RNG seed, -1 for random
# bool f16_kv; // use fp16 for KV cache
# bool logits_all; // the llama_eval() call computes all logits, not just the last one
# bool vocab_only; // only load the vocabulary, no weights
# bool use_mmap; // use mmap if possible
# bool use_mlock; // force system to keep model in RAM
# bool embedding; // embedding mode only
# // called with a progress value between 0 and 1, pass NULL to disable
# llama_progress_callback progress_callback;
# // context pointer passed to the progress callback
# void * progress_callback_user_data;
# };
class llama_context_params(Structure):
_fields_ = [
("n_ctx", c_int), # text context
("n_parts", c_int), # -1 for default
("n_gpu_layers", c_int), # number of layers to store in VRAM
("seed", c_int), # RNG seed, 0 for random
("f16_kv", c_bool), # use fp16 for KV cache
("n_ctx", c_int),
("n_gpu_layers", c_int),
("seed", c_int),
("f16_kv", c_bool),
(
"logits_all",
c_bool,
), # the llama_eval() call computes all logits, not just the last one
("vocab_only", c_bool), # only load the vocabulary, no weights
("use_mmap", c_bool), # use mmap if possible
("use_mlock", c_bool), # force system to keep model in RAM
("embedding", c_bool), # embedding mode only
# called with a progress value between 0 and 1, pass NULL to disable
),
("vocab_only", c_bool),
("use_mmap", c_bool),
("use_mlock", c_bool),
("embedding", c_bool),
("progress_callback", llama_progress_callback),
# context pointer passed to the progress callback
("progress_callback_user_data", c_void_p),
]
llama_context_params_p = POINTER(llama_context_params)
# enum llama_ftype {
# LLAMA_FTYPE_ALL_F32 = 0,
# LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
# LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
# LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
# LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
# // LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed
# // LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed
# LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
# LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
# LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
# };
LLAMA_FTYPE_ALL_F32 = c_int(0)
LLAMA_FTYPE_MOSTLY_F16 = c_int(1) # except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_0 = c_int(2) # except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_1 = c_int(3) # except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = c_int(
4
) # tok_embeddings.weight and output.weight are F16
# LLAMA_FTYPE_MOSTLY_Q4_2 = c_int(5) # except 1d tensors
# LLAMA_FTYPE_MOSTYL_Q4_3 = c_int(6) # except 1d tensors
LLAMA_FTYPE_MOSTLY_Q8_0 = c_int(7) # except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_0 = c_int(8) # except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_1 = c_int(9) # except 1d tensors
# Misc
c_float_p = POINTER(c_float)
c_uint8_p = POINTER(c_uint8)
c_size_t_p = POINTER(c_size_t)
# Functions
LLAMA_FTYPE_MOSTLY_F16 = c_int(1)
LLAMA_FTYPE_MOSTLY_Q4_0 = c_int(2)
LLAMA_FTYPE_MOSTLY_Q4_1 = c_int(3)
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = c_int(4)
LLAMA_FTYPE_MOSTLY_Q8_0 = c_int(7)
LLAMA_FTYPE_MOSTLY_Q5_0 = c_int(8)
LLAMA_FTYPE_MOSTLY_Q5_1 = c_int(9)
# LLAMA_API struct llama_context_params llama_context_default_params();
def llama_context_default_params() -> llama_context_params:
return _lib.llama_context_default_params()
@ -163,6 +211,7 @@ _lib.llama_context_default_params.argtypes = []
_lib.llama_context_default_params.restype = llama_context_params
# LLAMA_API bool llama_mmap_supported();
def llama_mmap_supported() -> bool:
return _lib.llama_mmap_supported()
@ -171,6 +220,7 @@ _lib.llama_mmap_supported.argtypes = []
_lib.llama_mmap_supported.restype = c_bool
# LLAMA_API bool llama_mlock_supported();
def llama_mlock_supported() -> bool:
return _lib.llama_mlock_supported()
@ -179,9 +229,33 @@ _lib.llama_mlock_supported.argtypes = []
_lib.llama_mlock_supported.restype = c_bool
# Various functions for loading a ggml llama model.
# Allocate (almost) all memory needed for the model.
# Return NULL on failure
# // TODO: not great API - very likely to change
# // Initialize the llama + ggml backend
# // Call once at the start of the program
# LLAMA_API void llama_init_backend();
def llama_init_backend():
return _lib.llama_init_backend()
_lib.llama_init_backend.argtypes = []
_lib.llama_init_backend.restype = None
# LLAMA_API int64_t llama_time_us();
def llama_time_us() -> int:
return _lib.llama_time_us()
_lib.llama_time_us.argtypes = []
_lib.llama_time_us.restype = ctypes.c_int64
# // Various functions for loading a ggml llama model.
# // Allocate (almost) all memory needed for the model.
# // Return NULL on failure
# LLAMA_API struct llama_context * llama_init_from_file(
# const char * path_model,
# struct llama_context_params params);
def llama_init_from_file(
path_model: bytes, params: llama_context_params
) -> llama_context_p:
@ -193,6 +267,7 @@ _lib.llama_init_from_file.restype = llama_context_p
# Frees all allocated memory
# LLAMA_API void llama_free(struct llama_context * ctx);
def llama_free(ctx: llama_context_p):
return _lib.llama_free(ctx)
@ -204,6 +279,11 @@ _lib.llama_free.restype = None
# TODO: not great API - very likely to change
# Returns 0 on success
# nthread - how many threads to use. If <=0, will use std::thread::hardware_concurrency(), else the number given
# LLAMA_API int llama_model_quantize(
# const char * fname_inp,
# const char * fname_out,
# enum llama_ftype ftype,
# int nthread);
def llama_model_quantize(
fname_inp: bytes, fname_out: bytes, ftype: c_int, nthread: c_int
) -> int:
@ -220,6 +300,11 @@ _lib.llama_model_quantize.restype = c_int
# The model needs to be reloaded before applying a new adapter, otherwise the adapter
# will be applied on top of the previous one
# Returns 0 on success
# LLAMA_API int llama_apply_lora_from_file(
# struct llama_context * ctx,
# const char * path_lora,
# const char * path_base_model,
# int n_threads);
def llama_apply_lora_from_file(
ctx: llama_context_p,
path_lora: c_char_p,
@ -234,6 +319,7 @@ _lib.llama_apply_lora_from_file.restype = c_int
# Returns the number of tokens in the KV cache
# LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx);
def llama_get_kv_cache_token_count(ctx: llama_context_p) -> int:
return _lib.llama_get_kv_cache_token_count(ctx)
@ -243,6 +329,7 @@ _lib.llama_get_kv_cache_token_count.restype = c_int
# Sets the current rng seed.
# LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, int seed);
def llama_set_rng_seed(ctx: llama_context_p, seed: c_int):
return _lib.llama_set_rng_seed(ctx, seed)
@ -253,6 +340,7 @@ _lib.llama_set_rng_seed.restype = None
# Returns the maximum size in bytes of the state (rng, logits, embedding
# and kv_cache) - will often be smaller after compacting tokens
# LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx);
def llama_get_state_size(ctx: llama_context_p) -> int:
return _lib.llama_get_state_size(ctx)
@ -264,6 +352,7 @@ _lib.llama_get_state_size.restype = c_size_t
# Copies the state to the specified destination address.
# Destination needs to have allocated enough memory.
# Returns the number of bytes copied
# LLAMA_API size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst);
def llama_copy_state_data(
ctx: llama_context_p, dst # type: Array[c_uint8]
) -> int:
@ -276,6 +365,7 @@ _lib.llama_copy_state_data.restype = c_size_t
# Set the state reading from the specified address
# Returns the number of bytes read
# LLAMA_API size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src);
def llama_set_state_data(
ctx: llama_context_p, src # type: Array[c_uint8]
) -> int:
@ -287,6 +377,7 @@ _lib.llama_set_state_data.restype = c_size_t
# Save/load session file
# LLAMA_API bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out);
def llama_load_session_file(
ctx: llama_context_p,
path_session: bytes,
@ -309,6 +400,7 @@ _lib.llama_load_session_file.argtypes = [
_lib.llama_load_session_file.restype = c_size_t
# LLAMA_API bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count);
def llama_save_session_file(
ctx: llama_context_p,
path_session: bytes,
@ -331,6 +423,12 @@ _lib.llama_save_session_file.restype = c_size_t
# tokens + n_tokens is the provided batch of new tokens to process
# n_past is the number of tokens to use from previous eval calls
# Returns 0 on success
# LLAMA_API int llama_eval(
# struct llama_context * ctx,
# const llama_token * tokens,
# int n_tokens,
# int n_past,
# int n_threads);
def llama_eval(
ctx: llama_context_p,
tokens, # type: Array[llama_token]
@ -350,6 +448,12 @@ _lib.llama_eval.restype = c_int
# Returns the number of tokens on success, no more than n_max_tokens
# Returns a negative number on failure - the number of tokens that would have been returned
# TODO: not sure if correct
# LLAMA_API int llama_tokenize(
# struct llama_context * ctx,
# const char * text,
# llama_token * tokens,
# int n_max_tokens,
# bool add_bos);
def llama_tokenize(
ctx: llama_context_p,
text: bytes,
@ -364,6 +468,7 @@ _lib.llama_tokenize.argtypes = [llama_context_p, c_char_p, llama_token_p, c_int,
_lib.llama_tokenize.restype = c_int
# LLAMA_API int llama_n_vocab(const struct llama_context * ctx);
def llama_n_vocab(ctx: llama_context_p) -> int:
return _lib.llama_n_vocab(ctx)
@ -372,6 +477,7 @@ _lib.llama_n_vocab.argtypes = [llama_context_p]
_lib.llama_n_vocab.restype = c_int
# LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
def llama_n_ctx(ctx: llama_context_p) -> int:
return _lib.llama_n_ctx(ctx)
@ -380,6 +486,7 @@ _lib.llama_n_ctx.argtypes = [llama_context_p]
_lib.llama_n_ctx.restype = c_int
# LLAMA_API int llama_n_embd (const struct llama_context * ctx);
def llama_n_embd(ctx: llama_context_p) -> int:
return _lib.llama_n_embd(ctx)
@ -393,6 +500,7 @@ _lib.llama_n_embd.restype = c_int
# Can be mutated in order to change the probabilities of the next token
# Rows: n_tokens
# Cols: n_vocab
# LLAMA_API float * llama_get_logits(struct llama_context * ctx);
def llama_get_logits(
ctx: llama_context_p,
): # type: (...) -> Array[float] # type: ignore
@ -405,6 +513,7 @@ _lib.llama_get_logits.restype = c_float_p
# Get the embeddings for the input
# shape: [n_embd] (1-dimensional)
# LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
def llama_get_embeddings(
ctx: llama_context_p,
): # type: (...) -> Array[float] # type: ignore
@ -416,6 +525,7 @@ _lib.llama_get_embeddings.restype = c_float_p
# Token Id -> String. Uses the vocabulary in the provided context
# LLAMA_API const char * llama_token_to_str(const struct llama_context * ctx, llama_token token);
def llama_token_to_str(ctx: llama_context_p, token: llama_token) -> bytes:
return _lib.llama_token_to_str(ctx, token)
@ -426,6 +536,7 @@ _lib.llama_token_to_str.restype = c_char_p
# Special tokens
# LLAMA_API llama_token llama_token_bos();
def llama_token_bos() -> int:
return _lib.llama_token_bos()
@ -434,6 +545,7 @@ _lib.llama_token_bos.argtypes = []
_lib.llama_token_bos.restype = llama_token
# LLAMA_API llama_token llama_token_eos();
def llama_token_eos() -> int:
return _lib.llama_token_eos()
@ -442,6 +554,7 @@ _lib.llama_token_eos.argtypes = []
_lib.llama_token_eos.restype = llama_token
# LLAMA_API llama_token llama_token_nl();
def llama_token_nl() -> int:
return _lib.llama_token_nl()
@ -454,6 +567,7 @@ _lib.llama_token_nl.restype = llama_token
# @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
# LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty);
def llama_sample_repetition_penalty(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
@ -477,6 +591,7 @@ _lib.llama_sample_repetition_penalty.restype = None
# @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
# LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence);
def llama_sample_frequency_and_presence_penalties(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
@ -507,6 +622,7 @@ _lib.llama_sample_frequency_and_presence_penalties.restype = None
# @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
# LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates);
def llama_sample_softmax(
ctx: llama_context_p, candidates # type: _Pointer[llama_token_data]
):
@ -521,6 +637,7 @@ _lib.llama_sample_softmax.restype = None
# @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
# LLAMA_API void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep);
def llama_sample_top_k(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
@ -540,6 +657,7 @@ _lib.llama_sample_top_k.restype = None
# @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
# LLAMA_API void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep);
def llama_sample_top_p(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
@ -559,6 +677,7 @@ _lib.llama_sample_top_p.restype = None
# @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
# LLAMA_API void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep);
def llama_sample_tail_free(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
@ -578,6 +697,7 @@ _lib.llama_sample_tail_free.restype = None
# @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
# LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep);
def llama_sample_typical(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
@ -596,6 +716,7 @@ _lib.llama_sample_typical.argtypes = [
_lib.llama_sample_typical.restype = None
# LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp);
def llama_sample_temperature(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
@ -618,6 +739,7 @@ _lib.llama_sample_temperature.restype = None
# @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
# @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
# @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
# LLAMA_API llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu);
def llama_sample_token_mirostat(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
@ -645,6 +767,7 @@ _lib.llama_sample_token_mirostat.restype = llama_token
# @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
# @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
# @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
# LLAMA_API llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu);
def llama_sample_token_mirostat_v2(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
@ -666,6 +789,7 @@ _lib.llama_sample_token_mirostat_v2.restype = llama_token
# @details Selects the token with the highest probability.
# LLAMA_API llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates);
def llama_sample_token_greedy(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
@ -681,6 +805,7 @@ _lib.llama_sample_token_greedy.restype = llama_token
# @details Randomly selects a token from the candidates based on their probabilities.
# LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates);
def llama_sample_token(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
@ -698,6 +823,7 @@ _lib.llama_sample_token.restype = llama_token
# Performance information
# LLAMA_API void llama_print_timings(struct llama_context * ctx);
def llama_print_timings(ctx: llama_context_p):
_lib.llama_print_timings(ctx)
@ -706,6 +832,7 @@ _lib.llama_print_timings.argtypes = [llama_context_p]
_lib.llama_print_timings.restype = None
# LLAMA_API void llama_reset_timings(struct llama_context * ctx);
def llama_reset_timings(ctx: llama_context_p):
_lib.llama_reset_timings(ctx)
@ -715,9 +842,19 @@ _lib.llama_reset_timings.restype = None
# Print system information
# LLAMA_API const char * llama_print_system_info(void);
def llama_print_system_info() -> bytes:
return _lib.llama_print_system_info()
_lib.llama_print_system_info.argtypes = []
_lib.llama_print_system_info.restype = c_char_p
###################################################################################################
_llama_initialized = False
if not _llama_initialized:
llama_init_backend()
_llama_initialized = True

2
vendor/llama.cpp vendored

@ -1 +1 @@
Subproject commit c238b5873a1ea496db03ffcfe124c9d0d83afbc6
Subproject commit 7e4ea5beff567f53be92f75f9089e6f11fa5dabd