Re-order Llama class params

This commit is contained in:
Andrei Betlen 2023-07-15 15:35:08 -04:00
parent e4f9db37db
commit 8ab098e49d

View file

@ -205,8 +205,6 @@ class Llama:
model_path: str,
# NOTE: These parameters are likely to change in the future.
n_ctx: int = 512,
rope_freq_base: float = 10000.0,
rope_freq_scale: float = 1.0,
n_parts: int = -1,
n_gpu_layers: int = 0,
seed: int = 1337,
@ -223,6 +221,8 @@ class Llama:
lora_path: Optional[str] = None,
low_vram: bool = False,
tensor_split: Optional[List[float]] = None,
rope_freq_base: float = 10000.0,
rope_freq_scale: float = 1.0,
verbose: bool = True,
):
"""Load a llama.cpp model from `model_path`.
@ -230,8 +230,6 @@ class Llama:
Args:
model_path: Path to the model.
n_ctx: Maximum context size.
rope_freq_base: RoPE base frequency.
rope_freq_scale: RoPE frequency scale.
n_parts: Number of parts to split the model into. If -1, the number of parts is automatically determined.
seed: Random seed. -1 for random.
f16_kv: Use half-precision for key/value cache.
@ -246,6 +244,8 @@ class Llama:
lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
lora_path: Path to a LoRA file to apply to the model.
tensor_split: List of floats to split the model across multiple GPUs. If None, the model is not split.
rope_freq_base: Base frequency for rope sampling.
rope_freq_scale: Scale factor for rope sampling.
verbose: Print verbose output to stderr.
Raises:
@ -260,8 +260,6 @@ class Llama:
self.params = llama_cpp.llama_context_default_params()
self.params.n_ctx = n_ctx
self.params.rope_freq_base = rope_freq_base
self.params.rope_freq_scale = rope_freq_scale
self.params.n_gpu_layers = n_gpu_layers
self.params.seed = seed
self.params.f16_kv = f16_kv
@ -281,6 +279,9 @@ class Llama:
self._c_tensor_split = FloatArray(*tensor_split) # keep a reference to the array so it is not gc'd
self.params.tensor_split = self._c_tensor_split
self.params.rope_freq_base = rope_freq_base
self.params.rope_freq_scale = rope_freq_scale
self.last_n_tokens_size = last_n_tokens_size
self.n_batch = min(n_ctx, n_batch)