Implement GGUF metadata KV overrides (#1011)

* Implement GGUF metadata overrides

* whitespace fix

* Fix kv overrides.

* Fix pointer and pickle

* Match llama.cpp kv_overrides cli argument

---------

Co-authored-by: Andrei <abetlen@gmail.com>
This commit is contained in:
Phil H 2024-01-15 17:29:29 +00:00 committed by GitHub
parent 7eff42c239
commit 76aafa6149
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 55 additions and 1 deletions

View file

@ -735,6 +735,7 @@ class Llama:
vocab_only: bool = False,
use_mmap: bool = True,
use_mlock: bool = False,
kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None,
# Context Params
seed: int = llama_cpp.LLAMA_DEFAULT_SEED,
n_ctx: int = 512,
@ -803,6 +804,7 @@ class Llama:
vocab_only: Only load the vocabulary no weights.
use_mmap: Use mmap if possible.
use_mlock: Force the system to keep the model in RAM.
kv_overrides: Key-value overrides for the model.
seed: RNG seed, -1 for random
n_ctx: Text context, 0 = from model
n_batch: Prompt processing maximum batch size
@ -866,6 +868,34 @@ class Llama:
self.model_params.use_mmap = use_mmap if lora_path is None else False
self.model_params.use_mlock = use_mlock
self.kv_overrides = kv_overrides
if kv_overrides is not None:
n_overrides = len(kv_overrides)
self._kv_overrides_array = llama_cpp.llama_model_kv_override * (n_overrides + 1)
self._kv_overrides_array_keys = []
for k, v in kv_overrides.items():
key_buf = ctypes.create_string_buffer(k.encode("utf-8"))
self._kv_overrides_array_keys.append(key_buf)
self._kv_overrides_array[i].key = key_buf
if isinstance(v, int):
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_INT
self._kv_overrides_array[i].value.int_value = v
elif isinstance(v, float):
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_FLOAT
self._kv_overrides_array[i].value.float_value = v
elif isinstance(v, bool):
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_BOOL
self._kv_overrides_array[i].value.bool_value = v
else:
raise ValueError(f"Unknown value type for {k}: {v}")
self._kv_overrides_array_sentinel_key = b'\0'
# null array sentinel
self._kv_overrides_array[n_overrides].key = self._kv_overrides_array_sentinel_key
self.model_params.kv_overrides = self._kv_overrides_array
self.n_batch = min(n_ctx, n_batch) # ???
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
self.n_threads_batch = n_threads_batch or max(
@ -2148,6 +2178,7 @@ class Llama:
vocab_only=self.model_params.vocab_only,
use_mmap=self.model_params.use_mmap,
use_mlock=self.model_params.use_mlock,
kv_overrides=self.kv_overrides,
# Context Params
seed=self.context_params.seed,
n_ctx=self.context_params.n_ctx,
@ -2190,6 +2221,7 @@ class Llama:
vocab_only=state["vocab_only"],
use_mmap=state["use_mmap"],
use_mlock=state["use_mlock"],
kv_overrides=state["kv_overrides"],
# Context Params
seed=state["seed"],
n_ctx=state["n_ctx"],

View file

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Optional, Union, List
from typing import Dict, Optional, Union, List
import llama_cpp
@ -71,6 +71,23 @@ class LlamaProxy:
chat_handler = llama_cpp.llama_chat_format.Llava15ChatHandler(
clip_model_path=settings.clip_model_path, verbose=settings.verbose
)
kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None
if settings.kv_overrides is not None:
assert isinstance(settings.kv_overrides, list)
kv_overrides = {}
for kv in settings.kv_overrides:
key, value = kv.split("=")
if ":" in value:
value_type, value = value.split(":")
if value_type == "bool":
kv_overrides[key] = value.lower() in ["true", "1"]
elif value_type == "int":
kv_overrides[key] = int(value)
elif value_type == "float":
kv_overrides[key] = float(value)
else:
raise ValueError(f"Unknown value type {value_type}")
_model = llama_cpp.Llama(
model_path=settings.model,
@ -81,6 +98,7 @@ class LlamaProxy:
vocab_only=settings.vocab_only,
use_mmap=settings.use_mmap,
use_mlock=settings.use_mlock,
kv_overrides=kv_overrides,
# Context Params
seed=settings.seed,
n_ctx=settings.n_ctx,

View file

@ -48,6 +48,10 @@ class ModelSettings(BaseSettings):
default=llama_cpp.llama_mlock_supported(),
description="Use mlock.",
)
kv_overrides: Optional[List[str]] = Field(
default=None,
description="List of model kv overrides in the format key=type:value where type is one of (bool, int, float). Valid true values are (true, TRUE, 1), otherwise false.",
)
# Context Params
seed: int = Field(
default=llama_cpp.LLAMA_DEFAULT_SEED, description="Random seed. -1 for random."