From 329297fafb4916951cf1c3146505a9501e986d95 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Thu, 4 May 2023 12:18:40 -0400 Subject: [PATCH] Bugfix: Missing logits_to_logprobs --- llama_cpp/llama.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index fef7b3e..8cd77ee 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -639,7 +639,7 @@ class Llama: self.detokenize([token]).decode("utf-8", errors="ignore") for token in all_tokens ] - all_logprobs = [Llama._logits_to_logprobs(row) for row in self.eval_logits] + all_logprobs = [Llama.logits_to_logprobs(list(map(float, row))) for row in self.eval_logits] for token, token_str, logprobs_token in zip( all_tokens, all_token_strs, all_logprobs ): @@ -985,7 +985,7 @@ class Llama: return llama_cpp.llama_token_bos() @staticmethod - def logits_to_logprobs(logits: List[llama_cpp.c_float]) -> List[llama_cpp.c_float]: + def logits_to_logprobs(logits: List[float]) -> List[float]: exps = [math.exp(float(x)) for x in logits] sum_exps = sum(exps) - return [llama_cpp.c_float(math.log(x / sum_exps)) for x in exps] + return [math.log(x / sum_exps) for x in exps]