From a40476e299e9468f4be69a075c028b005a5812f8 Mon Sep 17 00:00:00 2001 From: MillionthOdin16 <102247808+MillionthOdin16@users.noreply.github.com> Date: Sun, 2 Apr 2023 21:50:13 -0400 Subject: [PATCH] Update llama_cpp.py Make shared library code more robust with some platform specific functionality and more descriptive errors when failures occur --- llama_cpp/llama_cpp.py | 50 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 41 insertions(+), 9 deletions(-) diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 51d237b..c185336 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -1,17 +1,49 @@ +import sys +import os import ctypes - from ctypes import c_int, c_float, c_char_p, c_void_p, c_bool, POINTER, Structure, Array, c_uint8, c_size_t - import pathlib -from itertools import chain # Load the library -# TODO: fragile, should fix -_base_path = pathlib.Path(__file__).parent -(_lib_path,) = chain( - _base_path.glob("*.so"), _base_path.glob("*.dylib"), _base_path.glob("*.dll") -) -_lib = ctypes.CDLL(str(_lib_path)) +def load_shared_library(lib_base_name): + # Determine the file extension based on the platform + if sys.platform.startswith("linux"): + lib_ext = ".so" + elif sys.platform == "darwin": + lib_ext = ".dylib" + elif sys.platform == "win32": + lib_ext = ".dll" + else: + raise RuntimeError("Unsupported platform") + + # Construct the paths to the possible shared library names + _base_path = pathlib.Path(__file__).parent.resolve() + # Searching for the library in the current directory under the name "libllama" (default name + # for llamacpp) and "llama" (default name for this repo) + _lib_paths = [ + _base_path / f"lib{lib_base_name}{lib_ext}", + _base_path / f"{lib_base_name}{lib_ext}" + ] + + # Add the library directory to the DLL search path on Windows (if needed) + if sys.platform == "win32" and sys.version_info >= (3, 8): + os.add_dll_directory(str(_base_path)) + + # Try to load the shared library, handling potential errors + for _lib_path in _lib_paths: + if _lib_path.exists(): + try: + return ctypes.CDLL(str(_lib_path)) + except Exception as e: + raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}") + + raise FileNotFoundError(f"Shared library with base name '{lib_base_name}' not found") + +# Specify the base name of the shared library to load +lib_base_name = "llama" + +# Load the library +_lib = load_shared_library(lib_base_name) # C types llama_context_p = c_void_p