diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 1e78221..f2e1383 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -1019,12 +1019,11 @@ class Llama: """ assert self._ctx.ctx is not None assert self._batch.batch is not None - n_ctx = self._n_ctx + self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1) for i in range(0, len(tokens), self.n_batch): batch = tokens[i : min(len(tokens), i + self.n_batch)] - n_past = min(n_ctx - len(batch), self.n_tokens) + n_past = self.n_tokens n_tokens = len(batch) - self._ctx.kv_cache_seq_rm(-1, n_past, -1) self._batch.set_batch( batch=batch, n_past=n_past, logits_all=self.context_params.logits_all )