diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 0c0d48f..54424cb 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -30,6 +30,7 @@ import numpy.typing as npt from .utils import suppress_stdout_stderr + class BaseLlamaCache(ABC): """Base cache class for a llama.cpp model.""" @@ -215,30 +216,37 @@ class Llama: self, model_path: str, *, - # NOTE: These parameters are likely to change in the future. - seed: int = llama_cpp.LLAMA_DEFAULT_SEED, - n_ctx: int = 512, - n_batch: int = 512, + # Model Params n_gpu_layers: int = 0, main_gpu: int = 0, tensor_split: Optional[List[float]] = None, - rope_freq_base: float = 10000.0, - rope_freq_scale: float = 1.0, - low_vram: bool = False, - mul_mat_q: bool = True, - f16_kv: bool = True, - logits_all: bool = False, vocab_only: bool = False, use_mmap: bool = True, use_mlock: bool = False, - embedding: bool = False, + # Context Params + seed: int = llama_cpp.LLAMA_DEFAULT_SEED, + n_ctx: int = 512, + n_batch: int = 512, n_threads: Optional[int] = None, + n_threads_batch: Optional[int] = None, + rope_freq_base: float = 10000.0, + rope_freq_scale: float = 1.0, + mul_mat_q: bool = True, + f16_kv: bool = True, + logits_all: bool = False, + embedding: bool = False, + # Sampling Params last_n_tokens_size: int = 64, + # LoRA Params lora_base: Optional[str] = None, + lora_scale: float = 1.0, lora_path: Optional[str] = None, + # Backend Params numa: bool = False, + # Misc verbose: bool = True, - **kwargs # type: ignore + # Extra Params + **kwargs, # type: ignore ): """Load a llama.cpp model from `model_path`. @@ -277,52 +285,64 @@ class Llama: self.verbose = verbose + self.numa = numa if not Llama.__backend_initialized: if self.verbose: - llama_cpp.llama_backend_init(numa) + llama_cpp.llama_backend_init(self.numa) else: with suppress_stdout_stderr(): - llama_cpp.llama_backend_init(numa) + llama_cpp.llama_backend_init(self.numa) Llama.__backend_initialized = True self.model_path = model_path - self.params = llama_cpp.llama_context_default_params() - self.params.seed = seed - self.params.n_ctx = n_ctx - self.params.n_gpu_layers = 0x7FFFFFFF if n_gpu_layers == -1 else n_gpu_layers # 0x7FFFFFFF is INT32 max, will be auto set to all layers - self.params.main_gpu = main_gpu - self.params.rope_freq_base = rope_freq_base - self.params.rope_freq_scale = rope_freq_scale - self.params.low_vram = low_vram - self.params.mul_mat_q = mul_mat_q - self.params.f16_kv = f16_kv - self.params.logits_all = logits_all - self.params.vocab_only = vocab_only - self.params.use_mmap = use_mmap if lora_path is None else False - self.params.use_mlock = use_mlock - self.params.embedding = embedding - + # Model Params + self.model_params = llama_cpp.llama_model_default_params() + self.model_params.n_gpu_layers = ( + 0x7FFFFFFF if n_gpu_layers == -1 else n_gpu_layers + ) # 0x7FFFFFFF is INT32 max, will be auto set to all layers + self.model_params.main_gpu = main_gpu self.tensor_split = tensor_split self._p_tensor_split = None - if self.tensor_split is not None: # Type conversion and expand the list to the length of LLAMA_MAX_DEVICES FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES self._c_tensor_split = FloatArray( - *tensor_split + *tensor_split # type: ignore ) # keep a reference to the array so it is not gc'd - self.params.tensor_split = self._c_tensor_split + self.model_params.tensor_split = self._c_tensor_split + self.model_params.vocab_only = vocab_only + self.model_params.use_mmap = use_mmap if lora_path is None else False + self.model_params.use_mlock = use_mlock + self.n_batch = min(n_ctx, n_batch) # ??? + self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1) + self.n_threads_batch = n_threads_batch or max( + multiprocessing.cpu_count() // 2, 1 + ) + # Context Params + self.context_params = llama_cpp.llama_context_default_params() + self.context_params.seed = seed + self.context_params.n_ctx = n_ctx + self.context_params.n_batch = self.n_batch + self.context_params.n_threads = self.n_threads + self.context_params.n_threads_batch = self.n_threads_batch + self.context_params.rope_freq_base = rope_freq_base + self.context_params.rope_freq_scale = rope_freq_scale + self.context_params.mul_mat_q = mul_mat_q + self.context_params.f16_kv = f16_kv + self.context_params.logits_all = logits_all + self.context_params.embedding = embedding + + # Sampling Params self.last_n_tokens_size = last_n_tokens_size - self.n_batch = min(n_ctx, n_batch) + self.cache: Optional[BaseLlamaCache] = None - self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1) - self.lora_base = lora_base + self.lora_scale = lora_scale self.lora_path = lora_path if not os.path.exists(model_path): @@ -330,21 +350,23 @@ class Llama: if verbose: self.model = llama_cpp.llama_load_model_from_file( - self.model_path.encode("utf-8"), self.params + self.model_path.encode("utf-8"), self.model_params ) else: with suppress_stdout_stderr(): self.model = llama_cpp.llama_load_model_from_file( - self.model_path.encode("utf-8"), self.params + self.model_path.encode("utf-8"), self.model_params ) assert self.model is not None if verbose: - self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.params) + self.ctx = llama_cpp.llama_new_context_with_model( + self.model, self.context_params + ) else: with suppress_stdout_stderr(): self.ctx = llama_cpp.llama_new_context_with_model( - self.model, self.params + self.model, self.context_params ) assert self.ctx is not None @@ -353,6 +375,7 @@ class Llama: if llama_cpp.llama_model_apply_lora_from_file( self.model, self.lora_path.encode("utf-8"), + self.lora_scale, self.lora_base.encode("utf-8") if self.lora_base is not None else llama_cpp.c_char_p(0), @@ -409,7 +432,7 @@ class Llama: def eval_logits(self) -> Deque[List[float]]: return deque( self.scores[: self.n_tokens, :].tolist(), - maxlen=self._n_ctx if self.params.logits_all else 1, + maxlen=self._n_ctx if self.model_params.logits_all else 1, ) def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]: @@ -427,7 +450,7 @@ class Llama: assert self.model is not None n_ctx = self._n_ctx tokens = (llama_cpp.llama_token * n_ctx)() - n_tokens = llama_cpp.llama_tokenize_with_model( + n_tokens = llama_cpp.llama_tokenize( self.model, text, len(text), @@ -438,7 +461,7 @@ class Llama: if n_tokens < 0: n_tokens = abs(n_tokens) tokens = (llama_cpp.llama_token * n_tokens)() - n_tokens = llama_cpp.llama_tokenize_with_model( + n_tokens = llama_cpp.llama_tokenize( self.model, text, len(text), @@ -466,14 +489,16 @@ class Llama: size = 32 buffer = (ctypes.c_char * size)() for token in tokens: - n = llama_cpp.llama_token_to_piece_with_model( + n = llama_cpp.llama_token_to_piece( self.model, llama_cpp.llama_token(token), buffer, size ) assert n <= size output += bytes(buffer[:n]) # NOTE: Llama1 models automatically added a space at the start of the prompt # this line removes a leading space if the first token is a beginning of sentence token - return output[1:] if len(tokens) > 0 and tokens[0] == self.token_bos() else output + return ( + output[1:] if len(tokens) > 0 and tokens[0] == self.token_bos() else output + ) def set_cache(self, cache: Optional[BaseLlamaCache]): """Set the cache. @@ -504,17 +529,16 @@ class Llama: tokens=(llama_cpp.llama_token * len(batch))(*batch), n_tokens=n_tokens, n_past=n_past, - n_threads=self.n_threads, ) if return_code != 0: raise RuntimeError(f"llama_eval returned {return_code}") # Save tokens self.input_ids[self.n_tokens : self.n_tokens + n_tokens] = batch # Save logits - rows = n_tokens if self.params.logits_all else 1 + rows = n_tokens if self.context_params.logits_all else 1 cols = self._n_vocab offset = ( - 0 if self.params.logits_all else n_tokens - 1 + 0 if self.context_params.logits_all else n_tokens - 1 ) # NOTE: Only save the last token logits if logits_all is False self.scores[self.n_tokens + offset : self.n_tokens + n_tokens, :].reshape( -1 @@ -545,11 +569,7 @@ class Llama: n_vocab = self._n_vocab n_ctx = self._n_ctx top_k = n_vocab if top_k <= 0 else top_k - last_n_tokens_size = ( - n_ctx - if last_n_tokens_size < 0 - else last_n_tokens_size - ) + last_n_tokens_size = n_ctx if last_n_tokens_size < 0 else last_n_tokens_size logits: npt.NDArray[np.single] = self._scores[-1, :] if logits_processor is not None: @@ -610,7 +630,7 @@ class Llama: mu=llama_cpp.ctypes.byref(mirostat_mu), # type: ignore m=mirostat_m, ) - elif mirostat_mode== 2: + elif mirostat_mode == 2: mirostat_mu = llama_cpp.c_float(2.0 * mirostat_tau) llama_cpp.llama_sample_temperature( ctx=self.ctx, @@ -802,7 +822,7 @@ class Llama: def create_embedding( self, input: Union[str, List[str]], model: Optional[str] = None - ) -> Embedding: + ) -> CreateEmbeddingResponse: """Embed a string. Args: @@ -814,7 +834,7 @@ class Llama: assert self.ctx is not None model_name: str = model if model is not None else self.model_path - if self.params.embedding == False: + if self.model_params.embedding == False: raise RuntimeError( "Llama model must be created with embedding=True to call this method" ) @@ -900,7 +920,11 @@ class Llama: created: int = int(time.time()) completion_tokens: List[int] = [] # Add blank space to start of prompt to match OG llama tokenizer - prompt_tokens: List[int] = self.tokenize(prompt.encode("utf-8")) if prompt != "" else [self.token_bos()] + prompt_tokens: List[int] = ( + self.tokenize(prompt.encode("utf-8")) + if prompt != "" + else [self.token_bos()] + ) text: bytes = b"" returned_tokens: int = 0 stop = ( @@ -932,7 +956,7 @@ class Llama: else: stop_sequences = [] - if logprobs is not None and self.params.logits_all is False: + if logprobs is not None and self.model_params.logits_all is False: raise ValueError( "logprobs is not supported for models created with logits_all=False" ) @@ -1025,7 +1049,9 @@ class Llama: for token in remaining_tokens: token_end_position += len(self.detokenize([token])) # Check if stop sequence is in the token - if token_end_position > (remaining_length - first_stop_position): + if token_end_position > ( + remaining_length - first_stop_position + ): break token_str = self.detokenize([token]).decode( "utf-8", errors="ignore" @@ -1082,7 +1108,7 @@ class Llama: for i in range(1, len(remaining_tokens) + 1): try: bs = self.detokenize(remaining_tokens[:i]) - ts = bs.decode('utf-8') + ts = bs.decode("utf-8") decode_success = True break except UnicodeError: @@ -1093,7 +1119,9 @@ class Llama: # all remaining tokens cannot be decoded to a UTF-8 character break token_end_position += len(bs) - if token_end_position > (remaining_length - first_stop_position): + if token_end_position > ( + remaining_length - first_stop_position + ): break remaining_tokens = remaining_tokens[i:] returned_tokens += i @@ -1398,7 +1426,7 @@ class Llama: model=model, stopping_criteria=stopping_criteria, logits_processor=logits_processor, - grammar=grammar + grammar=grammar, ) if stream: chunks: Iterator[CompletionChunk] = completion_or_chunks @@ -1618,47 +1646,68 @@ class Llama: def __getstate__(self): return dict( - verbose=self.verbose, model_path=self.model_path, - n_ctx=self.params.n_ctx, - n_gpu_layers=self.params.n_gpu_layers, - seed=self.params.seed, - f16_kv=self.params.f16_kv, - logits_all=self.params.logits_all, - vocab_only=self.params.vocab_only, - use_mmap=self.params.use_mmap, - use_mlock=self.params.use_mlock, - embedding=self.params.embedding, - low_vram=self.params.low_vram, - last_n_tokens_size=self.last_n_tokens_size, - n_batch=self.n_batch, - n_threads=self.n_threads, - lora_base=self.lora_base, - lora_path=self.lora_path, + # Model Params + n_gpu_layers=self.model_params.n_gpu_layers, + main_gpu=self.model_params.main_gpu, tensor_split=self.tensor_split, - mul_mat_q=self.params.mul_mat_q, + vocab_only=self.model_params.vocab_only, + use_mmap=self.model_params.use_mmap, + use_mlock=self.model_params.use_mlock, + # Context Params + seed=self.context_params.seed, + n_ctx=self.context_params.n_ctx, + n_batch=self.n_batch, + n_threads=self.context_params.n_threads, + n_threads_batch=self.context_params.n_threads_batch, + rope_freq_base=self.context_params.rope_freq_base, + rope_freq_scale=self.context_params.rope_freq_scale, + mul_mat_q=self.context_params.mul_mat_q, + f16_kv=self.context_params.f16_kv, + logits_all=self.context_params.logits_all, + embedding=self.context_params.embedding, + # Sampling Params + last_n_tokens_size=self.last_n_tokens_size, + # LoRA Params + lora_base=self.lora_base, + lora_scale=self.lora_scale, + lora_path=self.lora_path, + # Backend Params + numa=self.numa, + # Misc + verbose=self.verbose, ) def __setstate__(self, state): self.__init__( model_path=state["model_path"], - n_ctx=state["n_ctx"], + # Model Params n_gpu_layers=state["n_gpu_layers"], - seed=state["seed"], - f16_kv=state["f16_kv"], - logits_all=state["logits_all"], + main_gpu=state["main_gpu"], + tensor_split=state["tensor_split"], vocab_only=state["vocab_only"], use_mmap=state["use_mmap"], use_mlock=state["use_mlock"], - embedding=state["embedding"], - low_vram=state["low_vram"], - n_threads=state["n_threads"], + # Context Params + seed=state["seed"], + n_ctx=state["n_ctx"], n_batch=state["n_batch"], + n_threads=state["n_threads"], + n_threads_batch=state["n_threads_batch"], + rope_freq_base=state["rope_freq_base"], + rope_freq_scale=state["rope_freq_scale"], + mul_mat_q=state["mul_mat_q"], + f16_kv=state["f16_kv"], + logits_all=state["logits_all"], + embedding=state["embedding"], + # Sampling Params last_n_tokens_size=state["last_n_tokens_size"], + # LoRA Params lora_base=state["lora_base"], lora_path=state["lora_path"], - tensor_split=state["tensor_split"], - mul_mat_q=state["mul_mat_q"], + # Backend Params + numa=state["numa"], + # Misc verbose=state["verbose"], ) @@ -1711,13 +1760,13 @@ class Llama: def n_embd(self) -> int: """Return the embedding size.""" - assert self.ctx is not None - return llama_cpp.llama_n_embd(self.ctx) + assert self.model is not None + return llama_cpp.llama_n_embd(self.model) def n_vocab(self) -> int: """Return the vocabulary size.""" - assert self.ctx is not None - return llama_cpp.llama_n_vocab(self.ctx) + assert self.model is not None + return llama_cpp.llama_n_vocab(self.model) def tokenizer(self) -> "LlamaTokenizer": """Return the tokenizer for this model.""" diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 53298df..4734aec 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -2,20 +2,21 @@ import sys import os import ctypes from ctypes import ( - c_double, - c_int, - c_float, - c_char_p, - c_int32, - c_uint32, - c_void_p, c_bool, + c_char_p, + c_int, + c_int8, + c_int32, + c_uint8, + c_uint32, + c_size_t, + c_float, + c_double, + c_void_p, POINTER, _Pointer, # type: ignore Structure, Array, - c_uint8, - c_size_t, ) import pathlib from typing import List, Union @@ -93,6 +94,9 @@ LLAMA_MAX_DEVICES = GGML_CUDA_MAX_DEVICES if GGML_USE_CUBLAS else 1 # define LLAMA_DEFAULT_SEED 0xFFFFFFFF LLAMA_DEFAULT_SEED = 0xFFFFFFFF +# define LLAMA_MAX_RNG_STATE (64*1024) +LLAMA_MAX_RNG_STATE = 64 * 1024 + # define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' LLAMA_FILE_MAGIC_GGSN = 0x6767736E @@ -109,18 +113,14 @@ llama_model_p = c_void_p llama_context_p = c_void_p -# typedef int llama_token; -llama_token = c_int +# typedef int32_t llama_pos; +llama_pos = c_int32 +# typedef int32_t llama_token; +llama_token = c_int32 llama_token_p = POINTER(llama_token) +# typedef int32_t llama_seq_id; +llama_seq_id = c_int32 -# enum llama_log_level { -# LLAMA_LOG_LEVEL_ERROR = 2, -# LLAMA_LOG_LEVEL_WARN = 3, -# LLAMA_LOG_LEVEL_INFO = 4 -# }; -LLAMA_LOG_LEVEL_ERROR = 2 -LLAMA_LOG_LEVEL_WARN = 3 -LLAMA_LOG_LEVEL_INFO = 4 # enum llama_vocab_type { # LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece @@ -147,27 +147,29 @@ LLAMA_TOKEN_TYPE_USER_DEFINED = 4 LLAMA_TOKEN_TYPE_UNUSED = 5 LLAMA_TOKEN_TYPE_BYTE = 6 + +# // model file types # enum llama_ftype { # LLAMA_FTYPE_ALL_F32 = 0, -# LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 -# // LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed -# // LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed -# LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q2_K = 10,// except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q3_K_S = 11,// except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q3_K_M = 12,// except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q3_K_L = 13,// except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q4_K_S = 14,// except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q4_K_M = 15,// except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q5_K_S = 16,// except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q5_K_M = 17,// except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors -# +# LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 +# // LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed +# // LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed +# LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q2_K = 10, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q3_K_S = 11, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q3_K_M = 12, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q3_K_L = 13, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q4_K_S = 14, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q4_K_M = 15, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q5_K_S = 16, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q5_K_M = 17, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q6_K = 18, // except 1d tensors + # LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file # }; LLAMA_FTYPE_ALL_F32 = 0 @@ -224,19 +226,55 @@ llama_token_data_array_p = POINTER(llama_token_data_array) # typedef void (*llama_progress_callback)(float progress, void *ctx); llama_progress_callback = ctypes.CFUNCTYPE(None, c_float, c_void_p) -# struct llama_context_params { -# uint32_t seed; // RNG seed, -1 for random -# int32_t n_ctx; // text context -# int32_t n_batch; // prompt processing batch size -# int32_t n_gpu_layers; // number of layers to store in VRAM -# int32_t main_gpu; // the GPU that is used for scratch and small tensors +# // Input data for llama_decode +# // A llama_batch object can contain input about one or many sequences +# // The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens +# // +# // - token : the token ids of the input (used when embd is NULL) +# // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL) +# // - pos : the positions of the respective token in the sequence +# // - seq_id : the sequence to which the respective token belongs +# // - logits : if zero, the logits for the respective token will not be output +# // +# typedef struct llama_batch { +# int32_t n_tokens; + +# llama_token * token; +# float * embd; +# llama_pos * pos; +# llama_seq_id * seq_id; +# int8_t * logits; + + +# // NOTE: helpers for smooth API transition - can be deprecated in the future +# // for future-proof code, use the above fields instead and ignore everything below +# // +# // pos[i] = all_pos_0 + i*all_pos_1 +# // +# llama_pos all_pos_0; // used if pos == NULL +# llama_pos all_pos_1; // used if pos == NULL +# llama_seq_id all_seq_id; // used if seq_id == NULL +# } llama_batch; +class llama_batch(Structure): + _fields_ = [ + ("n_tokens", c_int32), + ("token", POINTER(llama_token)), + ("embd", c_float_p), + ("pos", POINTER(llama_pos)), + ("seq_id", POINTER(llama_seq_id)), + ("logits", POINTER(c_int8)), + ("all_pos_0", llama_pos), + ("all_pos_1", llama_pos), + ("all_seq_id", llama_seq_id), + ] + + +# struct llama_model_params { +# int32_t n_gpu_layers; // number of layers to store in VRAM +# int32_t main_gpu; // the GPU that is used for scratch and small tensors # const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES) -# // ref: https://github.com/ggerganov/llama.cpp/pull/2054 -# float rope_freq_base; // RoPE base frequency -# float rope_freq_scale; // RoPE frequency scaling factor - # // called with a progress value between 0 and 1, pass NULL to disable # llama_progress_callback progress_callback; # // context pointer passed to the progress callback @@ -244,41 +282,57 @@ llama_progress_callback = ctypes.CFUNCTYPE(None, c_float, c_void_p) # // Keep the booleans together to avoid misalignment during copy-by-value. -# bool low_vram; // if true, reduce VRAM usage at the cost of performance -# bool mul_mat_q; // if true, use experimental mul_mat_q kernels -# bool f16_kv; // use fp16 for KV cache -# bool logits_all; // the llama_eval() call computes all logits, not just the last one # bool vocab_only; // only load the vocabulary, no weights # bool use_mmap; // use mmap if possible # bool use_mlock; // force system to keep model in RAM +# }; +class llama_model_params(Structure): + _fields_ = [ + ("n_gpu_layers", c_int32), + ("main_gpu", c_int32), + ("tensor_split", c_float_p), + ("progress_callback", llama_progress_callback), + ("progress_callback_user_data", c_void_p), + ("vocab_only", c_bool), + ("use_mmap", c_bool), + ("use_mlock", c_bool), + ] + + +# struct llama_context_params { +# uint32_t seed; // RNG seed, -1 for random +# uint32_t n_ctx; // text context +# uint32_t n_batch; // prompt processing batch size +# uint32_t n_threads; // number of threads to use for generation +# uint32_t n_threads_batch; // number of threads to use for batch processing + +# // ref: https://github.com/ggerganov/llama.cpp/pull/2054 +# float rope_freq_base; // RoPE base frequency +# float rope_freq_scale; // RoPE frequency scaling factor + + +# // Keep the booleans together to avoid misalignment during copy-by-value. +# bool mul_mat_q; // if true, use experimental mul_mat_q kernels +# bool f16_kv; // use fp16 for KV cache +# bool logits_all; // the llama_eval() call computes all logits, not just the last one # bool embedding; // embedding mode only # }; class llama_context_params(Structure): _fields_ = [ ("seed", c_uint32), - ("n_ctx", c_int32), - ("n_batch", c_int32), - ("n_gpu_layers", c_int32), - ("main_gpu", c_int32), - ("tensor_split", c_float_p), + ("n_ctx", c_uint32), + ("n_batch", c_uint32), + ("n_threads", c_uint32), + ("n_threads_batch", c_uint32), ("rope_freq_base", c_float), ("rope_freq_scale", c_float), - ("progress_callback", llama_progress_callback), - ("progress_callback_user_data", c_void_p), - ("low_vram", c_bool), ("mul_mat_q", c_bool), ("f16_kv", c_bool), ("logits_all", c_bool), - ("vocab_only", c_bool), - ("use_mmap", c_bool), - ("use_mlock", c_bool), ("embedding", c_bool), ] -llama_context_params_p = POINTER(llama_context_params) - - # // Signature for logging events # // Note that text includes the new line character at the end for most events. # // If your logging mechanism cannot handle that, check if the last character is '\n' and strip it @@ -385,6 +439,16 @@ class llama_timings(Structure): ] +# // Helpers for getting default parameters +# LLAMA_API struct llama_model_params llama_model_default_params(void); +def llama_model_default_params() -> llama_model_params: + return _lib.llama_model_default_params() + + +_lib.llama_model_default_params.argtypes = [] +_lib.llama_model_default_params.restype = llama_model_params + + # LLAMA_API struct llama_context_params llama_context_default_params(void); def llama_context_default_params() -> llama_context_params: return _lib.llama_context_default_params() @@ -429,12 +493,12 @@ _lib.llama_backend_free.restype = None # const char * path_model, # struct llama_context_params params); def llama_load_model_from_file( - path_model: bytes, params: llama_context_params + path_model: bytes, params: llama_model_params ) -> llama_model_p: return _lib.llama_load_model_from_file(path_model, params) -_lib.llama_load_model_from_file.argtypes = [c_char_p, llama_context_params] +_lib.llama_load_model_from_file.argtypes = [c_char_p, llama_model_params] _lib.llama_load_model_from_file.restype = llama_model_p @@ -506,13 +570,13 @@ _lib.llama_mlock_supported.argtypes = [] _lib.llama_mlock_supported.restype = c_bool -# LLAMA_API int llama_n_vocab (const struct llama_context * ctx); -def llama_n_vocab(ctx: llama_context_p) -> int: - return _lib.llama_n_vocab(ctx) +# LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx); +def llama_get_model(ctx: llama_context_p) -> llama_model_p: + return _lib.llama_get_model(ctx) -_lib.llama_n_vocab.argtypes = [llama_context_p] -_lib.llama_n_vocab.restype = c_int +_lib.llama_get_model.argtypes = [llama_context_p] +_lib.llama_get_model.restype = llama_model_p # LLAMA_API int llama_n_ctx (const struct llama_context * ctx); @@ -524,72 +588,47 @@ _lib.llama_n_ctx.argtypes = [llama_context_p] _lib.llama_n_ctx.restype = c_int -# LLAMA_API int llama_n_ctx_train(const struct llama_context * ctx); -def llama_n_ctx_train(ctx: llama_context_p) -> int: - return _lib.llama_n_ctx_train(ctx) +# LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model); +def llama_vocab_type(model: llama_model_p) -> int: + return _lib.llama_vocab_type(model) -_lib.llama_n_ctx_train.argtypes = [llama_context_p] -_lib.llama_n_ctx_train.restype = c_int - - -# LLAMA_API int llama_n_embd (const struct llama_context * ctx); -def llama_n_embd(ctx: llama_context_p) -> int: - return _lib.llama_n_embd(ctx) - - -_lib.llama_n_embd.argtypes = [llama_context_p] -_lib.llama_n_embd.restype = c_int - - -# LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_context * ctx); -def llama_vocab_type(ctx: llama_context_p) -> int: - return _lib.llama_vocab_type(ctx) - - -_lib.llama_vocab_type.argtypes = [llama_context_p] +_lib.llama_vocab_type.argtypes = [llama_model_p] _lib.llama_vocab_type.restype = c_int -# LLAMA_API int llama_model_n_vocab (const struct llama_model * model); -def llama_model_n_vocab(model: llama_model_p) -> int: - return _lib.llama_model_n_vocab(model) +# LLAMA_API int llama_n_vocab (const struct llama_model * model); +def llama_n_vocab(model: llama_model_p) -> int: + return _lib.llama_n_vocab(model) -_lib.llama_model_n_vocab.argtypes = [llama_model_p] -_lib.llama_model_n_vocab.restype = c_int +_lib.llama_n_vocab.argtypes = [llama_model_p] +_lib.llama_n_vocab.restype = c_int -# LLAMA_API int llama_model_n_ctx (const struct llama_model * model); -def llama_model_n_ctx(model: llama_model_p) -> int: - return _lib.llama_model_n_ctx(model) +# LLAMA_API int llama_n_ctx_train(const struct llama_model * model); +def llama_n_ctx_train(model: llama_model_p) -> int: + return _lib.llama_n_ctx_train(model) -_lib.llama_model_n_ctx.argtypes = [llama_model_p] -_lib.llama_model_n_ctx.restype = c_int +_lib.llama_n_ctx_train.argtypes = [llama_model_p] +_lib.llama_n_ctx_train.restype = c_int -# LLAMA_API int llama_model_n_ctx_train(const struct llama_model * model); -def llama_model_n_ctx_train(model: llama_model_p) -> int: - return _lib.llama_model_n_ctx_train(model) +# LLAMA_API int llama_n_embd (const struct llama_model * model); +def llama_n_embd(model: llama_model_p) -> int: + return _lib.llama_n_embd(model) -_lib.llama_model_n_ctx_train.argtypes = [llama_model_p] -_lib.llama_model_n_ctx_train.restype = c_int - - -# LLAMA_API int llama_model_n_embd (const struct llama_model * model); -def llama_model_n_embd(model: llama_model_p) -> int: - return _lib.llama_model_n_embd(model) - - -_lib.llama_model_n_embd.argtypes = [llama_model_p] -_lib.llama_model_n_embd.restype = c_int +_lib.llama_n_embd.argtypes = [llama_model_p] +_lib.llama_n_embd.restype = c_int # // Get a string describing the model type # LLAMA_API int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size); -def llama_model_desc(model: llama_model_p, buf: bytes, buf_size: Union[c_size_t, int]) -> int: +def llama_model_desc( + model: llama_model_p, buf: bytes, buf_size: Union[c_size_t, int] +) -> int: return _lib.llama_model_desc(model, buf, buf_size) @@ -617,6 +656,18 @@ _lib.llama_model_n_params.argtypes = [llama_model_p] _lib.llama_model_n_params.restype = ctypes.c_uint64 +# // Get a llama model tensor +# LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name); +def llama_get_model_tensor( + model: llama_model_p, name: Union[c_char_p, bytes] +) -> c_void_p: + return _lib.llama_get_model_tensor(model, name) + + +_lib.llama_get_model_tensor.argtypes = [llama_model_p, c_char_p] +_lib.llama_get_model_tensor.restype = c_void_p + + # // Returns 0 on success # LLAMA_API int llama_model_quantize( # const char * fname_inp, @@ -638,57 +689,76 @@ _lib.llama_model_quantize.argtypes = [ _lib.llama_model_quantize.restype = c_int -# Apply a LoRA adapter to a loaded model -# path_base_model is the path to a higher quality model to use as a base for -# the layers modified by the adapter. Can be NULL to use the current loaded model. -# The model needs to be reloaded before applying a new adapter, otherwise the adapter -# will be applied on top of the previous one -# Returns 0 on success -# LLAMA_API int llama_apply_lora_from_file( +# // Apply a LoRA adapter to a loaded model +# // path_base_model is the path to a higher quality model to use as a base for +# // the layers modified by the adapter. Can be NULL to use the current loaded model. +# // The model needs to be reloaded before applying a new adapter, otherwise the adapter +# // will be applied on top of the previous one +# // Returns 0 on success +# LLAMA_API DEPRECATED(int llama_apply_lora_from_file( # struct llama_context * ctx, # const char * path_lora, +# float scale, # const char * path_base_model, -# int n_threads); +# int n_threads), +# "use llama_model_apply_lora_from_file instead"); def llama_apply_lora_from_file( ctx: llama_context_p, path_lora: Union[c_char_p, bytes], + scale: Union[c_float, float], path_base_model: Union[c_char_p, bytes], n_threads: Union[c_int, int], ) -> int: - return _lib.llama_apply_lora_from_file(ctx, path_lora, path_base_model, n_threads) + return _lib.llama_apply_lora_from_file( + ctx, path_lora, scale, path_base_model, n_threads + ) -_lib.llama_apply_lora_from_file.argtypes = [llama_context_p, c_char_p, c_char_p, c_int] +_lib.llama_apply_lora_from_file.argtypes = [ + llama_context_p, + c_char_p, + c_float, + c_char_p, + c_int, +] _lib.llama_apply_lora_from_file.restype = c_int # LLAMA_API int llama_model_apply_lora_from_file( # const struct llama_model * model, -# const char * path_lora, -# const char * path_base_model, -# int n_threads); +# const char * path_lora, +# float scale, +# const char * path_base_model, +# int n_threads); def llama_model_apply_lora_from_file( model: llama_model_p, path_lora: Union[c_char_p, bytes], + scale: Union[c_float, float], path_base_model: Union[c_char_p, bytes], n_threads: Union[c_int, int], ) -> int: return _lib.llama_model_apply_lora_from_file( - model, path_lora, path_base_model, n_threads + model, path_lora, scale, path_base_model, n_threads ) _lib.llama_model_apply_lora_from_file.argtypes = [ llama_model_p, c_char_p, + c_float, c_char_p, c_int, ] _lib.llama_model_apply_lora_from_file.restype = c_int +# // +# // KV cache +# // -# Returns the number of tokens in the KV cache -# LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx); + +# // Returns the number of tokens in the KV cache +# LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx), +# "avoid using this, it will be removed in the future, instead - count the tokens in user code"); def llama_get_kv_cache_token_count(ctx: llama_context_p) -> int: return _lib.llama_get_kv_cache_token_count(ctx) @@ -697,14 +767,118 @@ _lib.llama_get_kv_cache_token_count.argtypes = [llama_context_p] _lib.llama_get_kv_cache_token_count.restype = c_int -# Sets the current rng seed. -# LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, int seed); -def llama_set_rng_seed(ctx: llama_context_p, seed: c_uint32): - return _lib.llama_set_rng_seed(ctx, seed) +# // Remove all tokens data of cells in [c0, c1) +# LLAMA_API void llama_kv_cache_tokens_rm( +# struct llama_context * ctx, +# int32_t c0, +# int32_t c1); +def llama_kv_cache_tokens_rm( + ctx: llama_context_p, c0: Union[c_int32, int], c1: Union[c_int32, int] +): + return _lib.llama_kv_cache_tokens_rm(ctx, c0, c1) -_lib.llama_set_rng_seed.argtypes = [llama_context_p, c_int] -_lib.llama_set_rng_seed.restype = None +_lib.llama_kv_cache_tokens_rm.argtypes = [llama_context_p, c_int32, c_int32] +_lib.llama_kv_cache_tokens_rm.restype = None + + +# // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) +# LLAMA_API void llama_kv_cache_seq_rm( +# struct llama_context * ctx, +# llama_seq_id seq_id, +# llama_pos p0, +# llama_pos p1); +def llama_kv_cache_seq_rm( + ctx: llama_context_p, + seq_id: llama_seq_id, + p0: Union[llama_pos, int], + p1: Union[llama_pos, int], +): + return _lib.llama_kv_cache_seq_rm(ctx, seq_id, p0, p1) + + +_lib.llama_kv_cache_seq_rm.argtypes = [ + llama_context_p, + llama_seq_id, + llama_pos, + llama_pos, +] +_lib.llama_kv_cache_seq_rm.restype = None + + +# // Copy all tokens that belong to the specified sequence to another sequence +# // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence +# LLAMA_API void llama_kv_cache_seq_cp( +# struct llama_context * ctx, +# llama_seq_id seq_id_src, +# llama_seq_id seq_id_dst, +# llama_pos p0, +# llama_pos p1); +def llama_kv_cache_seq_cp( + ctx: llama_context_p, + seq_id_src: llama_seq_id, + seq_id_dst: llama_seq_id, + p0: Union[llama_pos, int], + p1: Union[llama_pos, int], +): + return _lib.llama_kv_cache_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1) + + +_lib.llama_kv_cache_seq_cp.argtypes = [ + llama_context_p, + llama_seq_id, + llama_seq_id, + llama_pos, + llama_pos, +] +_lib.llama_kv_cache_seq_cp.restype = None + + +# // Removes all tokens that do not belong to the specified sequence +# LLAMA_API void llama_kv_cache_seq_keep( +# struct llama_context * ctx, +# llama_seq_id seq_id); +def llama_kv_cache_seq_keep( + ctx: llama_context_p, + seq_id: llama_seq_id, +): + return _lib.llama_kv_cache_seq_keep(ctx, seq_id) + + +_lib.llama_kv_cache_seq_keep.argtypes = [llama_context_p, llama_seq_id] +_lib.llama_kv_cache_seq_keep.restype = None + + +# // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) +# // If the KV cache is RoPEd, the KV data is updated accordingly +# LLAMA_API void llama_kv_cache_seq_shift( +# struct llama_context * ctx, +# llama_seq_id seq_id, +# llama_pos p0, +# llama_pos p1, +# llama_pos delta); +def llama_kv_cache_seq_shift( + ctx: llama_context_p, + seq_id: llama_seq_id, + p0: Union[llama_pos, int], + p1: Union[llama_pos, int], + delta: Union[llama_pos, int], +): + return _lib.llama_kv_cache_seq_shift(ctx, seq_id, p0, p1, delta) + + +_lib.llama_kv_cache_seq_shift.argtypes = [ + llama_context_p, + llama_seq_id, + llama_pos, + llama_pos, + llama_pos, +] +_lib.llama_kv_cache_seq_shift.restype = None + +# // +# // State / sessions +# // # Returns the maximum size in bytes of the state (rng, logits, embedding @@ -721,7 +895,9 @@ _lib.llama_get_state_size.restype = c_size_t # Copies the state to the specified destination address. # Destination needs to have allocated enough memory. # Returns the number of bytes copied -# LLAMA_API size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst); +# LLAMA_API size_t llama_copy_state_data( +# struct llama_context * ctx, +# uint8_t * dst); def llama_copy_state_data( ctx: llama_context_p, dst # type: Array[c_uint8] ) -> int: @@ -734,7 +910,9 @@ _lib.llama_copy_state_data.restype = c_size_t # Set the state reading from the specified address # Returns the number of bytes read -# LLAMA_API size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src); +# LLAMA_API size_t llama_set_state_data( +# struct llama_context * ctx, +# uint8_t * src); def llama_set_state_data( ctx: llama_context_p, src # type: Array[c_uint8] ) -> int: @@ -746,7 +924,12 @@ _lib.llama_set_state_data.restype = c_size_t # Save/load session file -# LLAMA_API bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out); +# LLAMA_API bool llama_load_session_file( +# struct llama_context * ctx, +# const char * path_session, +# llama_token * tokens_out, +# size_t n_token_capacity, +# size_t * n_token_count_out); def llama_load_session_file( ctx: llama_context_p, path_session: bytes, @@ -769,7 +952,11 @@ _lib.llama_load_session_file.argtypes = [ _lib.llama_load_session_file.restype = c_size_t -# LLAMA_API bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count); +# LLAMA_API bool llama_save_session_file( +# struct llama_context * ctx, +# const char * path_session, +# const llama_token * tokens, +# size_t n_token_count); def llama_save_session_file( ctx: llama_context_p, path_session: bytes, @@ -787,70 +974,148 @@ _lib.llama_save_session_file.argtypes = [ ] _lib.llama_save_session_file.restype = c_size_t +# // +# // Decoding +# // -# Run the llama inference to obtain the logits and probabilities for the next token. -# tokens + n_tokens is the provided batch of new tokens to process -# n_past is the number of tokens to use from previous eval calls -# Returns 0 on success -# LLAMA_API int llama_eval( + +# // Run the llama inference to obtain the logits and probabilities for the next token(s). +# // tokens + n_tokens is the provided batch of new tokens to process +# // n_past is the number of tokens to use from previous eval calls +# // Returns 0 on success +# // DEPRECATED: use llama_decode() instead +# LLAMA_API DEPRECATED(int llama_eval( # struct llama_context * ctx, -# const llama_token * tokens, -# int n_tokens, -# int n_past, -# int n_threads); +# llama_token * tokens, +# int32_t n_tokens, +# int n_past), +# "use llama_decode() instead"); def llama_eval( ctx: llama_context_p, tokens, # type: Array[llama_token] n_tokens: Union[c_int, int], n_past: Union[c_int, int], - n_threads: Union[c_int, int], ) -> int: - return _lib.llama_eval(ctx, tokens, n_tokens, n_past, n_threads) + return _lib.llama_eval(ctx, tokens, n_tokens, n_past) -_lib.llama_eval.argtypes = [llama_context_p, llama_token_p, c_int, c_int, c_int] +_lib.llama_eval.argtypes = [llama_context_p, llama_token_p, c_int, c_int] _lib.llama_eval.restype = c_int # // Same as llama_eval, but use float matrix input directly. -# LLAMA_API int llama_eval_embd( +# // DEPRECATED: use llama_decode() instead +# LLAMA_API DEPRECATED(int llama_eval_embd( # struct llama_context * ctx, -# const float * embd, -# int n_tokens, -# int n_past, -# int n_threads); +# float * embd, +# int32_t n_tokens, +# int n_past), +# "use llama_decode() instead"); def llama_eval_embd( ctx: llama_context_p, embd, # type: Array[c_float] n_tokens: Union[c_int, int], n_past: Union[c_int, int], - n_threads: Union[c_int, int], ) -> int: - return _lib.llama_eval_embd(ctx, embd, n_tokens, n_past, n_threads) + return _lib.llama_eval_embd(ctx, embd, n_tokens, n_past) -_lib.llama_eval_embd.argtypes = [llama_context_p, c_float_p, c_int, c_int, c_int] +_lib.llama_eval_embd.argtypes = [llama_context_p, c_float_p, c_int, c_int] _lib.llama_eval_embd.restype = c_int -# // Export a static computation graph for context of 511 and batch size of 1 -# // NOTE: since this functionality is mostly for debugging and demonstration purposes, we hardcode these -# // parameters here to keep things simple -# // IMPORTANT: do not use for anything else other than debugging and testing! -# LLAMA_API int llama_eval_export(struct llama_context * ctx, const char * fname); -def llama_eval_export(ctx: llama_context_p, fname: bytes) -> int: - return _lib.llama_eval_export(ctx, fname) +# // Return batch for single sequence of tokens starting at pos_0 +# // +# // NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it +# // +# LLAMA_API struct llama_batch llama_batch_get_one( +# llama_token * tokens, +# int32_t n_tokens, +# llama_pos pos_0, +# llama_seq_id seq_id); +def llama_batch_get_one( + tokens, # type: Array[llama_token] + n_tokens: Union[c_int, int], + pos_0: Union[llama_pos, int], + seq_id: llama_seq_id, +) -> llama_batch: + return _lib.llama_batch_get_one(tokens, n_tokens, pos_0, seq_id) -_lib.llama_eval_export.argtypes = [llama_context_p, c_char_p] -_lib.llama_eval_export.restype = c_int +_lib.llama_batch_get_one.argtypes = [ + llama_token_p, + c_int, + llama_pos, + llama_seq_id, +] +_lib.llama_batch_get_one.restype = llama_batch -# Token logits obtained from the last call to llama_eval() -# The logits for the last token are stored in the last row -# Can be mutated in order to change the probabilities of the next token -# Rows: n_tokens -# Cols: n_vocab +# // Allocates a batch of tokens on the heap +# // The batch has to be freed with llama_batch_free() +# // If embd != 0, llama_batch.embd will be allocated with size of n_tokens * embd * sizeof(float) +# // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token +# // The rest of the llama_batch members are allocated with size n_tokens +# // All members are left uninitialized +# LLAMA_API struct llama_batch llama_batch_init( +# int32_t n_tokens, +# int32_t embd); +def llama_batch_init( + n_tokens: Union[c_int, int], embd: Union[c_int, int] +) -> llama_batch: + return _lib.llama_batch_init(n_tokens, embd) + + +_lib.llama_batch_init.argtypes = [c_int, c_int] +_lib.llama_batch_init.restype = llama_batch + + +# // Frees a batch of tokens allocated with llama_batch_init() +# LLAMA_API void llama_batch_free(struct llama_batch batch); +def llama_batch_free(batch: llama_batch): + return _lib.llama_batch_free(batch) + + +_lib.llama_batch_free.argtypes = [llama_batch] +_lib.llama_batch_free.restype = None + + +# // Positive return values does not mean a fatal error, but rather a warning. +# // 0 - success +# // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) +# // < 0 - error +# LLAMA_API int llama_decode( +# struct llama_context * ctx, +# struct llama_batch batch); +def llama_decode(ctx: llama_context_p, batch: llama_batch) -> int: + return _lib.llama_decode(ctx, batch) + + +_lib.llama_decode.argtypes = [llama_context_p, llama_batch] +_lib.llama_decode.restype = c_int + + +# // Set the number of threads used for decoding +# // n_threads is the number of threads used for generation (single token) +# // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens) +# LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch); +def llama_set_n_threads( + ctx: llama_context_p, + n_threads: Union[c_uint32, int], + n_threads_batch: Union[c_uint32, int], +): + return _lib.llama_set_n_threads(ctx, n_threads, n_threads_batch) + + +_lib.llama_set_n_threads.argtypes = [llama_context_p, c_uint32, c_uint32] +_lib.llama_set_n_threads.restype = None + + +# // Token logits obtained from the last call to llama_eval() +# // The logits for the last token are stored in the last row +# // Logits for which llama_batch.logits[i] == 0 are undefined +# // Rows: n_tokens provided with llama_batch +# // Cols: n_vocab # LLAMA_API float * llama_get_logits(struct llama_context * ctx); def llama_get_logits( ctx: llama_context_p, @@ -862,6 +1127,19 @@ _lib.llama_get_logits.argtypes = [llama_context_p] _lib.llama_get_logits.restype = c_float_p +# // Logits for the ith token. Equivalent to: +# // llama_get_logits(ctx) + i*n_vocab +# LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i); +def llama_get_logits_ith( + ctx: llama_context_p, i: Union[c_int32, int] +): # type: (...) -> Array[float] # type: ignore + return _lib.llama_get_logits_ith(ctx, i) + + +_lib.llama_get_logits_ith.argtypes = [llama_context_p, c_int32] +_lib.llama_get_logits_ith.restype = c_float_p + + # Get the embeddings for the input # shape: [n_embd] (1-dimensional) # LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); @@ -911,7 +1189,7 @@ _lib.llama_token_get_type.restype = ctypes.c_int # LLAMA_API llama_token llama_token_bos(const struct llama_context * ctx); // beginning-of-sentence -def llama_token_bos(ctx: llama_context_p) -> llama_token: +def llama_token_bos(ctx: llama_context_p) -> int: return _lib.llama_token_bos(ctx) @@ -920,7 +1198,7 @@ _lib.llama_token_bos.restype = llama_token # LLAMA_API llama_token llama_token_eos(const struct llama_context * ctx); // end-of-sentence -def llama_token_eos(ctx: llama_context_p) -> llama_token: +def llama_token_eos(ctx: llama_context_p) -> int: return _lib.llama_token_eos(ctx) @@ -929,7 +1207,7 @@ _lib.llama_token_eos.restype = llama_token # LLAMA_API llama_token llama_token_nl (const struct llama_context * ctx); // next-line -def llama_token_nl(ctx: llama_context_p) -> llama_token: +def llama_token_nl(ctx: llama_context_p) -> int: return _lib.llama_token_nl(ctx) @@ -942,41 +1220,18 @@ _lib.llama_token_nl.restype = llama_token # // -# Convert the provided text into tokens. -# The tokens pointer must be large enough to hold the resulting tokens. -# Returns the number of tokens on success, no more than n_max_tokens -# Returns a negative number on failure - the number of tokens that would have been returned -# TODO: not sure if correct +# // Convert the provided text into tokens. +# // The tokens pointer must be large enough to hold the resulting tokens. +# // Returns the number of tokens on success, no more than n_max_tokens +# // Returns a negative number on failure - the number of tokens that would have been returned # LLAMA_API int llama_tokenize( -# struct llama_context * ctx, -# const char * text, -# int text_len, -# llama_token * tokens, -# int n_max_tokens, -# bool add_bos); -def llama_tokenize( - ctx: llama_context_p, - text: bytes, - text_len: Union[c_int, int], - tokens, # type: Array[llama_token] - n_max_tokens: Union[c_int, int], - add_bos: Union[c_bool, int], -) -> int: - return _lib.llama_tokenize(ctx, text, text_len, tokens, n_max_tokens, add_bos) - - -_lib.llama_tokenize.argtypes = [llama_context_p, c_char_p, c_int, llama_token_p, c_int, c_bool] -_lib.llama_tokenize.restype = c_int - - -# LLAMA_API int llama_tokenize_with_model( # const struct llama_model * model, # const char * text, # int text_len, # llama_token * tokens, # int n_max_tokens, # bool add_bos); -def llama_tokenize_with_model( +def llama_tokenize( model: llama_model_p, text: bytes, text_len: Union[c_int, int], @@ -984,10 +1239,10 @@ def llama_tokenize_with_model( n_max_tokens: Union[c_int, int], add_bos: Union[c_bool, bool], ) -> int: - return _lib.llama_tokenize_with_model(model, text, text_len, tokens, n_max_tokens, add_bos) + return _lib.llama_tokenize(model, text, text_len, tokens, n_max_tokens, add_bos) -_lib.llama_tokenize_with_model.argtypes = [ +_lib.llama_tokenize.argtypes = [ llama_model_p, c_char_p, c_int, @@ -995,7 +1250,7 @@ _lib.llama_tokenize_with_model.argtypes = [ c_int, c_bool, ] -_lib.llama_tokenize_with_model.restype = c_int +_lib.llama_tokenize.restype = c_int # // Token Id -> Piece. @@ -1003,39 +1258,23 @@ _lib.llama_tokenize_with_model.restype = c_int # // Does not write null terminator to the buffer. # // User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens. # LLAMA_API int llama_token_to_piece( -# const struct llama_context * ctx, -# llama_token token, -# char * buf, -# int length); +# const struct llama_model * model, +# llama_token token, +# char * buf, +# int length); def llama_token_to_piece( - ctx: llama_context_p, token: llama_token, buf: bytes, length: Union[c_int, int] + model: llama_model_p, + token: llama_token, + buf: Union[c_char_p, bytes], + length: Union[c_int, int], ) -> int: - return _lib.llama_token_to_piece(ctx, token, buf, length) + return _lib.llama_token_to_piece(model, token, buf, length) -_lib.llama_token_to_piece.argtypes = [llama_context_p, llama_token, c_char_p, c_int] +_lib.llama_token_to_piece.argtypes = [llama_model_p, llama_token, c_char_p, c_int] _lib.llama_token_to_piece.restype = c_int -# LLAMA_API int llama_token_to_piece_with_model( -# const struct llama_model * model, -# llama_token token, -# char * buf, -# int length); -def llama_token_to_piece_with_model( - model: llama_model_p, token: llama_token, buf: bytes, length: Union[c_int, int] -) -> int: - return _lib.llama_token_to_piece_with_model(model, token, buf, length) - - -_lib.llama_token_to_piece_with_model.argtypes = [ - llama_model_p, - llama_token, - c_char_p, - c_int, -] -_lib.llama_token_to_piece_with_model.restype = c_int - # // # // Grammar # // @@ -1083,8 +1322,23 @@ _lib.llama_grammar_copy.restype = llama_grammar_p # // +# // Sets the current rng seed. +# LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed); +def llama_set_rng_seed(ctx: llama_context_p, seed: Union[c_uint32, int]): + return _lib.llama_set_rng_seed(ctx, seed) + + +_lib.llama_set_rng_seed.argtypes = [llama_context_p, c_uint32] +_lib.llama_set_rng_seed.restype = None + + # @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. -# LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty); +# LLAMA_API void llama_sample_repetition_penalty( +# struct llama_context * ctx, +# llama_token_data_array * candidates, +# const llama_token * last_tokens, +# size_t last_tokens_size, +# float penalty); def llama_sample_repetition_penalty( ctx: llama_context_p, candidates, # type: _Pointer[llama_token_data_array] @@ -1108,7 +1362,13 @@ _lib.llama_sample_repetition_penalty.restype = None # @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. -# LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence); +# LLAMA_API void llama_sample_frequency_and_presence_penalties( +# struct llama_context * ctx, +# llama_token_data_array * candidates, +# const llama_token * last_tokens, +# size_t last_tokens_size, +# float alpha_frequency, +# float alpha_presence); def llama_sample_frequency_and_presence_penalties( ctx: llama_context_p, candidates, # type: _Pointer[llama_token_data_array] @@ -1168,7 +1428,9 @@ _lib.llama_sample_classifier_free_guidance.restype = None # @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. -# LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates); +# LLAMA_API void llama_sample_softmax( +# struct llama_context * ctx, +# llama_token_data_array * candidates); def llama_sample_softmax( ctx: llama_context_p, candidates # type: _Pointer[llama_token_data] ): @@ -1183,7 +1445,11 @@ _lib.llama_sample_softmax.restype = None # @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 -# LLAMA_API void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep); +# LLAMA_API void llama_sample_top_k( +# struct llama_context * ctx, +# llama_token_data_array * candidates, +# int k, +# size_t min_keep); def llama_sample_top_k( ctx: llama_context_p, candidates, # type: _Pointer[llama_token_data_array] @@ -1203,7 +1469,11 @@ _lib.llama_sample_top_k.restype = None # @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 -# LLAMA_API void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep); +# LLAMA_API void llama_sample_top_p( +# struct llama_context * ctx, +# llama_token_data_array * candidates, +# float p, +# size_t min_keep); def llama_sample_top_p( ctx: llama_context_p, candidates, # type: _Pointer[llama_token_data_array] @@ -1223,7 +1493,11 @@ _lib.llama_sample_top_p.restype = None # @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. -# LLAMA_API void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep); +# LLAMA_API void llama_sample_tail_free( +# struct llama_context * ctx, +# llama_token_data_array * candidates, +# float z, +# size_t min_keep); def llama_sample_tail_free( ctx: llama_context_p, candidates, # type: _Pointer[llama_token_data_array] @@ -1243,7 +1517,11 @@ _lib.llama_sample_tail_free.restype = None # @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. -# LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep); +# LLAMA_API void llama_sample_typical( +# struct llama_context * ctx, +# llama_token_data_array * candidates, +# float p, +# size_t min_keep); def llama_sample_typical( ctx: llama_context_p, candidates, # type: _Pointer[llama_token_data_array] @@ -1262,7 +1540,31 @@ _lib.llama_sample_typical.argtypes = [ _lib.llama_sample_typical.restype = None -# LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp); +# LLAMA_API void llama_sample_temp( +# struct llama_context * ctx, +# llama_token_data_array * candidates, +# float temp); +def llama_sample_temp( + ctx: llama_context_p, + candidates, # type: _Pointer[llama_token_data_array] + temp: Union[c_float, float], +): + return _lib.llama_sample_temp(ctx, candidates, temp) + + +_lib.llama_sample_temp.argtypes = [ + llama_context_p, + llama_token_data_array_p, + c_float, +] +_lib.llama_sample_temp.restype = None + + +# LLAMA_API DEPRECATED(void llama_sample_temperature( +# struct llama_context * ctx, +# llama_token_data_array * candidates, +# float temp), +# "use llama_sample_temp instead"); def llama_sample_temperature( ctx: llama_context_p, candidates, # type: _Pointer[llama_token_data_array] @@ -1302,7 +1604,13 @@ _lib.llama_sample_grammar.restype = None # @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. # @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. # @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. -# LLAMA_API llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu); +# LLAMA_API llama_token llama_sample_token_mirostat( +# struct llama_context * ctx, +# llama_token_data_array * candidates, +# float tau, +# float eta, +# int m, +# float * mu); def llama_sample_token_mirostat( ctx: llama_context_p, candidates, # type: _Pointer[llama_token_data_array] @@ -1330,7 +1638,12 @@ _lib.llama_sample_token_mirostat.restype = llama_token # @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. # @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. # @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. -# LLAMA_API llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu); +# LLAMA_API llama_token llama_sample_token_mirostat_v2( +# struct llama_context * ctx, +# llama_token_data_array * candidates, +# float tau, +# float eta, +# float * mu); def llama_sample_token_mirostat_v2( ctx: llama_context_p, candidates, # type: _Pointer[llama_token_data_array] @@ -1352,7 +1665,9 @@ _lib.llama_sample_token_mirostat_v2.restype = llama_token # @details Selects the token with the highest probability. -# LLAMA_API llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates); +# LLAMA_API llama_token llama_sample_token_greedy( +# struct llama_context * ctx, +# llama_token_data_array * candidates); def llama_sample_token_greedy( ctx: llama_context_p, candidates, # type: _Pointer[llama_token_data_array] @@ -1368,7 +1683,9 @@ _lib.llama_sample_token_greedy.restype = llama_token # @details Randomly selects a token from the candidates based on their probabilities. -# LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates); +# LLAMA_API llama_token llama_sample_token( +# struct llama_context * ctx, +# llama_token_data_array * candidates); def llama_sample_token( ctx: llama_context_p, candidates, # type: _Pointer[llama_token_data_array] @@ -1384,7 +1701,10 @@ _lib.llama_sample_token.restype = llama_token # /// @details Accepts the sampled token into the grammar -# LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token); +# LLAMA_API void llama_grammar_accept_token( +# struct llama_context * ctx, +# struct llama_grammar * grammar, +# llama_token token); def llama_grammar_accept_token( ctx: llama_context_p, grammar: llama_grammar_p, @@ -1399,16 +1719,18 @@ _lib.llama_grammar_accept_token.argtypes = [ llama_token, ] _lib.llama_grammar_accept_token.restype = None + + # // # // Beam search # // - # struct llama_beam_view { # const llama_token * tokens; + # size_t n_tokens; -# float p; // Cumulative beam probability (renormalized relative to all beams) -# bool eob; // Callback should set this to true when a beam is at end-of-beam. +# float p; // Cumulative beam probability (renormalized relative to all beams) +# bool eob; // Callback should set this to true when a beam is at end-of-beam. # }; class llama_beam_view(ctypes.Structure): _fields_ = [ @@ -1427,7 +1749,7 @@ class llama_beam_view(ctypes.Structure): # struct llama_beam_view * beam_views; # size_t n_beams; // Number of elements in beam_views[]. # size_t common_prefix_length; // Current max length of prefix tokens shared by all beams. -# bool last_call; // True iff this is the last callback invocation. +# bool last_call; // True iff this is the last callback invocation. # }; class llama_beams_state(ctypes.Structure): _fields_ = [ @@ -1453,7 +1775,13 @@ llama_beam_search_callback_fn_t = ctypes.CFUNCTYPE(None, c_void_p, llama_beams_s # /// @param n_past Number of tokens already evaluated. # /// @param n_predict Maximum number of tokens to predict. EOS may occur earlier. # /// @param n_threads Number of threads as passed to llama_eval(). -# LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void * callback_data, size_t n_beams, int n_past, int n_predict, int n_threads); +# LLAMA_API void llama_beam_search( +# struct llama_context * ctx, +# llama_beam_search_callback_fn_t callback, +# void * callback_data, +# size_t n_beams, +# int n_past, +# int n_predict); def llama_beam_search( ctx: llama_context_p, callback: "ctypes._CFuncPtr[None, c_void_p, llama_beams_state]", # type: ignore @@ -1461,12 +1789,21 @@ def llama_beam_search( n_beams: Union[c_size_t, int], n_past: Union[c_int, int], n_predict: Union[c_int, int], - n_threads: Union[c_int, int], ): return _lib.llama_beam_search( - ctx, callback, callback_data, n_beams, n_past, n_predict, n_threads + ctx, callback, callback_data, n_beams, n_past, n_predict ) +_lib.llama_beam_search.argtypes = [ + llama_context_p, + llama_beam_search_callback_fn_t, + c_void_p, + c_size_t, + c_int, + c_int, +] +_lib.llama_beam_search.restype = None + # Performance information @@ -1508,9 +1845,10 @@ _lib.llama_print_system_info.argtypes = [] _lib.llama_print_system_info.restype = c_char_p +# NOTE: THIS IS CURRENTLY BROKEN AS ggml_log_callback IS NOT EXPOSED IN LLAMA.H # // Set callback for all future logging events. # // If this is not called, or NULL is supplied, everything is output on stderr. -# LLAMA_API void llama_log_set(llama_log_callback log_callback, void * user_data); +# LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data); def llama_log_set( log_callback: "ctypes._FuncPointer", user_data: c_void_p # type: ignore ): @@ -1528,4 +1866,3 @@ def llama_dump_timing_info_yaml(stream: ctypes.c_void_p, ctx: llama_context_p): _lib.llama_dump_timing_info_yaml.argtypes = [ctypes.c_void_p, llama_context_p] _lib.llama_dump_timing_info_yaml.restype = None - diff --git a/tests/test_llama.py b/tests/test_llama.py index 3b432b5..bb2b42c 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -119,7 +119,7 @@ def test_llama_pickle(): def test_utf8(monkeypatch): llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True) - n_vocab = llama_cpp.llama_n_vocab(llama.ctx) + n_vocab = llama.n_vocab() ## Set up mock function def mock_eval(*args, **kwargs): diff --git a/vendor/llama.cpp b/vendor/llama.cpp index a98b163..16bc66d 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit a98b1633d5a94d0aa84c7c16e1f8df5ac21fc850 +Subproject commit 16bc66d9479edd5ee12ec734973554d4493c5dfa