From 4d6b2f7b91a8dfd4b4e283ea73492772b3471afe Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Tue, 23 Jan 2024 22:08:27 -0500 Subject: [PATCH] fix: format --- llama_cpp/llama.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 538781d..3d15800 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -197,13 +197,14 @@ class Llama: # kv_overrides is the original python dict self.kv_overrides = kv_overrides if kv_overrides is not None: - # _kv_overrides_array is a ctypes.Array of llama_model_kv_override Structs - kvo_array_len = len(kv_overrides) + 1 # for sentinel element - self._kv_overrides_array = (llama_cpp.llama_model_kv_override * kvo_array_len)() + kvo_array_len = len(kv_overrides) + 1 # for sentinel element + self._kv_overrides_array = ( + llama_cpp.llama_model_kv_override * kvo_array_len + )() for i, (k, v) in enumerate(kv_overrides.items()): - self._kv_overrides_array[i].key = k.encode('utf-8'); + self._kv_overrides_array[i].key = k.encode("utf-8") if isinstance(v, int): self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_INT self._kv_overrides_array[i].value.int_value = v @@ -216,7 +217,9 @@ class Llama: else: raise ValueError(f"Unknown value type for {k}: {v}") - self._kv_overrides_array[-1].key = b'\0' # ensure sentinel element is zeroed + self._kv_overrides_array[ + -1 + ].key = b"\0" # ensure sentinel element is zeroed self.model_params.kv_overrides = self._kv_overrides_array self.n_batch = min(n_ctx, n_batch) # ??? @@ -326,7 +329,9 @@ class Llama: (n_ctx, self._n_vocab), dtype=np.single ) - self._mirostat_mu = ctypes.c_float(2.0 * 5.0) # TODO: Move this to sampling context + self._mirostat_mu = ctypes.c_float( + 2.0 * 5.0 + ) # TODO: Move this to sampling context try: self.metadata = self._model.metadata() @@ -334,7 +339,7 @@ class Llama: self.metadata = {} if self.verbose: print(f"Failed to load metadata: {e}", file=sys.stderr) - + if self.verbose: print(f"Model metadata: {self.metadata}", file=sys.stderr) @@ -534,7 +539,7 @@ class Llama: candidates=self._candidates, tau=mirostat_tau, eta=mirostat_eta, - mu=ctypes.pointer(self._mirostat_mu) + mu=ctypes.pointer(self._mirostat_mu), ) else: self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1)