Add validation for tensor_split size exceeding LLAMA_MAX_DEVICES (#820)

* Add validation for tensor_split size exceeding LLAMA_MAX_DEVICES

* reword
This commit is contained in:
Eric Liu 2023-10-15 10:51:51 -07:00 committed by GitHub
parent f30aa20126
commit b50166500e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -308,6 +308,8 @@ class Llama:
self.tensor_split = tensor_split
self._p_tensor_split = None
if self.tensor_split is not None:
if len(self.tensor_split) > llama_cpp.LLAMA_MAX_DEVICES:
raise ValueError(f"Attempt to split tensors that exceed maximum supported devices. Current LLAMA_MAX_DEVICES={llama_cpp.LLAMA_MAX_DEVICES}")
# Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES
self._c_tensor_split = FloatArray(