From b6747f722e473cb8380a3b8145704e18d8fc76b8 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Mon, 1 May 2023 17:45:08 -0400 Subject: [PATCH] Fix logprob calculation. Fixes #134 --- llama_cpp/llama.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index b38f2bb..bec5be7 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -638,7 +638,7 @@ class Llama: for token in all_tokens ] all_logprobs = [ - [Llama.logit_to_logprob(logit) for logit in row] + Llama._logits_to_logprobs(row) for row in self.eval_logits ] for token, token_str, logprobs_token in zip( @@ -980,5 +980,7 @@ class Llama: return llama_cpp.llama_token_bos() @staticmethod - def logit_to_logprob(x: float) -> float: - return math.log(1.0 + math.exp(x)) + def logits_to_logprobs(logits: List[llama_cpp.c_float]) -> List[llama_cpp.c_float]: + exps = [math.exp(float(x)) for x in logits] + sum_exps = sum(exps) + return [llama_cpp.c_float(math.log(x / sum_exps)) for x in exps]