diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 0a69b2c..487f44d 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -4,7 +4,7 @@ import uuid import time import math import multiprocessing -from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque +from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple from collections import deque from . import llama_cpp @@ -15,15 +15,34 @@ class LlamaCache: """Cache for a llama.cpp model.""" def __init__(self): - self.cache_state: Dict[Sequence[llama_cpp.llama_token], "LlamaState"] = dict() + self.cache_state: Dict[Tuple[llama_cpp.llama_token, ...], "LlamaState"] = dict() + + def _sorted_keys(self) -> List[Tuple[llama_cpp.llama_token, ...]]: + return [ + key + for _, key in sorted( + ((len(key), key) for key in self.cache_state.keys()), reverse=True + ) + ] + + def _find_key( + self, key: Tuple[llama_cpp.llama_token, ...] + ) -> Optional[Tuple[llama_cpp.llama_token, ...]]: + for k in self._sorted_keys(): + if key[: len(k)] == k: + return k + return None def __getitem__( self, key: Sequence[llama_cpp.llama_token] ) -> Optional["LlamaState"]: - return self.cache_state.get(tuple(key), None) + _key = self._find_key(tuple(key)) + if _key is None: + return None + return self.cache_state[_key] def __contains__(self, key: Sequence[llama_cpp.llama_token]) -> bool: - return tuple(key) in self.cache_state + return self._find_key(tuple(key)) is not None def __setitem__(self, key: Sequence[llama_cpp.llama_token], value: "LlamaState"): self.cache_state = dict() # NOTE: Currently limit to one cache entry. @@ -295,7 +314,7 @@ class Llama: if ( reset and len(self.eval_tokens) > 0 - and self.eval_tokens == tokens[: len(self.eval_tokens)] + and tuple(self.eval_tokens) == tuple(tokens[: len(self.eval_tokens)]) ): if self.verbose: print("generate cache hit", file=sys.stderr) @@ -438,6 +457,8 @@ class Llama: if self.cache and len(completion_tokens) == 0: if prompt_tokens not in self.cache: + if self.verbose: + print("cache miss", file=sys.stderr) self.cache[prompt_tokens] = self.save_state() completion_tokens.append(token)