Remove excessive errors="ignore" and add utf8 test
This commit is contained in:
parent
b7d14efc8b
commit
18a0c10032
|
@ -358,7 +358,7 @@ class Llama:
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
llama_cpp.llama_reset_timings(self.ctx)
|
llama_cpp.llama_reset_timings(self.ctx)
|
||||||
|
|
||||||
tokens = self.tokenize(input.encode("utf-8", errors="ignore"))
|
tokens = self.tokenize(input.encode("utf-8"))
|
||||||
self.reset()
|
self.reset()
|
||||||
self.eval(tokens)
|
self.eval(tokens)
|
||||||
n_tokens = len(tokens)
|
n_tokens = len(tokens)
|
||||||
|
@ -416,7 +416,7 @@ class Llama:
|
||||||
completion_tokens: List[llama_cpp.llama_token] = []
|
completion_tokens: List[llama_cpp.llama_token] = []
|
||||||
# Add blank space to start of prompt to match OG llama tokenizer
|
# Add blank space to start of prompt to match OG llama tokenizer
|
||||||
prompt_tokens: List[llama_cpp.llama_token] = self.tokenize(
|
prompt_tokens: List[llama_cpp.llama_token] = self.tokenize(
|
||||||
b" " + prompt.encode("utf-8", errors="ignore")
|
b" " + prompt.encode("utf-8")
|
||||||
)
|
)
|
||||||
text: bytes = b""
|
text: bytes = b""
|
||||||
returned_characters: int = 0
|
returned_characters: int = 0
|
||||||
|
@ -431,7 +431,7 @@ class Llama:
|
||||||
)
|
)
|
||||||
|
|
||||||
if stop != []:
|
if stop != []:
|
||||||
stop_sequences = [s.encode("utf-8", errors="ignore") for s in stop]
|
stop_sequences = [s.encode("utf-8") for s in stop]
|
||||||
else:
|
else:
|
||||||
stop_sequences = []
|
stop_sequences = []
|
||||||
|
|
||||||
|
|
|
@ -24,7 +24,7 @@ def test_llama_patch(monkeypatch):
|
||||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
|
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
|
||||||
|
|
||||||
output_text = " jumps over the lazy dog."
|
output_text = " jumps over the lazy dog."
|
||||||
output_tokens = llama.tokenize(output_text.encode("utf-8", errors="ignore"))
|
output_tokens = llama.tokenize(output_text.encode("utf-8"))
|
||||||
token_eos = llama.token_eos()
|
token_eos = llama.token_eos()
|
||||||
n = 0
|
n = 0
|
||||||
|
|
||||||
|
@ -94,3 +94,37 @@ def test_llama_pickle():
|
||||||
text = b"Hello World"
|
text = b"Hello World"
|
||||||
|
|
||||||
assert llama.detokenize(llama.tokenize(text)) == text
|
assert llama.detokenize(llama.tokenize(text)) == text
|
||||||
|
|
||||||
|
def test_utf8(monkeypatch):
|
||||||
|
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
|
||||||
|
|
||||||
|
## Set up mock function
|
||||||
|
def mock_eval(*args, **kwargs):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
|
||||||
|
|
||||||
|
output_text = "😀"
|
||||||
|
output_tokens = llama.tokenize(output_text.encode("utf-8"))
|
||||||
|
token_eos = llama.token_eos()
|
||||||
|
n = 0
|
||||||
|
|
||||||
|
def mock_sample(*args, **kwargs):
|
||||||
|
nonlocal n
|
||||||
|
if n < len(output_tokens):
|
||||||
|
n += 1
|
||||||
|
return output_tokens[n - 1]
|
||||||
|
else:
|
||||||
|
return token_eos
|
||||||
|
|
||||||
|
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_top_p_top_k", mock_sample)
|
||||||
|
|
||||||
|
## Test basic completion with utf8 multibyte
|
||||||
|
n = 0 # reset
|
||||||
|
completion = llama.create_completion("", max_tokens=4)
|
||||||
|
assert completion["choices"][0]["text"] == output_text
|
||||||
|
|
||||||
|
## Test basic completion with incomplete utf8 multibyte
|
||||||
|
n = 0 # reset
|
||||||
|
completion = llama.create_completion("", max_tokens=1)
|
||||||
|
assert completion["choices"][0]["text"] == ""
|
||||||
|
|
Loading…
Reference in a new issue