diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 012bb86..6babebd 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -20,6 +20,9 @@ from collections import deque, OrderedDict from . import llama_cpp from .llama_types import * +import numpy as np +import numpy.typing as npt + class LlamaCache: """Cache for a llama.cpp model.""" @@ -73,11 +76,15 @@ class LlamaState: self, eval_tokens: Deque[int], eval_logits: Deque[List[float]], + input_ids: npt.NDArray[np.intc], + scores: npt.NDArray[np.single], llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8] llama_state_size: int, ): self.eval_tokens = eval_tokens self.eval_logits = eval_logits + self.input_ids = input_ids + self.scores = scores self.llama_state = llama_state self.llama_state_size = llama_state_size @@ -207,20 +214,14 @@ class Llama: self._n_vocab = self.n_vocab() self._n_ctx = self.n_ctx() - data = (llama_cpp.llama_token_data * self._n_vocab)( - *[ - llama_cpp.llama_token_data( - id=llama_cpp.llama_token(i), - logit=llama_cpp.c_float(0.0), - p=llama_cpp.c_float(0.0), - ) - for i in range(self._n_vocab) - ] - ) size = llama_cpp.c_size_t(self._n_vocab) - sorted = False + sorted = llama_cpp.c_bool(False) + self._candidates_data = np.array( + [], dtype=[("id", np.intc), ("logit", np.single), ("p", np.single)] + ) + self._candidates_data.resize(3, self._n_vocab) candidates = llama_cpp.llama_token_data_array( - data=data, + data=self._candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p), size=size, sorted=sorted, ) @@ -228,6 +229,9 @@ class Llama: self._token_nl = Llama.token_nl() self._token_eos = Llama.token_eos() + self._input_ids = np.array([], dtype=np.intc) + self._scores = np.ndarray((0, self._n_vocab), dtype=np.single) + def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]: """Tokenize a string. @@ -319,6 +323,9 @@ class Llama: raise RuntimeError(f"llama_eval returned {return_code}") # Save tokens self.eval_tokens.extend(batch) + self._input_ids: npt.NDArray[np.intc] = np.concatenate( + (self._input_ids, np.array(batch, dtype=np.intc)), axis=0 + ) # Save logits rows = n_tokens if self.params.logits_all else 1 n_vocab = self._n_vocab @@ -326,6 +333,9 @@ class Llama: logits_view = llama_cpp.llama_get_logits(self.ctx) logits = [logits_view[i * cols : (i + 1) * cols] for i in range(rows)] self.eval_logits.extend(logits) + self._scores: npt.NDArray[np.single] = np.concatenate( + (self._scores, np.array(logits, dtype=np.single)), axis=0 + ) def _sample( self, @@ -354,18 +364,23 @@ class Llama: if last_n_tokens_size.value < 0 else last_n_tokens_size ) - logits = self.eval_logits[-1] + logits: npt.NDArray[np.single] = self._scores[-1, :] if logits_processor is not None: - logits = logits_processor(list(self.eval_tokens), logits) - self.eval_logits[-1] = logits + logits = np.array( + logits_processor(list(self.eval_tokens), logits.tolist()), + dtype=np.single, + ) + self._scores[-1, :] = logits + self.eval_logits[-1] = logits.tolist() nl_logit = logits[self._token_nl] candidates = self._candidates - for i, logit in enumerate(logits): - candidates.data[i].id = llama_cpp.llama_token(i) - candidates.data[i].logit = llama_cpp.c_float(logit) - candidates.data[i].p = llama_cpp.c_float(0.0) + candidates_data = self._candidates_data + candidates_data["id"] = np.arange(n_vocab, dtype=np.intc) # type: ignore + candidates_data["logit"] = logits + candidates_data["p"] = np.zeros(n_vocab, dtype=np.single) + candidates.data = candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p) candidates.sorted = llama_cpp.c_bool(False) candidates.size = llama_cpp.c_size_t(n_vocab) llama_cpp.llama_sample_repetition_penalty( @@ -1371,6 +1386,8 @@ class Llama: return LlamaState( eval_tokens=self.eval_tokens.copy(), eval_logits=self.eval_logits.copy(), + scores=self._scores.copy(), + input_ids=self._input_ids.copy(), llama_state=llama_state_compact, llama_state_size=n_bytes, ) @@ -1379,6 +1396,8 @@ class Llama: assert self.ctx is not None self.eval_tokens = state.eval_tokens.copy() self.eval_logits = state.eval_logits.copy() + self._scores = state.scores.copy() + self._input_ids = state.input_ids.copy() state_size = state.llama_state_size if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size: raise RuntimeError("Failed to set llama state data") diff --git a/setup.py b/setup.py index bd7192f..198dd74 100644 --- a/setup.py +++ b/setup.py @@ -16,9 +16,7 @@ setup( license="MIT", package_dir={"llama_cpp": "llama_cpp", "llama_cpp.server": "llama_cpp/server"}, packages=["llama_cpp", "llama_cpp.server"], - install_requires=[ - "typing-extensions>=4.5.0", - ], + install_requires=["typing-extensions>=4.5.0", "numpy>=1.24.2"], extras_require={ "server": ["uvicorn>=0.21.1", "fastapi>=0.95.0", "sse-starlette>=1.3.3"], },