Fix rope scaling defaults (#767)

* Fix rope scale with backwards compatibility

* Fix defaults

* Fix op

* Remove backwards compatibility

* Check single val
This commit is contained in:
Josh XT 2023-09-29 16:03:57 -04:00 committed by GitHub
parent a72efc77de
commit a945404b4a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -229,8 +229,8 @@ class Llama:
n_batch: int = 512,
n_threads: Optional[int] = None,
n_threads_batch: Optional[int] = None,
rope_freq_base: float = 10000.0,
rope_freq_scale: float = 1.0,
rope_freq_base: float = 0.0,
rope_freq_scale: float = 0.0,
mul_mat_q: bool = True,
f16_kv: bool = True,
logits_all: bool = False,
@ -282,7 +282,6 @@ class Llama:
Returns:
A Llama instance.
"""
self.verbose = verbose
self.numa = numa
@ -320,7 +319,6 @@ class Llama:
self.n_threads_batch = n_threads_batch or max(
multiprocessing.cpu_count() // 2, 1
)
# Context Params
self.context_params = llama_cpp.llama_context_default_params()
self.context_params.seed = seed
@ -328,8 +326,12 @@ class Llama:
self.context_params.n_batch = self.n_batch
self.context_params.n_threads = self.n_threads
self.context_params.n_threads_batch = self.n_threads_batch
self.context_params.rope_freq_base = rope_freq_base
self.context_params.rope_freq_scale = rope_freq_scale
self.context_params.rope_freq_base = (
rope_freq_base if rope_freq_base != 0.0 else 0
)
self.context_params.rope_freq_scale = (
rope_freq_scale if rope_freq_scale != 0.0 else 0
)
self.context_params.mul_mat_q = mul_mat_q
self.context_params.f16_kv = f16_kv
self.context_params.logits_all = logits_all
@ -338,7 +340,6 @@ class Llama:
# Sampling Params
self.last_n_tokens_size = last_n_tokens_size
self.cache: Optional[BaseLlamaCache] = None
self.lora_base = lora_base