From 92c077136d1f0b029f8907a79eae009a750005e2 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Sat, 15 Apr 2023 12:03:09 -0400 Subject: [PATCH] Add experimental cache --- llama_cpp/llama.py | 69 +++++++++++++++++++++++++++++++++--- llama_cpp/server/__main__.py | 5 ++- 2 files changed, 69 insertions(+), 5 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 121f91d..b92801c 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -11,6 +11,15 @@ from . import llama_cpp from .llama_types import * +class LlamaCache: + """Cache for a llama.cpp model. + + NOTE: This implementation currently only tells the Llama class to avoid reprocessing bytes and continue from the last + completion. It does not actually cache the results.""" + + pass + + class Llama: """High-level Python wrapper for a llama.cpp model.""" @@ -82,6 +91,14 @@ class Llama: self.n_past = 0 self.all_logits: List[List[float]] = [] # TODO: Use an array instead of a list. + ### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support + ### saving and restoring state, this allows us to continue a completion if the last + ### completion_bytes is a prefix to the prompt passed in. However this is actually incorrect + ### because it does not take into account stop tokens which have been processed by the model. + self._completion_bytes: List[bytes] = [] + self._cache: Optional[LlamaCache] = None + ### + self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1) if not os.path.exists(model_path): @@ -135,6 +152,14 @@ class Llama: output += llama_cpp.llama_token_to_str(self.ctx, token) return output + def set_cache(self, cache: Optional[LlamaCache]): + """Set the cache. + + Args: + cache: The cache to set. + """ + self._cache = cache + def reset(self): """Reset the model state.""" self.last_n_tokens_data.extend( @@ -245,6 +270,17 @@ class Llama: The generated tokens. """ assert self.ctx is not None + ### HACK + if ( + reset + and self._cache + and len(self.tokens) > 0 + and self.tokens == tokens[: len(self.tokens)] + ): + if self.verbose: + print("generate cache hit", file=sys.stderr) + reset = False + ### if reset: self.reset() while True: @@ -361,6 +397,21 @@ class Llama: "logprobs is not supported for models created with logits_all=False" ) + ### HACK + reset: bool = True + _prompt: bytes = prompt.encode("utf-8") + _completion: bytes = b"".join(self._completion_bytes) + if len(_completion) and self._cache and _prompt.startswith(_completion): + if self.verbose: + print("completion cache hit", file=sys.stderr) + reset = False + _prompt = _prompt[len(_completion) :] + prompt_tokens = self.tokenize(b" " + _prompt) + self._completion_bytes.append(_prompt) + else: + self._completion_bytes = [prompt.encode("utf-8")] + ### + finish_reason = "length" for token in self.generate( prompt_tokens, @@ -368,6 +419,7 @@ class Llama: top_p=top_p, temp=temperature, repeat_penalty=repeat_penalty, + reset=reset, ): if token == llama_cpp.llama_token_eos(): text = self.detokenize(completion_tokens) @@ -397,6 +449,9 @@ class Llama: break text = all_text[: len(all_text) - longest] returned_characters += len(text[start:]) + ### HACK + self._completion_bytes.append(text[start:]) + ### yield { "id": completion_id, "object": "text_completion", @@ -418,6 +473,9 @@ class Llama: break if stream: + ### HACK + self._completion_bytes.append(text[returned_characters:]) + ### yield { "id": completion_id, "object": "text_completion", @@ -434,13 +492,16 @@ class Llama: } return - text = text.decode("utf-8") + ### HACK + self._completion_bytes.append(text) + ### + text_str = text.decode("utf-8") if echo: - text = prompt + text + text_str = prompt + text_str if suffix is not None: - text = text + suffix + text_str = text_str + suffix logprobs_or_none: Optional[CompletionLogprobs] = None if logprobs is not None: @@ -493,7 +554,7 @@ class Llama: "model": self.model_path, "choices": [ { - "text": text, + "text": text_str, "index": 0, "logprobs": logprobs_or_none, "finish_reason": finish_reason, diff --git a/llama_cpp/server/__main__.py b/llama_cpp/server/__main__.py index 7fc3c57..48481c6 100644 --- a/llama_cpp/server/__main__.py +++ b/llama_cpp/server/__main__.py @@ -35,6 +35,7 @@ class Settings(BaseSettings): embedding: bool = True last_n_tokens_size: int = 64 logits_all: bool = False + cache: bool = False # WARNING: This is an experimental feature app = FastAPI( @@ -60,6 +61,9 @@ llama = llama_cpp.Llama( n_ctx=settings.n_ctx, last_n_tokens_size=settings.last_n_tokens_size, ) +if settings.cache: + cache = llama_cpp.LlamaCache() + llama.set_cache(cache) llama_lock = Lock() @@ -68,7 +72,6 @@ def get_llama(): yield llama - class CreateCompletionRequest(BaseModel): prompt: Union[str, List[str]] suffix: Optional[str] = Field(None)