diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 54424cb..966f79f 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -229,8 +229,8 @@ class Llama: n_batch: int = 512, n_threads: Optional[int] = None, n_threads_batch: Optional[int] = None, - rope_freq_base: float = 10000.0, - rope_freq_scale: float = 1.0, + rope_freq_base: float = 0.0, + rope_freq_scale: float = 0.0, mul_mat_q: bool = True, f16_kv: bool = True, logits_all: bool = False, @@ -282,7 +282,6 @@ class Llama: Returns: A Llama instance. """ - self.verbose = verbose self.numa = numa @@ -320,7 +319,6 @@ class Llama: self.n_threads_batch = n_threads_batch or max( multiprocessing.cpu_count() // 2, 1 ) - # Context Params self.context_params = llama_cpp.llama_context_default_params() self.context_params.seed = seed @@ -328,8 +326,12 @@ class Llama: self.context_params.n_batch = self.n_batch self.context_params.n_threads = self.n_threads self.context_params.n_threads_batch = self.n_threads_batch - self.context_params.rope_freq_base = rope_freq_base - self.context_params.rope_freq_scale = rope_freq_scale + self.context_params.rope_freq_base = ( + rope_freq_base if rope_freq_base != 0.0 else 0 + ) + self.context_params.rope_freq_scale = ( + rope_freq_scale if rope_freq_scale != 0.0 else 0 + ) self.context_params.mul_mat_q = mul_mat_q self.context_params.f16_kv = f16_kv self.context_params.logits_all = logits_all @@ -338,7 +340,6 @@ class Llama: # Sampling Params self.last_n_tokens_size = last_n_tokens_size - self.cache: Optional[BaseLlamaCache] = None self.lora_base = lora_base