fix: last tokens passing to sample_repetition_penalties function (#1295)

Co-authored-by: ymikhaylov <ymikhaylov@x5.ru>
Co-authored-by: Andrei <abetlen@gmail.com>
This commit is contained in:
Yuri Mikhailov 2024-04-02 04:25:43 +09:00 committed by GitHub
parent 45bf5ae582
commit 62aad610e1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -730,12 +730,14 @@ class _LlamaSamplingContext:
if len(self.prev) > 0:
nl_token = ctx_main.model.token_nl()
nl_logit = logits_array[nl_token]
if self.params.penalty_last_n > 0:
last_tokens = self.prev[-self.params.penalty_last_n:]
last_tokens_size = min(len(last_tokens), self.params.penalty_last_n)
if last_tokens_size > 0:
last_tokens_p = (llama_cpp.llama_token * len(last_tokens))(*last_tokens)
ctx_main.sample_repetition_penalties(
token_data_array,
# TODO: Only create this once
(llama_cpp.llama_token * len(self.prev))(*self.prev),
self.params.penalty_last_n,
last_tokens_p,
last_tokens_size,
self.params.penalty_repeat,
self.params.penalty_freq,
self.params.penalty_present,