diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index f6017a1..6ac7f1d 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -4,6 +4,7 @@ import uuid import time import math import multiprocessing +from abc import ABC from typing import ( List, Optional, @@ -26,21 +27,94 @@ import numpy as np import numpy.typing as npt +class LlamaCache(ABC): + """Base cache class for a llama.cpp model.""" -class LlamaCache: - """Cache for a llama.cpp model.""" + def __init__(self, capacity_bytes: int = (2 << 30)): + pass + + @property + def cache_size(self): + return 0 + + def _find_longest_prefix_key( + self, + key: Tuple[int, ...], + ) -> Optional[Tuple[int, ...]]: + pass + + def __getitem__(self, key: Sequence[int]) -> "LlamaState": + pass + + def __contains__(self, key: Sequence[int]) -> bool: + pass + + def __setitem__(self, key: Sequence[int], value: "LlamaState"): + pass + + +class LlamaRAMCache(LlamaCache): + """Cache for a llama.cpp model using RAM.""" + + def __init__(self, capacity_bytes: int = (2 << 30)): + super().__init__(capacity_bytes) + self.capacity_bytes = capacity_bytes + self.cache_state: OrderedDict[Tuple[int, ...], "LlamaState"] = OrderedDict() + + @property + def cache_size(self): + return sum([state.llama_state_size for state in self.cache_state.values()]) + + def _find_longest_prefix_key( + self, + key: Tuple[int, ...], + ) -> Optional[Tuple[int, ...]]: + min_len = 0 + min_key = None + keys = ( + (k, Llama.longest_token_prefix(k, key)) for k in self.cache_state.keys() + ) + for k, prefix_len in keys: + if prefix_len > min_len: + min_len = prefix_len + min_key = k + return min_key + + def __getitem__(self, key: Sequence[int]) -> "LlamaState": + key = tuple(key) + _key = self._find_longest_prefix_key(key) + if _key is None: + raise KeyError("Key not found") + value = self.cache_state[_key] + self.cache_state.move_to_end(_key) + return value + + def __contains__(self, key: Sequence[int]) -> bool: + return self._find_longest_prefix_key(tuple(key)) is not None + + def __setitem__(self, key: Sequence[int], value: "LlamaState"): + key = tuple(key) + if key in self.cache_state: + del self.cache_state[key] + self.cache_state[key] = value + while self.cache_size > self.capacity_bytes: + self.cache_state.popitem(last=False) + + +class LlamaDiskCache(LlamaCache): + """Cache for a llama.cpp model using disk.""" def __init__(self, cache_dir="./llama_cache", capacity_bytes: int = (2 << 30)): + super().__init__(capacity_bytes) self.cache = diskcache.Cache(cache_dir) - self.capacity_bytes = capacity_bytes @property def cache_size(self): return self.cache.volume() def _find_longest_prefix_key( - self, - key: Tuple[int, ...], + self, + key: Tuple[int, ...], ) -> Optional[Tuple[int, ...]]: min_len = 0 min_key = None @@ -60,9 +134,6 @@ class LlamaCache: self.cache.push(_key) return value - def __contains__(self, key: Sequence[int]) -> bool: - return self._find_longest_prefix_key(tuple(key)) is not None - def __setitem__(self, key: Sequence[int], value: "LlamaState"): key = tuple(key) if key in self.cache: @@ -73,16 +144,15 @@ class LlamaCache: del self.cache[key_to_remove] - class LlamaState: def __init__( - self, - eval_tokens: Deque[int], - eval_logits: Deque[List[float]], - input_ids: npt.NDArray[np.intc], - scores: npt.NDArray[np.single], - llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8] - llama_state_size: int, + self, + eval_tokens: Deque[int], + eval_logits: Deque[List[float]], + input_ids: npt.NDArray[np.intc], + scores: npt.NDArray[np.single], + llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8] + llama_state_size: int, ): self.eval_tokens = eval_tokens self.eval_logits = eval_logits @@ -114,25 +184,25 @@ class Llama: """High-level Python wrapper for a llama.cpp model.""" def __init__( - self, - model_path: str, - # NOTE: These parameters are likely to change in the future. - n_ctx: int = 512, - n_parts: int = -1, - n_gpu_layers: int = 0, - seed: int = 1337, - f16_kv: bool = True, - logits_all: bool = False, - vocab_only: bool = False, - use_mmap: bool = True, - use_mlock: bool = False, - embedding: bool = False, - n_threads: Optional[int] = None, - n_batch: int = 512, - last_n_tokens_size: int = 64, - lora_base: Optional[str] = None, - lora_path: Optional[str] = None, - verbose: bool = True, + self, + model_path: str, + # NOTE: These parameters are likely to change in the future. + n_ctx: int = 512, + n_parts: int = -1, + n_gpu_layers: int = 0, + seed: int = 1337, + f16_kv: bool = True, + logits_all: bool = False, + vocab_only: bool = False, + use_mmap: bool = True, + use_mlock: bool = False, + embedding: bool = False, + n_threads: Optional[int] = None, + n_batch: int = 512, + last_n_tokens_size: int = 64, + lora_base: Optional[str] = None, + lora_path: Optional[str] = None, + verbose: bool = True, ): """Load a llama.cpp model from `model_path`. @@ -201,12 +271,12 @@ class Llama: if self.lora_path: if llama_cpp.llama_apply_lora_from_file( - self.ctx, - llama_cpp.c_char_p(self.lora_path.encode("utf-8")), - llama_cpp.c_char_p(self.lora_base.encode("utf-8")) - if self.lora_base is not None - else llama_cpp.c_char_p(0), - llama_cpp.c_int(self.n_threads), + self.ctx, + llama_cpp.c_char_p(self.lora_path.encode("utf-8")), + llama_cpp.c_char_p(self.lora_base.encode("utf-8")) + if self.lora_base is not None + else llama_cpp.c_char_p(0), + llama_cpp.c_int(self.n_threads), ): raise RuntimeError( f"Failed to apply LoRA from lora path: {self.lora_path} to base path: {self.lora_base}" @@ -317,7 +387,7 @@ class Llama: assert self.ctx is not None n_ctx = self._n_ctx for i in range(0, len(tokens), self.n_batch): - batch = tokens[i : min(len(tokens), i + self.n_batch)] + batch = tokens[i: min(len(tokens), i + self.n_batch)] n_past = min(n_ctx - len(batch), len(self._input_ids)) n_tokens = len(batch) return_code = llama_cpp.llama_eval( @@ -339,28 +409,28 @@ class Llama: n_vocab = self._n_vocab cols = n_vocab logits_view = llama_cpp.llama_get_logits(self.ctx) - logits = [logits_view[i * cols : (i + 1) * cols] for i in range(rows)] + logits = [logits_view[i * cols: (i + 1) * cols] for i in range(rows)] self.eval_logits.extend(logits) self._scores: npt.NDArray[np.single] = np.concatenate( (self._scores, np.array(logits, dtype=np.single)), axis=0 ) def _sample( - self, - last_n_tokens_data, # type: llama_cpp.Array[llama_cpp.llama_token] - last_n_tokens_size: llama_cpp.c_int, - top_k: llama_cpp.c_int, - top_p: llama_cpp.c_float, - temp: llama_cpp.c_float, - tfs_z: llama_cpp.c_float, - repeat_penalty: llama_cpp.c_float, - frequency_penalty: llama_cpp.c_float, - presence_penalty: llama_cpp.c_float, - mirostat_mode: llama_cpp.c_int, - mirostat_tau: llama_cpp.c_float, - mirostat_eta: llama_cpp.c_float, - penalize_nl: bool = True, - logits_processor: Optional[LogitsProcessorList] = None, + self, + last_n_tokens_data, # type: llama_cpp.Array[llama_cpp.llama_token] + last_n_tokens_size: llama_cpp.c_int, + top_k: llama_cpp.c_int, + top_p: llama_cpp.c_float, + temp: llama_cpp.c_float, + tfs_z: llama_cpp.c_float, + repeat_penalty: llama_cpp.c_float, + frequency_penalty: llama_cpp.c_float, + presence_penalty: llama_cpp.c_float, + mirostat_mode: llama_cpp.c_int, + mirostat_tau: llama_cpp.c_float, + mirostat_eta: llama_cpp.c_float, + penalize_nl: bool = True, + logits_processor: Optional[LogitsProcessorList] = None, ): assert self.ctx is not None assert len(self.eval_logits) > 0 @@ -480,19 +550,19 @@ class Llama: ) def sample( - self, - top_k: int = 40, - top_p: float = 0.95, - temp: float = 0.80, - repeat_penalty: float = 1.1, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, - tfs_z: float = 1.0, - mirostat_mode: int = 0, - mirostat_eta: float = 0.1, - mirostat_tau: float = 5.0, - penalize_nl: bool = True, - logits_processor: Optional[LogitsProcessorList] = None, + self, + top_k: int = 40, + top_p: float = 0.95, + temp: float = 0.80, + repeat_penalty: float = 1.1, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_eta: float = 0.1, + mirostat_tau: float = 5.0, + penalize_nl: bool = True, + logits_processor: Optional[LogitsProcessorList] = None, ): """Sample a token from the model. @@ -508,7 +578,7 @@ class Llama: assert self.ctx is not None last_n_tokens_data = [llama_cpp.llama_token(0)] * max( 0, self.last_n_tokens_size - len(self._input_ids) - ) + self._input_ids[-self.last_n_tokens_size :].tolist() + ) + self._input_ids[-self.last_n_tokens_size:].tolist() return self._sample( last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)( *last_n_tokens_data @@ -529,21 +599,21 @@ class Llama: ) def generate( - self, - tokens: Sequence[int], - top_k: int = 40, - top_p: float = 0.95, - temp: float = 0.80, - repeat_penalty: float = 1.1, - reset: bool = True, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, - tfs_z: float = 1.0, - mirostat_mode: int = 0, - mirostat_tau: float = 5.0, - mirostat_eta: float = 0.1, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, + self, + tokens: Sequence[int], + top_k: int = 40, + top_p: float = 0.95, + temp: float = 0.80, + repeat_penalty: float = 1.1, + reset: bool = True, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, ) -> Generator[int, Optional[Sequence[int]], None]: """Create a generator of tokens from a prompt. @@ -606,7 +676,7 @@ class Llama: logits_processor=logits_processor, ) if stopping_criteria is not None and stopping_criteria( - self._input_ids.tolist(), self._scores[-1, :].tolist() + self._input_ids.tolist(), self._scores[-1, :].tolist() ): return tokens_or_none = yield token @@ -615,7 +685,7 @@ class Llama: tokens.extend(tokens_or_none) def create_embedding( - self, input: Union[str, List[str]], model: Optional[str] = None + self, input: Union[str, List[str]], model: Optional[str] = None ) -> Embedding: """Embed a string. @@ -650,8 +720,8 @@ class Llama: n_tokens = len(tokens) total_tokens += n_tokens embedding = llama_cpp.llama_get_embeddings(self.ctx)[ - : llama_cpp.llama_n_embd(self.ctx) - ] + : llama_cpp.llama_n_embd(self.ctx) + ] data.append( { @@ -685,27 +755,27 @@ class Llama: return list(map(float, self.create_embedding(input)["data"][0]["embedding"])) def _create_completion( - self, - prompt: str, - suffix: Optional[str] = None, - max_tokens: int = 16, - temperature: float = 0.8, - top_p: float = 0.95, - logprobs: Optional[int] = None, - echo: bool = False, - stop: Optional[Union[str, List[str]]] = [], - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, - repeat_penalty: float = 1.1, - top_k: int = 40, - stream: bool = False, - tfs_z: float = 1.0, - mirostat_mode: int = 0, - mirostat_tau: float = 5.0, - mirostat_eta: float = 0.1, - model: Optional[str] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - logits_processor: Optional[LogitsProcessorList] = None, + self, + prompt: str, + suffix: Optional[str] = None, + max_tokens: int = 16, + temperature: float = 0.8, + top_p: float = 0.95, + logprobs: Optional[int] = None, + echo: bool = False, + stop: Optional[Union[str, List[str]]] = [], + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + repeat_penalty: float = 1.1, + top_k: int = 40, + stream: bool = False, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, + model: Optional[str] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + logits_processor: Optional[LogitsProcessorList] = None, ) -> Union[Iterator[Completion], Iterator[CompletionChunk]]: assert self.ctx is not None @@ -757,19 +827,19 @@ class Llama: finish_reason = "length" multibyte_fix = 0 for token in self.generate( - prompt_tokens, - top_k=top_k, - top_p=top_p, - temp=temperature, - tfs_z=tfs_z, - mirostat_mode=mirostat_mode, - mirostat_tau=mirostat_tau, - mirostat_eta=mirostat_eta, - frequency_penalty=frequency_penalty, - presence_penalty=presence_penalty, - repeat_penalty=repeat_penalty, - stopping_criteria=stopping_criteria, - logits_processor=logits_processor, + prompt_tokens, + top_k=top_k, + top_p=top_p, + temp=temperature, + tfs_z=tfs_z, + mirostat_mode=mirostat_mode, + mirostat_tau=mirostat_tau, + mirostat_eta=mirostat_eta, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + repeat_penalty=repeat_penalty, + stopping_criteria=stopping_criteria, + logits_processor=logits_processor, ): if token == self._token_eos: text = self.detokenize(completion_tokens) @@ -821,7 +891,7 @@ class Llama: token_end_position += len(self.detokenize([token])) # Check if stop sequence is in the token if token_end_position >= ( - remaining_length - first_stop_position - 1 + remaining_length - first_stop_position - 1 ): break logprobs_or_none: Optional[CompletionLogprobs] = None @@ -882,7 +952,7 @@ class Llama: break if stopping_criteria is not None and stopping_criteria( - self._input_ids.tolist(), self._scores[-1, :].tolist() + self._input_ids.tolist(), self._scores[-1, :].tolist() ): text = self.detokenize(completion_tokens) finish_reason = "stop" @@ -947,8 +1017,8 @@ class Llama: "choices": [ { "text": last_text[ - : len(last_text) - (token_end_position - end) - ].decode("utf-8", errors="ignore"), + : len(last_text) - (token_end_position - end) + ].decode("utf-8", errors="ignore"), "index": 0, "logprobs": logprobs_or_none, "finish_reason": finish_reason, @@ -1014,10 +1084,10 @@ class Llama: for token in all_tokens ] all_logprobs = [ - Llama.logits_to_logprobs(row.tolist()) for row in self._scores - ][token_offset:] + Llama.logits_to_logprobs(row.tolist()) for row in self._scores + ][token_offset:] for token, token_str, logprobs_token in zip( - all_tokens, all_token_strs, all_logprobs + all_tokens, all_token_strs, all_logprobs ): text_offsets.append(text_offset) text_offset += len(token_str) @@ -1068,27 +1138,27 @@ class Llama: } def create_completion( - self, - prompt: str, - suffix: Optional[str] = None, - max_tokens: int = 128, - temperature: float = 0.8, - top_p: float = 0.95, - logprobs: Optional[int] = None, - echo: bool = False, - stop: Optional[Union[str, List[str]]] = [], - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, - repeat_penalty: float = 1.1, - top_k: int = 40, - stream: bool = False, - tfs_z: float = 1.0, - mirostat_mode: int = 0, - mirostat_tau: float = 5.0, - mirostat_eta: float = 0.1, - model: Optional[str] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - logits_processor: Optional[LogitsProcessorList] = None, + self, + prompt: str, + suffix: Optional[str] = None, + max_tokens: int = 128, + temperature: float = 0.8, + top_p: float = 0.95, + logprobs: Optional[int] = None, + echo: bool = False, + stop: Optional[Union[str, List[str]]] = [], + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + repeat_penalty: float = 1.1, + top_k: int = 40, + stream: bool = False, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, + model: Optional[str] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + logits_processor: Optional[LogitsProcessorList] = None, ) -> Union[Completion, Iterator[CompletionChunk]]: """Generate text from a prompt. @@ -1141,27 +1211,27 @@ class Llama: return completion def __call__( - self, - prompt: str, - suffix: Optional[str] = None, - max_tokens: int = 128, - temperature: float = 0.8, - top_p: float = 0.95, - logprobs: Optional[int] = None, - echo: bool = False, - stop: Optional[Union[str, List[str]]] = [], - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, - repeat_penalty: float = 1.1, - top_k: int = 40, - stream: bool = False, - tfs_z: float = 1.0, - mirostat_mode: int = 0, - mirostat_tau: float = 5.0, - mirostat_eta: float = 0.1, - model: Optional[str] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - logits_processor: Optional[LogitsProcessorList] = None, + self, + prompt: str, + suffix: Optional[str] = None, + max_tokens: int = 128, + temperature: float = 0.8, + top_p: float = 0.95, + logprobs: Optional[int] = None, + echo: bool = False, + stop: Optional[Union[str, List[str]]] = [], + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + repeat_penalty: float = 1.1, + top_k: int = 40, + stream: bool = False, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, + model: Optional[str] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + logits_processor: Optional[LogitsProcessorList] = None, ) -> Union[Completion, Iterator[CompletionChunk]]: """Generate text from a prompt. @@ -1209,7 +1279,7 @@ class Llama: ) def _convert_text_completion_to_chat( - self, completion: Completion + self, completion: Completion ) -> ChatCompletion: return { "id": "chat" + completion["id"], @@ -1230,8 +1300,8 @@ class Llama: } def _convert_text_completion_chunks_to_chat( - self, - chunks: Iterator[CompletionChunk], + self, + chunks: Iterator[CompletionChunk], ) -> Iterator[ChatCompletionChunk]: for i, chunk in enumerate(chunks): if i == 0: @@ -1267,22 +1337,22 @@ class Llama: } def create_chat_completion( - self, - messages: List[ChatCompletionMessage], - temperature: float = 0.2, - top_p: float = 0.95, - top_k: int = 40, - stream: bool = False, - stop: Optional[Union[str, List[str]]] = [], - max_tokens: int = 256, - presence_penalty: float = 0.0, - frequency_penalty: float = 0.0, - repeat_penalty: float = 1.1, - tfs_z: float = 1.0, - mirostat_mode: int = 0, - mirostat_tau: float = 5.0, - mirostat_eta: float = 0.1, - model: Optional[str] = None, + self, + messages: List[ChatCompletionMessage], + temperature: float = 0.2, + top_p: float = 0.95, + top_k: int = 40, + stream: bool = False, + stop: Optional[Union[str, List[str]]] = [], + max_tokens: int = 256, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + repeat_penalty: float = 1.1, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, + model: Optional[str] = None, ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]: """Generate a chat completion from a list of messages.