diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 11d0ad4..a1634fa 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -79,6 +79,10 @@ c_size_t_p = POINTER(c_size_t) # llama.h bindings +GGML_USE_CUBLAS = hasattr(_lib, "ggml_init_cublas") +GGML_CUDA_MAX_DEVICES = ctypes.c_int(16) +LLAMA_MAX_DEVICES = GGML_CUDA_MAX_DEVICES if GGML_USE_CUBLAS else ctypes.c_int(1) + # #define LLAMA_FILE_MAGIC_GGJT 0x67676a74u // 'ggjt' LLAMA_FILE_MAGIC_GGJT = ctypes.c_uint(0x67676A74) # #define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla' @@ -142,9 +146,12 @@ llama_progress_callback = ctypes.CFUNCTYPE(None, c_float, c_void_p) # struct llama_context_params { -# int n_ctx; // text context -# int n_gpu_layers; // number of layers to store in VRAM -# int seed; // RNG seed, -1 for random +# int n_ctx; // text context +# int n_batch; // prompt processing batch size +# int n_gpu_layers; // number of layers to store in VRAM +# int main_gpu; // the GPU that is used for scratch and small tensors +# float tensor_split[LLAMA_MAX_DEVICES]; // how to split layers across multiple GPUs +# int seed; // RNG seed, -1 for random # bool f16_kv; // use fp16 for KV cache # bool logits_all; // the llama_eval() call computes all logits, not just the last one @@ -153,7 +160,6 @@ llama_progress_callback = ctypes.CFUNCTYPE(None, c_float, c_void_p) # bool use_mlock; // force system to keep model in RAM # bool embedding; // embedding mode only - # // called with a progress value between 0 and 1, pass NULL to disable # llama_progress_callback progress_callback; # // context pointer passed to the progress callback @@ -162,7 +168,10 @@ llama_progress_callback = ctypes.CFUNCTYPE(None, c_float, c_void_p) class llama_context_params(Structure): _fields_ = [ ("n_ctx", c_int), + ("n_batch", c_int), ("n_gpu_layers", c_int), + ("main_gpu", c_int), + ("tensor_split", c_float * LLAMA_MAX_DEVICES.value), ("seed", c_int), ("f16_kv", c_bool), ( diff --git a/vendor/llama.cpp b/vendor/llama.cpp index f4c55d3..2d7bf11 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit f4c55d3bd7e124b101bc974cbbf0e0dbbc32d5a3 +Subproject commit 2d7bf110edd8c49209401a16132052cba706ffd0