From 4f2b5d0b5321bedc879ee9b9a19ca15d18ddb995 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Sat, 8 Jul 2023 00:05:10 -0400 Subject: [PATCH] Format --- llama_cpp/llama.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 130e013..f8e0527 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -324,7 +324,7 @@ class Llama: self._candidates = candidates self._token_nl = Llama.token_nl() self._token_eos = Llama.token_eos() - self._candidates_data_id = np.arange(self._n_vocab, dtype=np.intc) # type: ignore + self._candidates_data_id = np.arange(self._n_vocab, dtype=np.intc) # type: ignore self._candidates_data_p = np.zeros(self._n_vocab, dtype=np.single) self.n_tokens = 0 @@ -445,8 +445,12 @@ class Llama: # Save logits rows = n_tokens if self.params.logits_all else 1 cols = self._n_vocab - offset = 0 if self.params.logits_all else n_tokens - 1 # NOTE: Only save the last token logits if logits_all is False - self.scores[self.n_tokens + offset: self.n_tokens + n_tokens, :].reshape(-1)[:] = llama_cpp.llama_get_logits(self.ctx)[:rows * cols] + offset = ( + 0 if self.params.logits_all else n_tokens - 1 + ) # NOTE: Only save the last token logits if logits_all is False + self.scores[self.n_tokens + offset : self.n_tokens + n_tokens, :].reshape( + -1 + )[:] = llama_cpp.llama_get_logits(self.ctx)[: rows * cols] # Update n_tokens self.n_tokens += n_tokens @@ -491,7 +495,7 @@ class Llama: candidates_data = self._candidates_data candidates_data["id"][:] = self._candidates_data_id # type: ignore candidates_data["logit"][:] = logits - candidates_data["p"][:] = self._candidates_data_p # type: ignore + candidates_data["p"][:] = self._candidates_data_p # type: ignore candidates.data = candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p) candidates.sorted = llama_cpp.c_bool(False) candidates.size = llama_cpp.c_size_t(n_vocab) @@ -537,7 +541,7 @@ class Llama: mirostat_mu = llama_cpp.c_float(2.0 * mirostat_tau.value) llama_cpp.llama_sample_temperature( ctx=self.ctx, - candidates=llama_cpp.ctypes.byref(candidates), # type: ignore + candidates=llama_cpp.ctypes.byref(candidates), # type: ignore temp=temp, ) return llama_cpp.llama_sample_token_mirostat_v2(