diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index a1634fa..bb9b0e5 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -15,28 +15,32 @@ from ctypes import ( c_size_t, ) import pathlib +from typing import List # Load the library def _load_shared_library(lib_base_name: str): - # Determine the file extension based on the platform - if sys.platform.startswith("linux"): - lib_ext = ".so" - elif sys.platform == "darwin": - lib_ext = ".so" - elif sys.platform == "win32": - lib_ext = ".dll" - else: - raise RuntimeError("Unsupported platform") - # Construct the paths to the possible shared library names _base_path = pathlib.Path(__file__).parent.resolve() # Searching for the library in the current directory under the name "libllama" (default name # for llamacpp) and "llama" (default name for this repo) - _lib_paths = [ - _base_path / f"lib{lib_base_name}{lib_ext}", - _base_path / f"{lib_base_name}{lib_ext}", - ] + _lib_paths: List[pathlib.Path] = [] + # Determine the file extension based on the platform + if sys.platform.startswith("linux"): + _lib_paths += [ + _base_path / f"lib{lib_base_name}.so", + ] + elif sys.platform == "darwin": + _lib_paths += [ + _base_path / f"lib{lib_base_name}.so", + _base_path / f"lib{lib_base_name}.dylib", + ] + elif sys.platform == "win32": + _lib_paths += [ + _base_path / f"{lib_base_name}.dll", + ] + else: + raise RuntimeError("Unsupported platform") if "LLAMA_CPP_LIB" in os.environ: lib_base_name = os.environ["LLAMA_CPP_LIB"] @@ -160,6 +164,7 @@ llama_progress_callback = ctypes.CFUNCTYPE(None, c_float, c_void_p) # bool use_mlock; // force system to keep model in RAM # bool embedding; // embedding mode only + # // called with a progress value between 0 and 1, pass NULL to disable # llama_progress_callback progress_callback; # // context pointer passed to the progress callback