diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index f79a2c2..1c73c0a 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -288,7 +288,7 @@ class Llama: if self.tensor_split is not None: # Type conversion and expand the list to the length of LLAMA_MAX_DEVICES - FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES.value + FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES self._c_tensor_split = FloatArray( *tensor_split ) # keep a reference to the array so it is not gc'd diff --git a/llama_cpp/server/__main__.py b/llama_cpp/server/__main__.py index db104e7..7c1b9f4 100644 --- a/llama_cpp/server/__main__.py +++ b/llama_cpp/server/__main__.py @@ -23,17 +23,12 @@ Then visit http://localhost:8000/docs to see the interactive API docs. """ import os import argparse -from typing import Literal, Union +from typing import List, Literal, Union import uvicorn from llama_cpp.server.app import create_app, Settings -def get_non_none_base_types(annotation): - if not hasattr(annotation, "__args__"): - return annotation - return [arg for arg in annotation.__args__ if arg is not type(None)][0] - def get_base_type(annotation): if getattr(annotation, '__origin__', None) is Literal: return type(annotation.__args__[0]) @@ -41,9 +36,22 @@ def get_base_type(annotation): non_optional_args = [arg for arg in annotation.__args__ if arg is not type(None)] if non_optional_args: return get_base_type(non_optional_args[0]) + elif getattr(annotation, '__origin__', None) is list or getattr(annotation, '__origin__', None) is List: + return get_base_type(annotation.__args__[0]) else: return annotation +def contains_list_type(annotation) -> bool: + origin = getattr(annotation, '__origin__', None) + + if origin is list or origin is List: + return True + elif origin in (Literal, Union): + return any(contains_list_type(arg) for arg in annotation.__args__) + else: + return False + + if __name__ == "__main__": parser = argparse.ArgumentParser() for name, field in Settings.model_fields.items(): @@ -53,6 +61,7 @@ if __name__ == "__main__": parser.add_argument( f"--{name}", dest=name, + nargs="*" if contains_list_type(field.annotation) else None, type=get_base_type(field.annotation) if field.annotation is not None else str, help=description, )