Fix logprobs for completions and implement for streaming logprobs.

This commit is contained in:
Andrei Betlen 2023-05-19 02:20:27 -04:00
parent a634a2453b
commit 17d4271b04

View file

@ -710,22 +710,56 @@ class Llama:
# We want to avoid yielding any characters from
# the generated text if they are part of a stop
# sequence.
longest = 0
first_stop_position = 0
for s in stop_sequences:
for i in range(len(s), 0, -1):
if all_text.endswith(s[:i]):
if i > longest:
longest = i
if i > first_stop_position:
first_stop_position = i
break
offset = 0
token_end_position = 0
remaining_tokens = completion_tokens[returned_tokens:]
remaining_length = len(self.detokenize(remaining_tokens))
for token in remaining_tokens:
offset += len(self.detokenize([token]))
# Check if stop sequence is not in the token
if offset >= (remaining_length - longest - 1):
token_end_position += len(self.detokenize([token]))
# Check if stop sequence is in the token
if token_end_position >= (remaining_length - first_stop_position - 1):
break
logprobs_or_none: Optional[CompletionLogprobs] = None
if logprobs is not None:
token_str = self.detokenize([token]).decode(
"utf-8", errors="ignore"
)
text_offset = len(prompt) + len(
self.detokenize(completion_tokens[:returned_tokens])
)
token_offset = len(prompt_tokens) + returned_tokens
logits = self.eval_logits[token_offset - 1]
current_logprobs = Llama.logits_to_logprobs(logits)
sorted_logprobs = list(
sorted(
zip(current_logprobs, range(len(current_logprobs))),
reverse=True,
)
)
top_logprob = {
self.detokenize([llama_cpp.llama_token(i)]).decode(
"utf-8", errors="ignore"
): logprob
for logprob, i in sorted_logprobs[:logprobs]
}
top_logprob.update({token_str: current_logprobs[int(token)]})
logprobs_or_none = {
"tokens": [
self.detokenize([token]).decode(
"utf-8", errors="ignore"
)
],
"text_offset": [text_offset],
"token_logprobs": [sorted_logprobs[int(token)][0]],
"top_logprobs": [top_logprob],
}
returned_tokens += 1
yield {
"id": completion_id,
@ -738,7 +772,7 @@ class Llama:
"utf-8", errors="ignore"
),
"index": 0,
"logprobs": None,
"logprobs": logprobs_or_none,
"finish_reason": None,
}
],
@ -766,13 +800,48 @@ class Llama:
else:
end = len(all_text)
offset = 0
token_end_position = 0
for token in remaining_tokens:
offset += len(self.detokenize([token]))
if offset >= end:
token_end_position += len(self.detokenize([token]))
logprobs_or_none: Optional[CompletionLogprobs] = None
if logprobs is not None:
token_str = self.detokenize([token]).decode(
"utf-8", errors="ignore"
)
text_offset = len(prompt) + len(
self.detokenize(completion_tokens[:returned_tokens])
)
token_offset = len(prompt_tokens) + returned_tokens - 1
logits = self.eval_logits[token_offset]
current_logprobs = Llama.logits_to_logprobs(logits)
sorted_logprobs = list(
sorted(
zip(current_logprobs, range(len(current_logprobs))),
reverse=True,
)
)
top_logprob = {
self.detokenize([llama_cpp.llama_token(i)]).decode(
"utf-8", errors="ignore"
): logprob
for logprob, i in sorted_logprobs[:logprobs]
}
top_logprob.update({token_str: current_logprobs[int(token)]})
logprobs_or_none = {
"tokens": [
self.detokenize([token]).decode("utf-8", errors="ignore")
],
"text_offset": [text_offset],
"token_logprobs": [sorted_logprobs[int(token)][0]],
"top_logprobs": [top_logprob],
}
if token_end_position >= end:
last_text = self.detokenize([token])
if offset == end - 1:
if token_end_position == end - 1:
break
returned_tokens += 1
yield {
"id": completion_id,
"object": "text_completion",
@ -781,10 +850,10 @@ class Llama:
"choices": [
{
"text": last_text[
: len(last_text) - (offset - end)
: len(last_text) - (token_end_position - end)
].decode("utf-8", errors="ignore"),
"index": 0,
"logprobs": None,
"logprobs": logprobs_or_none,
"finish_reason": finish_reason,
}
],
@ -802,7 +871,7 @@ class Llama:
"utf-8", errors="ignore"
),
"index": 0,
"logprobs": None,
"logprobs": logprobs_or_none,
"finish_reason": finish_reason
if returned_tokens == len(completion_tokens)
else None,
@ -821,13 +890,19 @@ class Llama:
logprobs_or_none: Optional[CompletionLogprobs] = None
if logprobs is not None:
text_offset = 0
text_offset = 0 if echo else len(prompt)
token_offset = 0 if echo else len(prompt_tokens[1:])
text_offsets: List[int] = []
token_logprobs: List[float] = []
token_logprobs: List[Optional[float]] = []
tokens: List[str] = []
top_logprobs: List[Dict[str, float]] = []
top_logprobs: List[Optional[Dict[str, float]]] = []
if echo:
# Remove leading BOS token
all_tokens = prompt_tokens[1:] + completion_tokens
else:
all_tokens = completion_tokens
all_tokens = prompt_tokens + completion_tokens
all_token_strs = [
self.detokenize([token]).decode("utf-8", errors="ignore")
for token in all_tokens
@ -835,7 +910,7 @@ class Llama:
all_logprobs = [
Llama.logits_to_logprobs(list(map(float, row)))
for row in self.eval_logits
]
][token_offset:]
for token, token_str, logprobs_token in zip(
all_tokens, all_token_strs, all_logprobs
):
@ -848,14 +923,20 @@ class Llama:
)
)
token_logprobs.append(sorted_logprobs[int(token)][0])
top_logprob = {
top_logprob: Optional[Dict[str, float]] = {
self.detokenize([llama_cpp.llama_token(i)]).decode(
"utf-8", errors="ignore"
): logprob
for logprob, i in sorted_logprobs[:logprobs]
}
top_logprob.update({token_str: sorted_logprobs[int(token)][0]})
top_logprob.update({token_str: logprobs_token[int(token)]})
top_logprobs.append(top_logprob)
# Weird idosincracy of the OpenAI API where
# token_logprobs and top_logprobs are null for
# the first token.
if echo and len(all_tokens) > 0:
token_logprobs[0] = None
top_logprobs[0] = None
logprobs_or_none = {
"tokens": tokens,
"text_offset": text_offsets,