From 18a0c10032ef793b67bb8ea9e4ca9e3aaa791595 Mon Sep 17 00:00:00 2001 From: Mug <> Date: Sat, 29 Apr 2023 12:19:22 +0200 Subject: [PATCH] Remove excessive errors="ignore" and add utf8 test --- llama_cpp/llama.py | 6 +++--- tests/test_llama.py | 38 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index fe540f9..4e3c3aa 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -358,7 +358,7 @@ class Llama: if self.verbose: llama_cpp.llama_reset_timings(self.ctx) - tokens = self.tokenize(input.encode("utf-8", errors="ignore")) + tokens = self.tokenize(input.encode("utf-8")) self.reset() self.eval(tokens) n_tokens = len(tokens) @@ -416,7 +416,7 @@ class Llama: completion_tokens: List[llama_cpp.llama_token] = [] # Add blank space to start of prompt to match OG llama tokenizer prompt_tokens: List[llama_cpp.llama_token] = self.tokenize( - b" " + prompt.encode("utf-8", errors="ignore") + b" " + prompt.encode("utf-8") ) text: bytes = b"" returned_characters: int = 0 @@ -431,7 +431,7 @@ class Llama: ) if stop != []: - stop_sequences = [s.encode("utf-8", errors="ignore") for s in stop] + stop_sequences = [s.encode("utf-8") for s in stop] else: stop_sequences = [] diff --git a/tests/test_llama.py b/tests/test_llama.py index 4dab687..4727d90 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -24,7 +24,7 @@ def test_llama_patch(monkeypatch): monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval) output_text = " jumps over the lazy dog." - output_tokens = llama.tokenize(output_text.encode("utf-8", errors="ignore")) + output_tokens = llama.tokenize(output_text.encode("utf-8")) token_eos = llama.token_eos() n = 0 @@ -93,4 +93,38 @@ def test_llama_pickle(): text = b"Hello World" - assert llama.detokenize(llama.tokenize(text)) == text \ No newline at end of file + assert llama.detokenize(llama.tokenize(text)) == text + +def test_utf8(monkeypatch): + llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True) + + ## Set up mock function + def mock_eval(*args, **kwargs): + return 0 + + monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval) + + output_text = "😀" + output_tokens = llama.tokenize(output_text.encode("utf-8")) + token_eos = llama.token_eos() + n = 0 + + def mock_sample(*args, **kwargs): + nonlocal n + if n < len(output_tokens): + n += 1 + return output_tokens[n - 1] + else: + return token_eos + + monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_top_p_top_k", mock_sample) + + ## Test basic completion with utf8 multibyte + n = 0 # reset + completion = llama.create_completion("", max_tokens=4) + assert completion["choices"][0]["text"] == output_text + + ## Test basic completion with incomplete utf8 multibyte + n = 0 # reset + completion = llama.create_completion("", max_tokens=1) + assert completion["choices"][0]["text"] == ""