From 86f8e5ad9162a57a72d3af598e477d0971e89eb7 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Mon, 24 Apr 2023 15:47:54 -0400 Subject: [PATCH] Refactor internal state for Llama class --- llama_cpp/llama.py | 63 +++++++++++++++++----------------------------- 1 file changed, 23 insertions(+), 40 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 70dcea9..f7a6e9e 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -84,16 +84,9 @@ class Llama: self.params.embedding = embedding self.last_n_tokens_size = last_n_tokens_size - self.last_n_tokens_data = deque( - [llama_cpp.llama_token(0)] * self.last_n_tokens_size, - maxlen=self.last_n_tokens_size, - ) - self.tokens_consumed = 0 - self.tokens: List[llama_cpp.llama_token] = [] self.n_batch = min(n_ctx, n_batch) - self.n_tokens = 0 - self.n_past = 0 - self.all_logits: List[List[float]] = [] # TODO: Use an array instead of a list. + self.eval_tokens: deque[llama_cpp.llama_token] = deque(maxlen=n_ctx) + self.eval_logits: deque[List[float]] = deque(maxlen=n_ctx) ### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support ### saving and restoring state, this allows us to continue a completion if the last @@ -181,14 +174,8 @@ class Llama: def reset(self): """Reset the model state.""" - self.last_n_tokens_data.extend( - [llama_cpp.llama_token(0)] * self.last_n_tokens_size - ) - self.tokens_consumed = 0 - self.tokens.clear() - self.n_tokens = 0 - self.n_past = 0 - self.all_logits.clear() + self.eval_tokens.clear() + self.eval_logits.clear() def eval(self, tokens: Sequence[llama_cpp.llama_token]): """Evaluate a list of tokens. @@ -200,32 +187,25 @@ class Llama: n_ctx = int(llama_cpp.llama_n_ctx(self.ctx)) for i in range(0, len(tokens), self.n_batch): batch = tokens[i : min(len(tokens), i + self.n_batch)] - self.n_past = min(n_ctx - len(batch), self.tokens_consumed) - self.n_tokens = len(batch) + n_past = min(n_ctx - len(batch), len(self.eval_tokens)) + n_tokens = len(batch) return_code = llama_cpp.llama_eval( ctx=self.ctx, tokens=(llama_cpp.llama_token * len(batch))(*batch), - n_tokens=llama_cpp.c_int(self.n_tokens), - n_past=llama_cpp.c_int(self.n_past), + n_tokens=llama_cpp.c_int(n_tokens), + n_past=llama_cpp.c_int(n_past), n_threads=llama_cpp.c_int(self.n_threads), ) if int(return_code) != 0: raise RuntimeError(f"llama_eval returned {return_code}") - self.tokens.extend(batch) - self.last_n_tokens_data.extend(batch) - self.tokens_consumed += len(batch) + self.eval_tokens.extend(batch) if self.params.logits_all: - self.all_logits.extend(self._logits()) - - def _logits(self) -> List[List[float]]: - """Return the logits from the last call to llama_eval.""" - assert self.ctx is not None - n_vocab = llama_cpp.llama_n_vocab(self.ctx) - cols = int(n_vocab) - rows = self.n_tokens if self.params.logits_all else 1 - logits_view = llama_cpp.llama_get_logits(self.ctx) - logits = [[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)] - return logits + n_vocab = llama_cpp.llama_n_vocab(self.ctx) + cols = int(n_vocab) + rows = n_tokens + logits_view = llama_cpp.llama_get_logits(self.ctx) + logits = [[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)] + self.eval_logits.extend(logits) def sample( self, @@ -246,10 +226,13 @@ class Llama: The sampled token. """ assert self.ctx is not None + last_n_tokens_data = [llama_cpp.llama_token(0)] * max( + 0, self.last_n_tokens_size - len(self.eval_tokens) + ) + list(self.eval_tokens)[-self.last_n_tokens_size :] return llama_cpp.llama_sample_top_p_top_k( ctx=self.ctx, last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)( - *self.last_n_tokens_data + *last_n_tokens_data ), last_n_tokens_size=llama_cpp.c_int(self.last_n_tokens_size), top_k=llama_cpp.c_int(top_k), @@ -293,13 +276,13 @@ class Llama: if ( reset and self._cache - and len(self.tokens) > 0 - and self.tokens == tokens[: len(self.tokens)] + and len(self.eval_tokens) > 0 + and self.eval_tokens == tokens[: len(self.eval_tokens)] ): if self.verbose: print("generate cache hit", file=sys.stderr) reset = False - tokens = tokens[len(self.tokens) :] + tokens = tokens[len(self.eval_tokens) :] ### if reset: self.reset() @@ -537,7 +520,7 @@ class Llama: ] all_logprobs = [ [Llama.logit_to_logprob(logit) for logit in row] - for row in self.all_logits + for row in self.eval_logits ] for token, token_str, logprobs_token in zip( all_tokens, all_token_strs, all_logprobs