Use pre-allocated buffers to store input_ids and scores

This commit is contained in:
Andrei Betlen 2023-06-29 00:40:47 -04:00
parent a5e059c053
commit b95b0ffbeb

View file

@ -141,7 +141,7 @@ class LlamaDiskCache(BaseLlamaCache):
if _key is None:
raise KeyError("Key not found")
value: "LlamaState" = self.cache.pop(_key) # type: ignore
# NOTE: This puts an integer as key in cache, which breaks,
# NOTE: This puts an integer as key in cache, which breaks,
# Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens
# self.cache.push(_key, side="front") # type: ignore
return value
@ -166,17 +166,15 @@ class LlamaDiskCache(BaseLlamaCache):
class LlamaState:
def __init__(
self,
eval_tokens: Deque[int],
eval_logits: Deque[List[float]],
input_ids: npt.NDArray[np.intc],
scores: npt.NDArray[np.single],
n_tokens: int,
llama_state: bytes,
llama_state_size: int,
):
self.eval_tokens = eval_tokens
self.eval_logits = eval_logits
self.input_ids = input_ids
self.scores = scores
self.n_tokens = n_tokens
self.llama_state = llama_state
self.llama_state_size = llama_state_size
@ -267,8 +265,6 @@ class Llama:
self.last_n_tokens_size = last_n_tokens_size
self.n_batch = min(n_ctx, n_batch)
self.eval_tokens: Deque[int] = deque(maxlen=n_ctx)
self.eval_logits: Deque[List[float]] = deque(maxlen=n_ctx if logits_all else 1)
self.cache: Optional[BaseLlamaCache] = None
@ -329,8 +325,30 @@ class Llama:
self._token_nl = Llama.token_nl()
self._token_eos = Llama.token_eos()
self._input_ids = np.array([], dtype=np.intc)
self._scores: npt.NDArray[np.single] = np.ndarray((0, self._n_vocab), dtype=np.single)
self.n_tokens = 0
self.input_ids: npt.NDArray[np.intc] = np.ndarray((n_ctx,), dtype=np.intc)
self.scores: npt.NDArray[np.single] = np.ndarray(
(n_ctx, self._n_vocab), dtype=np.single
)
@property
def _input_ids(self) -> npt.NDArray[np.intc]:
return self.input_ids[: self.n_tokens]
@property
def _scores(self) -> npt.NDArray[np.single]:
return self.scores[: self.n_tokens, :]
@property
def eval_tokens(self) -> Deque[int]:
return deque(self.input_ids[: self.n_tokens].tolist(), maxlen=self._n_ctx)
@property
def eval_logits(self) -> Deque[List[float]]:
return deque(
self.scores[: self.n_tokens, :].tolist(),
maxlen=self._n_ctx if self.params.logits_all else 1,
)
def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
"""Tokenize a string.
@ -397,10 +415,7 @@ class Llama:
def reset(self):
"""Reset the model state."""
self.eval_tokens.clear()
self.eval_logits.clear()
self._input_ids = np.array([], dtype=np.intc)
self._scores = np.ndarray((0, self._n_vocab), dtype=np.single)
self.n_tokens = 0
def eval(self, tokens: Sequence[int]):
"""Evaluate a list of tokens.
@ -410,7 +425,6 @@ class Llama:
"""
assert self.ctx is not None
n_ctx = self._n_ctx
scores: List[npt.NDArray[np.single]] = []
for i in range(0, len(tokens), self.n_batch):
batch = tokens[i : min(len(tokens), i + self.n_batch)]
n_past = min(n_ctx - len(batch), len(self._input_ids))
@ -425,19 +439,16 @@ class Llama:
if return_code != 0:
raise RuntimeError(f"llama_eval returned {return_code}")
# Save tokens
self.eval_tokens.extend(batch)
self._input_ids: npt.NDArray[np.intc] = np.concatenate(
(self._input_ids, np.array(batch, dtype=np.intc)), axis=0
)
self.input_ids[self.n_tokens : self.n_tokens + n_tokens] = batch
# Save logits
rows = n_tokens if self.params.logits_all else 1
n_vocab = self._n_vocab
cols = n_vocab
logits_view = llama_cpp.llama_get_logits(self.ctx)
logits = [logits_view[i * cols : (i + 1) * cols] for i in range(rows)]
self.eval_logits.extend(logits)
scores.append(np.array(logits, dtype=np.single))
self._scores = np.concatenate(scores)
self.scores[self.n_tokens : self.n_tokens + n_tokens, :] = logits
# Update n_tokens
self.n_tokens += n_tokens
def _sample(
self,
@ -457,8 +468,7 @@ class Llama:
logits_processor: Optional[LogitsProcessorList] = None,
):
assert self.ctx is not None
assert len(self.eval_logits) > 0
assert self._scores.shape[0] > 0
assert self.n_tokens > 0
n_vocab = self._n_vocab
n_ctx = self._n_ctx
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
@ -475,7 +485,6 @@ class Llama:
dtype=np.single,
)
self._scores[-1, :] = logits
self.eval_logits[-1] = logits.tolist()
nl_logit = logits[self._token_nl]
candidates = self._candidates
@ -672,14 +681,7 @@ class Llama:
print("Llama.generate: prefix-match hit", file=sys.stderr)
reset = False
tokens = tokens[longest_prefix:]
self._input_ids = self._input_ids[:longest_prefix]
self._scores = self._scores[:longest_prefix, :]
for _ in range(len(self.eval_tokens) - longest_prefix):
self.eval_tokens.pop()
try:
self.eval_logits.pop()
except IndexError:
pass
self.n_tokens = longest_prefix
if reset:
self.reset()
@ -819,7 +821,9 @@ class Llama:
llama_cpp.llama_reset_timings(self.ctx)
if len(prompt_tokens) > self._n_ctx:
raise ValueError(f"Requested tokens ({len(prompt_tokens)}) exceed context window of {self._n_ctx}")
raise ValueError(
f"Requested tokens ({len(prompt_tokens)}) exceed context window of {self._n_ctx}"
)
# Truncate max_tokens if requested tokens would exceed the context window
max_tokens = (
@ -1513,22 +1517,20 @@ class Llama:
file=sys.stderr,
)
return LlamaState(
eval_tokens=self.eval_tokens.copy(),
eval_logits=self.eval_logits.copy(),
scores=self._scores.copy(),
input_ids=self._input_ids.copy(),
scores=self.scores.copy(),
input_ids=self.input_ids.copy(),
n_tokens=self.n_tokens,
llama_state=bytes(llama_state_compact),
llama_state_size=n_bytes,
)
def load_state(self, state: LlamaState) -> None:
assert self.ctx is not None
self.eval_tokens = state.eval_tokens.copy()
self.eval_logits = state.eval_logits.copy()
self._scores = state.scores.copy()
self._input_ids = state.input_ids.copy()
self.scores = state.scores.copy()
self.input_ids = state.input_ids.copy()
self.n_tokens = state.n_tokens
state_size = state.llama_state_size
LLamaStateArrayType = (llama_cpp.c_uint8 * state_size)
LLamaStateArrayType = llama_cpp.c_uint8 * state_size
llama_state = LLamaStateArrayType.from_buffer_copy(state.llama_state)
if llama_cpp.llama_set_state_data(self.ctx, llama_state) != state_size: