From a9b9f0397cd86509b3ea359e5260e329464dc032 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 28 Jul 2023 01:53:08 -0400 Subject: [PATCH] Format --- llama_cpp/llama.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index b52a398..2537af2 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -224,7 +224,7 @@ class Llama: rope_freq_base: float = 10000.0, rope_freq_scale: float = 1.0, n_gqa: Optional[int] = None, # (TEMPORARY) must be 8 for llama2 70b - rms_norm_eps: Optional[float] = None, # (TEMPORARY) + rms_norm_eps: Optional[float] = None, # (TEMPORARY) verbose: bool = True, ): """Load a llama.cpp model from `model_path`. @@ -277,7 +277,9 @@ class Llama: if self.tensor_split is not None: FloatArray = (ctypes.c_float * len(self.tensor_split))(*self.tensor_split) - self._p_tensor_split = ctypes.POINTER(ctypes.c_float)(FloatArray) # keep a reference to the array so it is not gc'd + self._p_tensor_split = ctypes.POINTER(ctypes.c_float)( + FloatArray + ) # keep a reference to the array so it is not gc'd self.params.tensor_split = self._p_tensor_split self.params.rope_freq_base = rope_freq_base @@ -959,9 +961,7 @@ class Llama: for token in remaining_tokens: token_end_position += len(self.detokenize([token])) # Check if stop sequence is in the token - if token_end_position >= ( - remaining_length - first_stop_position - ): + if token_end_position >= (remaining_length - first_stop_position): break logprobs_or_none: Optional[CompletionLogprobs] = None if logprobs is not None: