diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 5f09e4d..ea9f0ff 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -112,21 +112,20 @@ class Llama: self.model_path.encode("utf-8"), self.params ) - self.lora_base = None - self.lora_path = None - if lora_path: - self.lora_base = lora_base - # Use lora_base if set otherwise revert to using model_path. - lora_base = lora_base if lora_base is not None else model_path - - self.lora_path = lora_path + self.lora_base = lora_base + self.lora_path = lora_path + if self.lora_path: if llama_cpp.llama_apply_lora_from_file( self.ctx, - lora_path.encode("utf-8"), - lora_base.encode("utf-8"), + llama_cpp.c_char_p(self.lora_path.encode("utf-8")), + llama_cpp.c_char_p(self.lora_base.encode("utf-8")) + if self.lora_base is not None + else llama_cpp.c_char_p(0), llama_cpp.c_int(self.n_threads), ): - raise RuntimeError(f"Failed to apply LoRA from lora path: {lora_path} to base path: {lora_base}") + raise RuntimeError( + f"Failed to apply LoRA from lora path: {self.lora_path} to base path: {self.lora_base}" + ) if self.verbose: print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)