diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index d5cf401..15307ab 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -923,6 +923,12 @@ class Llama: self._model = _LlamaModel( path_model=self.model_path, params=self.model_params, verbose=self.verbose ) + # Set the default value for the context and correct the batch + if n_ctx == 0: + n_ctx = self._model.n_ctx_train() + self.n_batch = min(n_ctx, n_batch) + self.context_params.n_ctx = self._model.n_ctx_train() + self.context_params.n_batch = self.n_batch self._ctx = _LlamaContext( model=self._model,