diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 5c66bcf..538781d 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -194,16 +194,16 @@ class Llama: self.model_params.use_mmap = use_mmap if lora_path is None else False self.model_params.use_mlock = use_mlock + # kv_overrides is the original python dict self.kv_overrides = kv_overrides if kv_overrides is not None: - n_overrides = len(kv_overrides) - self._kv_overrides_array = llama_cpp.llama_model_kv_override * (n_overrides + 1) - self._kv_overrides_array_keys = [] - for k, v in kv_overrides.items(): - key_buf = ctypes.create_string_buffer(k.encode("utf-8")) - self._kv_overrides_array_keys.append(key_buf) - self._kv_overrides_array[i].key = key_buf + # _kv_overrides_array is a ctypes.Array of llama_model_kv_override Structs + kvo_array_len = len(kv_overrides) + 1 # for sentinel element + self._kv_overrides_array = (llama_cpp.llama_model_kv_override * kvo_array_len)() + + for i, (k, v) in enumerate(kv_overrides.items()): + self._kv_overrides_array[i].key = k.encode('utf-8'); if isinstance(v, int): self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_INT self._kv_overrides_array[i].value.int_value = v @@ -216,10 +216,7 @@ class Llama: else: raise ValueError(f"Unknown value type for {k}: {v}") - self._kv_overrides_array_sentinel_key = b'\0' - - # null array sentinel - self._kv_overrides_array[n_overrides].key = self._kv_overrides_array_sentinel_key + self._kv_overrides_array[-1].key = b'\0' # ensure sentinel element is zeroed self.model_params.kv_overrides = self._kv_overrides_array self.n_batch = min(n_ctx, n_batch) # ???