Bugfix: n_batch should always be <= n_ctx

This commit is contained in:
Andrei Betlen 2023-04-04 13:08:21 -04:00
parent 248b0566fa
commit 5075c16fcc

View file

@ -68,7 +68,7 @@ class Llama:
maxlen=self.last_n_tokens_size,
)
self.tokens_consumed = 0
self.n_batch = n_batch
self.n_batch = min(n_ctx, n_batch)
self.n_threads = n_threads or multiprocessing.cpu_count()