Update Llama to add params

This commit is contained in:
Andrei Betlen 2023-03-25 16:26:23 -04:00
parent 4525236214
commit 8ae3beda9c

View file

@ -13,12 +13,15 @@ class Llama:
def __init__(
self,
model_path: str,
# NOTE: The following parameters are likely to change in the future.
n_ctx: int = 512,
n_parts: int = -1,
seed: int = 1337,
f16_kv: bool = False,
logits_all: bool = False,
vocab_only: bool = False,
use_mlock: bool = False,
embedding: bool = False,
n_threads: Optional[int] = None,
) -> "Llama":
"""Load a llama.cpp model from `model_path`.
@ -31,6 +34,8 @@ class Llama:
f16_kv: Use half-precision for key/value cache.
logits_all: Return logits for all tokens, not just the last token.
vocab_only: Only load the vocabulary no weights.
use_mlock: Force the system to keep the model in RAM.
embedding: Embedding mode only.
n_threads: Number of threads to use. If None, the number of threads is automatically determined.
Raises:
@ -51,6 +56,8 @@ class Llama:
self.params.f16_kv = f16_kv
self.params.logits_all = logits_all
self.params.vocab_only = vocab_only
self.params.use_mlock = use_mlock
self.params.embedding = embedding
self.n_threads = n_threads or multiprocessing.cpu_count()