diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index fef7b3e..8cd77ee 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -639,7 +639,7 @@ class Llama: self.detokenize([token]).decode("utf-8", errors="ignore") for token in all_tokens ] - all_logprobs = [Llama._logits_to_logprobs(row) for row in self.eval_logits] + all_logprobs = [Llama.logits_to_logprobs(list(map(float, row))) for row in self.eval_logits] for token, token_str, logprobs_token in zip( all_tokens, all_token_strs, all_logprobs ): @@ -985,7 +985,7 @@ class Llama: return llama_cpp.llama_token_bos() @staticmethod - def logits_to_logprobs(logits: List[llama_cpp.c_float]) -> List[llama_cpp.c_float]: + def logits_to_logprobs(logits: List[float]) -> List[float]: exps = [math.exp(float(x)) for x in logits] sum_exps = sum(exps) - return [llama_cpp.c_float(math.log(x / sum_exps)) for x in exps] + return [math.log(x / sum_exps) for x in exps]