Bugfix: Stop sequences and missing max_tokens check

This commit is contained in:
Andrei Betlen 2023-04-02 03:59:19 -04:00
parent 42dd11c2b4
commit 4f509b963e

View file

@ -286,6 +286,7 @@ class Llama:
# Add blank space to start of prompt to match OG llama tokenizer
prompt_tokens = self.tokenize(b" " + prompt.encode("utf-8"))
text = b""
returned_characters = 0
if len(prompt_tokens) + max_tokens > int(llama_cpp.llama_n_ctx(self.ctx)):
raise ValueError(
@ -293,9 +294,9 @@ class Llama:
)
if stop != []:
stop_bytes = [s.encode("utf-8") for s in stop]
stop_sequences = [s.encode("utf-8") for s in stop]
else:
stop_bytes = []
stop_sequences = []
finish_reason = None
for token in self.generate(
@ -306,28 +307,33 @@ class Llama:
repeat_penalty=repeat_penalty,
):
if token == llama_cpp.llama_token_eos():
text = self.detokenize(completion_tokens)
finish_reason = "stop"
break
completion_tokens.append(token)
text = self.detokenize(completion_tokens)
any_stop = [s for s in stop_bytes if s in text]
all_text = self.detokenize(completion_tokens)
any_stop = [s for s in stop_sequences if s in all_text]
if len(any_stop) > 0:
first_stop = any_stop[0]
text = text[: text.index(first_stop)]
text = all_text[: all_text.index(first_stop)]
finish_reason = "stop"
break
if stream:
start = len(self.detokenize(completion_tokens[:-1]))
start = returned_characters
longest = 0
# TODO: Clean up this mess
for s in stop_bytes:
# We want to avoid yielding any characters from
# the generated text if they are part of a stop
# sequence.
for s in stop_sequences:
for i in range(len(s), 0, -1):
if s[-i:] == text[-i:]:
if all_text.endswith(s[:i]):
if i > longest:
longest = i
break
text = all_text[: len(all_text) - longest]
returned_characters += len(text[start:])
yield {
"id": completion_id,
"object": "text_completion",
@ -335,23 +341,22 @@ class Llama:
"model": self.model_path,
"choices": [
{
"text": text[start : len(text) - longest].decode("utf-8"),
"text": text[start :].decode("utf-8"),
"index": 0,
"logprobs": None,
"finish_reason": None,
}
],
}
if len(completion_tokens) >= max_tokens:
text = self.detokenize(completion_tokens)
finish_reason = "length"
break
if finish_reason is None:
finish_reason = "length"
if stream:
if finish_reason == "stop":
start = len(self.detokenize(completion_tokens[:-1]))
text = text[start:].decode("utf-8")
else:
text = ""
yield {
"id": completion_id,
"object": "text_completion",
@ -359,7 +364,7 @@ class Llama:
"model": self.model_path,
"choices": [
{
"text": text,
"text": text[returned_characters:].decode("utf-8"),
"index": 0,
"logprobs": None,
"finish_reason": finish_reason,