Change tensor_split from array to pointer

This commit is contained in:
Shouyi Wang 2023-07-25 18:29:59 +10:00
parent c7c700b0d4
commit 426dbfe3f4

View file

@ -273,13 +273,12 @@ class Llama:
self.params.low_vram = low_vram
self.tensor_split = tensor_split
self._c_tensor_split = None
self._p_tensor_split = None
if self.tensor_split is not None:
#Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES.value
self._c_tensor_split = FloatArray(*tensor_split) # keep a reference to the array so it is not gc'd
self.params.tensor_split = self._c_tensor_split
FloatArray = (ctypes.c_float * len(self.tensor_split))(*self.tensor_split)
self._p_tensor_split = ctypes.POINTER(ctypes.c_float)(FloatArray) # keep a reference to the array so it is not gc'd
self.params.tensor_split = self._p_tensor_split
self.params.rope_freq_base = rope_freq_base
self.params.rope_freq_scale = rope_freq_scale