Bugfix: Missing logits_to_logprobs

This commit is contained in:
Andrei Betlen 2023-05-04 12:18:40 -04:00
parent d594892fd4
commit 329297fafb

View file

@ -639,7 +639,7 @@ class Llama:
self.detokenize([token]).decode("utf-8", errors="ignore") self.detokenize([token]).decode("utf-8", errors="ignore")
for token in all_tokens for token in all_tokens
] ]
all_logprobs = [Llama._logits_to_logprobs(row) for row in self.eval_logits] all_logprobs = [Llama.logits_to_logprobs(list(map(float, row))) for row in self.eval_logits]
for token, token_str, logprobs_token in zip( for token, token_str, logprobs_token in zip(
all_tokens, all_token_strs, all_logprobs all_tokens, all_token_strs, all_logprobs
): ):
@ -985,7 +985,7 @@ class Llama:
return llama_cpp.llama_token_bos() return llama_cpp.llama_token_bos()
@staticmethod @staticmethod
def logits_to_logprobs(logits: List[llama_cpp.c_float]) -> List[llama_cpp.c_float]: def logits_to_logprobs(logits: List[float]) -> List[float]:
exps = [math.exp(float(x)) for x in logits] exps = [math.exp(float(x)) for x in logits]
sum_exps = sum(exps) sum_exps = sum(exps)
return [llama_cpp.c_float(math.log(x / sum_exps)) for x in exps] return [math.log(x / sum_exps) for x in exps]