From 453e517fd54c5f2a882199629beb0f01002e0b40 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Tue, 18 Apr 2023 10:20:46 -0400 Subject: [PATCH] Add seperate lora_base path for applying LoRA to quantized models using original unquantized model weights. --- llama_cpp/llama.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 931d0ff..5f09e4d 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -39,6 +39,7 @@ class Llama: n_threads: Optional[int] = None, n_batch: int = 8, last_n_tokens_size: int = 64, + lora_base: Optional[str] = None, lora_path: Optional[str] = None, verbose: bool = True, ): @@ -58,6 +59,7 @@ class Llama: n_threads: Number of threads to use. If None, the number of threads is automatically determined. n_batch: Maximum number of prompt tokens to batch together when calling llama_eval. last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque. + lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model. lora_path: Path to a LoRA file to apply to the model. verbose: Print verbose output to stderr. @@ -110,16 +112,21 @@ class Llama: self.model_path.encode("utf-8"), self.params ) + self.lora_base = None self.lora_path = None if lora_path: + self.lora_base = lora_base + # Use lora_base if set otherwise revert to using model_path. + lora_base = lora_base if lora_base is not None else model_path + self.lora_path = lora_path if llama_cpp.llama_apply_lora_from_file( self.ctx, - self.lora_path.encode("utf-8"), - self.model_path.encode("utf-8"), + lora_path.encode("utf-8"), + lora_base.encode("utf-8"), llama_cpp.c_int(self.n_threads), ): - raise RuntimeError(f"Failed to apply LoRA from path: {self.lora_path}") + raise RuntimeError(f"Failed to apply LoRA from lora path: {lora_path} to base path: {lora_base}") if self.verbose: print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr) @@ -815,6 +822,7 @@ class Llama: last_n_tokens_size=self.last_n_tokens_size, n_batch=self.n_batch, n_threads=self.n_threads, + lora_base=self.lora_base, lora_path=self.lora_path, ) @@ -833,6 +841,7 @@ class Llama: n_threads=state["n_threads"], n_batch=state["n_batch"], last_n_tokens_size=state["last_n_tokens_size"], + lora_base=state["lora_base"], lora_path=state["lora_path"], verbose=state["verbose"], )