Fix logprob calculation. Fixes #134

This commit is contained in:
Andrei Betlen 2023-05-01 17:45:08 -04:00
parent c088a2b3a7
commit b6747f722e

View file

@ -638,7 +638,7 @@ class Llama:
for token in all_tokens for token in all_tokens
] ]
all_logprobs = [ all_logprobs = [
[Llama.logit_to_logprob(logit) for logit in row] Llama._logits_to_logprobs(row)
for row in self.eval_logits for row in self.eval_logits
] ]
for token, token_str, logprobs_token in zip( for token, token_str, logprobs_token in zip(
@ -980,5 +980,7 @@ class Llama:
return llama_cpp.llama_token_bos() return llama_cpp.llama_token_bos()
@staticmethod @staticmethod
def logit_to_logprob(x: float) -> float: def logits_to_logprobs(logits: List[llama_cpp.c_float]) -> List[llama_cpp.c_float]:
return math.log(1.0 + math.exp(x)) exps = [math.exp(float(x)) for x in logits]
sum_exps = sum(exps)
return [llama_cpp.c_float(math.log(x / sum_exps)) for x in exps]