diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 51d237b..c185336 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -1,17 +1,49 @@ +import sys +import os import ctypes - from ctypes import c_int, c_float, c_char_p, c_void_p, c_bool, POINTER, Structure, Array, c_uint8, c_size_t - import pathlib -from itertools import chain # Load the library -# TODO: fragile, should fix -_base_path = pathlib.Path(__file__).parent -(_lib_path,) = chain( - _base_path.glob("*.so"), _base_path.glob("*.dylib"), _base_path.glob("*.dll") -) -_lib = ctypes.CDLL(str(_lib_path)) +def load_shared_library(lib_base_name): + # Determine the file extension based on the platform + if sys.platform.startswith("linux"): + lib_ext = ".so" + elif sys.platform == "darwin": + lib_ext = ".dylib" + elif sys.platform == "win32": + lib_ext = ".dll" + else: + raise RuntimeError("Unsupported platform") + + # Construct the paths to the possible shared library names + _base_path = pathlib.Path(__file__).parent.resolve() + # Searching for the library in the current directory under the name "libllama" (default name + # for llamacpp) and "llama" (default name for this repo) + _lib_paths = [ + _base_path / f"lib{lib_base_name}{lib_ext}", + _base_path / f"{lib_base_name}{lib_ext}" + ] + + # Add the library directory to the DLL search path on Windows (if needed) + if sys.platform == "win32" and sys.version_info >= (3, 8): + os.add_dll_directory(str(_base_path)) + + # Try to load the shared library, handling potential errors + for _lib_path in _lib_paths: + if _lib_path.exists(): + try: + return ctypes.CDLL(str(_lib_path)) + except Exception as e: + raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}") + + raise FileNotFoundError(f"Shared library with base name '{lib_base_name}' not found") + +# Specify the base name of the shared library to load +lib_base_name = "llama" + +# Load the library +_lib = load_shared_library(lib_base_name) # C types llama_context_p = c_void_p