This commit is contained in:
Andrei Betlen 2024-04-03 00:55:21 -04:00
commit f96de6d920

View file

@ -730,12 +730,14 @@ class _LlamaSamplingContext:
if len(self.prev) > 0:
nl_token = ctx_main.model.token_nl()
nl_logit = logits_array[nl_token]
if self.params.penalty_last_n > 0:
last_tokens = self.prev[-self.params.penalty_last_n:]
last_tokens_size = min(len(last_tokens), self.params.penalty_last_n)
if last_tokens_size > 0:
last_tokens_p = (llama_cpp.llama_token * len(last_tokens))(*last_tokens)
ctx_main.sample_repetition_penalties(
token_data_array,
# TODO: Only create this once
(llama_cpp.llama_token * len(self.prev))(*self.prev),
self.params.penalty_last_n,
last_tokens_p,
last_tokens_size,
self.params.penalty_repeat,
self.params.penalty_freq,
self.params.penalty_present,