Merge pull request #523 from shouyiwang/tensor_split

Update tensor_split to match llama.cpp's change
This commit is contained in:
Andrei 2023-07-26 13:53:02 -04:00 committed by GitHub
commit e665b557fd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -273,13 +273,12 @@ class Llama:
self.params.low_vram = low_vram
self.tensor_split = tensor_split
self._c_tensor_split = None
self._p_tensor_split = None
if self.tensor_split is not None:
#Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES.value
self._c_tensor_split = FloatArray(*tensor_split) # keep a reference to the array so it is not gc'd
self.params.tensor_split = self._c_tensor_split
FloatArray = (ctypes.c_float * len(self.tensor_split))(*self.tensor_split)
self._p_tensor_split = ctypes.POINTER(ctypes.c_float)(FloatArray) # keep a reference to the array so it is not gc'd
self.params.tensor_split = self._p_tensor_split
self.params.rope_freq_base = rope_freq_base
self.params.rope_freq_scale = rope_freq_scale