diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index fdde7ea..434d824 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -308,6 +308,8 @@ class Llama: self.tensor_split = tensor_split self._p_tensor_split = None if self.tensor_split is not None: + if len(self.tensor_split) > llama_cpp.LLAMA_MAX_DEVICES: + raise ValueError(f"Attempt to split tensors that exceed maximum supported devices. Current LLAMA_MAX_DEVICES={llama_cpp.LLAMA_MAX_DEVICES}") # Type conversion and expand the list to the length of LLAMA_MAX_DEVICES FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES self._c_tensor_split = FloatArray(