Also ignore errors on input prompts

This commit is contained in:
Mug 2023-04-26 14:45:51 +02:00
parent 3c130f00ca
commit 5f81400fcb
3 changed files with 5 additions and 5 deletions

View file

@ -201,7 +201,7 @@ n_keep = {self.params.n_keep}
# tokenize a prompt # tokenize a prompt
def _tokenize(self, prompt, bos=True): def _tokenize(self, prompt, bos=True):
_arr = (llama_cpp.llama_token * (len(prompt) + 1))() _arr = (llama_cpp.llama_token * (len(prompt) + 1))()
_n = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8"), _arr, len(_arr), bos) _n = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8", errors="ignore"), _arr, len(_arr), bos)
return _arr[:_n] return _arr[:_n]
def set_color(self, c): def set_color(self, c):

View file

@ -358,7 +358,7 @@ class Llama:
if self.verbose: if self.verbose:
llama_cpp.llama_reset_timings(self.ctx) llama_cpp.llama_reset_timings(self.ctx)
tokens = self.tokenize(input.encode("utf-8")) tokens = self.tokenize(input.encode("utf-8", errors="ignore"))
self.reset() self.reset()
self.eval(tokens) self.eval(tokens)
n_tokens = len(tokens) n_tokens = len(tokens)
@ -416,7 +416,7 @@ class Llama:
completion_tokens: List[llama_cpp.llama_token] = [] completion_tokens: List[llama_cpp.llama_token] = []
# Add blank space to start of prompt to match OG llama tokenizer # Add blank space to start of prompt to match OG llama tokenizer
prompt_tokens: List[llama_cpp.llama_token] = self.tokenize( prompt_tokens: List[llama_cpp.llama_token] = self.tokenize(
b" " + prompt.encode("utf-8") b" " + prompt.encode("utf-8", errors="ignore")
) )
text: bytes = b"" text: bytes = b""
returned_characters: int = 0 returned_characters: int = 0
@ -431,7 +431,7 @@ class Llama:
) )
if stop != []: if stop != []:
stop_sequences = [s.encode("utf-8") for s in stop] stop_sequences = [s.encode("utf-8", errors="ignore") for s in stop]
else: else:
stop_sequences = [] stop_sequences = []

View file

@ -24,7 +24,7 @@ def test_llama_patch(monkeypatch):
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval) monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
output_text = " jumps over the lazy dog." output_text = " jumps over the lazy dog."
output_tokens = llama.tokenize(output_text.encode("utf-8")) output_tokens = llama.tokenize(output_text.encode("utf-8", errors="ignore"))
token_eos = llama.token_eos() token_eos = llama.token_eos()
n = 0 n = 0