diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index f222dfd..4840caf 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -68,7 +68,7 @@ class Llama: maxlen=self.last_n_tokens_size, ) self.tokens_consumed = 0 - self.n_batch = n_batch + self.n_batch = min(n_ctx, n_batch) self.n_threads = n_threads or multiprocessing.cpu_count()