Add experimental cache

This commit is contained in:
Andrei Betlen 2023-04-15 12:03:09 -04:00
parent a6372a7ae5
commit 92c077136d
2 changed files with 69 additions and 5 deletions

View file

@ -11,6 +11,15 @@ from . import llama_cpp
from .llama_types import *
class LlamaCache:
"""Cache for a llama.cpp model.
NOTE: This implementation currently only tells the Llama class to avoid reprocessing bytes and continue from the last
completion. It does not actually cache the results."""
pass
class Llama:
"""High-level Python wrapper for a llama.cpp model."""
@ -82,6 +91,14 @@ class Llama:
self.n_past = 0
self.all_logits: List[List[float]] = [] # TODO: Use an array instead of a list.
### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support
### saving and restoring state, this allows us to continue a completion if the last
### completion_bytes is a prefix to the prompt passed in. However this is actually incorrect
### because it does not take into account stop tokens which have been processed by the model.
self._completion_bytes: List[bytes] = []
self._cache: Optional[LlamaCache] = None
###
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
if not os.path.exists(model_path):
@ -135,6 +152,14 @@ class Llama:
output += llama_cpp.llama_token_to_str(self.ctx, token)
return output
def set_cache(self, cache: Optional[LlamaCache]):
"""Set the cache.
Args:
cache: The cache to set.
"""
self._cache = cache
def reset(self):
"""Reset the model state."""
self.last_n_tokens_data.extend(
@ -245,6 +270,17 @@ class Llama:
The generated tokens.
"""
assert self.ctx is not None
### HACK
if (
reset
and self._cache
and len(self.tokens) > 0
and self.tokens == tokens[: len(self.tokens)]
):
if self.verbose:
print("generate cache hit", file=sys.stderr)
reset = False
###
if reset:
self.reset()
while True:
@ -361,6 +397,21 @@ class Llama:
"logprobs is not supported for models created with logits_all=False"
)
### HACK
reset: bool = True
_prompt: bytes = prompt.encode("utf-8")
_completion: bytes = b"".join(self._completion_bytes)
if len(_completion) and self._cache and _prompt.startswith(_completion):
if self.verbose:
print("completion cache hit", file=sys.stderr)
reset = False
_prompt = _prompt[len(_completion) :]
prompt_tokens = self.tokenize(b" " + _prompt)
self._completion_bytes.append(_prompt)
else:
self._completion_bytes = [prompt.encode("utf-8")]
###
finish_reason = "length"
for token in self.generate(
prompt_tokens,
@ -368,6 +419,7 @@ class Llama:
top_p=top_p,
temp=temperature,
repeat_penalty=repeat_penalty,
reset=reset,
):
if token == llama_cpp.llama_token_eos():
text = self.detokenize(completion_tokens)
@ -397,6 +449,9 @@ class Llama:
break
text = all_text[: len(all_text) - longest]
returned_characters += len(text[start:])
### HACK
self._completion_bytes.append(text[start:])
###
yield {
"id": completion_id,
"object": "text_completion",
@ -418,6 +473,9 @@ class Llama:
break
if stream:
### HACK
self._completion_bytes.append(text[returned_characters:])
###
yield {
"id": completion_id,
"object": "text_completion",
@ -434,13 +492,16 @@ class Llama:
}
return
text = text.decode("utf-8")
### HACK
self._completion_bytes.append(text)
###
text_str = text.decode("utf-8")
if echo:
text = prompt + text
text_str = prompt + text_str
if suffix is not None:
text = text + suffix
text_str = text_str + suffix
logprobs_or_none: Optional[CompletionLogprobs] = None
if logprobs is not None:
@ -493,7 +554,7 @@ class Llama:
"model": self.model_path,
"choices": [
{
"text": text,
"text": text_str,
"index": 0,
"logprobs": logprobs_or_none,
"finish_reason": finish_reason,

View file

@ -35,6 +35,7 @@ class Settings(BaseSettings):
embedding: bool = True
last_n_tokens_size: int = 64
logits_all: bool = False
cache: bool = False # WARNING: This is an experimental feature
app = FastAPI(
@ -60,6 +61,9 @@ llama = llama_cpp.Llama(
n_ctx=settings.n_ctx,
last_n_tokens_size=settings.last_n_tokens_size,
)
if settings.cache:
cache = llama_cpp.LlamaCache()
llama.set_cache(cache)
llama_lock = Lock()
@ -68,7 +72,6 @@ def get_llama():
yield llama
class CreateCompletionRequest(BaseModel):
prompt: Union[str, List[str]]
suffix: Optional[str] = Field(None)