fix: Make LLamaState pickable for disk cache

I fixed the issue by making the saved state a bytes object instead of the ctypes one which can't be pickled.
This commit is contained in:
Okabintaro 2023-06-13 12:03:31 +02:00
parent ad4479e609
commit 10b0cb727b

View file

@ -141,7 +141,9 @@ class LlamaDiskCache(BaseLlamaCache):
if _key is None:
raise KeyError("Key not found")
value: "LlamaState" = self.cache.pop(_key) # type: ignore
self.cache.push(_key, side="front") # type: ignore
# NOTE: This puts an integer as key in cache, which breaks,
# Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens
# self.cache.push(_key, side="front") # type: ignore
return value
def __contains__(self, key: Sequence[int]) -> bool:
@ -168,7 +170,7 @@ class LlamaState:
eval_logits: Deque[List[float]],
input_ids: npt.NDArray[np.intc],
scores: npt.NDArray[np.single],
llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8]
llama_state: bytes,
llama_state_size: int,
):
self.eval_tokens = eval_tokens
@ -1503,7 +1505,7 @@ class Llama:
eval_logits=self.eval_logits.copy(),
scores=self._scores.copy(),
input_ids=self._input_ids.copy(),
llama_state=llama_state_compact,
llama_state=bytes(llama_state_compact),
llama_state_size=n_bytes,
)
@ -1514,7 +1516,10 @@ class Llama:
self._scores = state.scores.copy()
self._input_ids = state.input_ids.copy()
state_size = state.llama_state_size
if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size:
LLamaStateArrayType = (llama_cpp.c_uint8 * state_size)
llama_state = LLamaStateArrayType.from_buffer_copy(state.llama_state)
if llama_cpp.llama_set_state_data(self.ctx, llama_state) != state_size:
raise RuntimeError("Failed to set llama state data")
def n_ctx(self) -> int: