diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 182f855..1049e44 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -651,6 +651,45 @@ class Llama: llama_cpp.llama_free(self.ctx) self.ctx = None + def __getstate__(self): + return dict( + verbose=self.verbose, + model_path=self.model_path, + n_ctx=self.params.n_ctx, + n_parts=self.params.n_parts, + seed=self.params.seed, + f16_kv=self.params.f16_kv, + logits_all=self.params.logits_all, + vocab_only=self.params.vocab_only, + use_mlock=self.params.use_mlock, + embedding=self.params.embedding, + last_n_tokens_size=self.last_n_tokens_size, + last_n_tokens_data=self.last_n_tokens_data, + tokens_consumed=self.tokens_consumed, + n_batch=self.n_batch, + n_threads=self.n_threads, + ) + + def __setstate__(self, state): + self.__init__( + model_path=state["model_path"], + n_ctx=state["n_ctx"], + n_parts=state["n_parts"], + seed=state["seed"], + f16_kv=state["f16_kv"], + logits_all=state["logits_all"], + vocab_only=state["vocab_only"], + use_mlock=state["use_mlock"], + embedding=state["embedding"], + n_threads=state["n_threads"], + n_batch=state["n_batch"], + last_n_tokens_size=state["last_n_tokens_size"], + verbose=state["verbose"], + ) + self.last_n_tokens_data=state["last_n_tokens_data"] + self.tokens_consumed=state["tokens_consumed"] + + @staticmethod def token_eos() -> llama_cpp.llama_token: """Return the end-of-sequence token.""" diff --git a/tests/test_llama.py b/tests/test_llama.py index 6843ec6..6a50256 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -77,3 +77,20 @@ def test_llama_patch(monkeypatch): chunks = llama.create_completion(text, max_tokens=2, stream=True) assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == " j" assert completion["choices"][0]["finish_reason"] == "length" + + +def test_llama_pickle(): + import pickle + import tempfile + fp = tempfile.TemporaryFile() + llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True) + pickle.dump(llama, fp) + fp.seek(0) + llama = pickle.load(fp) + + assert llama + assert llama.ctx is not None + + text = b"Hello World" + + assert llama.detokenize(llama.tokenize(text)) == text \ No newline at end of file