fix: format

This commit is contained in:
Andrei Betlen 2024-01-23 22:08:27 -05:00
parent fe5d6ea648
commit 4d6b2f7b91

View file

@ -197,13 +197,14 @@ class Llama:
# kv_overrides is the original python dict # kv_overrides is the original python dict
self.kv_overrides = kv_overrides self.kv_overrides = kv_overrides
if kv_overrides is not None: if kv_overrides is not None:
# _kv_overrides_array is a ctypes.Array of llama_model_kv_override Structs # _kv_overrides_array is a ctypes.Array of llama_model_kv_override Structs
kvo_array_len = len(kv_overrides) + 1 # for sentinel element kvo_array_len = len(kv_overrides) + 1 # for sentinel element
self._kv_overrides_array = (llama_cpp.llama_model_kv_override * kvo_array_len)() self._kv_overrides_array = (
llama_cpp.llama_model_kv_override * kvo_array_len
)()
for i, (k, v) in enumerate(kv_overrides.items()): for i, (k, v) in enumerate(kv_overrides.items()):
self._kv_overrides_array[i].key = k.encode('utf-8'); self._kv_overrides_array[i].key = k.encode("utf-8")
if isinstance(v, int): if isinstance(v, int):
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_INT self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_INT
self._kv_overrides_array[i].value.int_value = v self._kv_overrides_array[i].value.int_value = v
@ -216,7 +217,9 @@ class Llama:
else: else:
raise ValueError(f"Unknown value type for {k}: {v}") raise ValueError(f"Unknown value type for {k}: {v}")
self._kv_overrides_array[-1].key = b'\0' # ensure sentinel element is zeroed self._kv_overrides_array[
-1
].key = b"\0" # ensure sentinel element is zeroed
self.model_params.kv_overrides = self._kv_overrides_array self.model_params.kv_overrides = self._kv_overrides_array
self.n_batch = min(n_ctx, n_batch) # ??? self.n_batch = min(n_ctx, n_batch) # ???
@ -326,7 +329,9 @@ class Llama:
(n_ctx, self._n_vocab), dtype=np.single (n_ctx, self._n_vocab), dtype=np.single
) )
self._mirostat_mu = ctypes.c_float(2.0 * 5.0) # TODO: Move this to sampling context self._mirostat_mu = ctypes.c_float(
2.0 * 5.0
) # TODO: Move this to sampling context
try: try:
self.metadata = self._model.metadata() self.metadata = self._model.metadata()
@ -334,7 +339,7 @@ class Llama:
self.metadata = {} self.metadata = {}
if self.verbose: if self.verbose:
print(f"Failed to load metadata: {e}", file=sys.stderr) print(f"Failed to load metadata: {e}", file=sys.stderr)
if self.verbose: if self.verbose:
print(f"Model metadata: {self.metadata}", file=sys.stderr) print(f"Model metadata: {self.metadata}", file=sys.stderr)
@ -534,7 +539,7 @@ class Llama:
candidates=self._candidates, candidates=self._candidates,
tau=mirostat_tau, tau=mirostat_tau,
eta=mirostat_eta, eta=mirostat_eta,
mu=ctypes.pointer(self._mirostat_mu) mu=ctypes.pointer(self._mirostat_mu),
) )
else: else:
self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1) self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1)