Fix tensor_split cli option

This commit is contained in:
Andrei Betlen 2023-09-13 20:00:42 -04:00
parent 203ede4ba2
commit c4c440ba2d
2 changed files with 16 additions and 7 deletions

View file

@ -288,7 +288,7 @@ class Llama:
if self.tensor_split is not None:
# Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES.value
FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES
self._c_tensor_split = FloatArray(
*tensor_split
) # keep a reference to the array so it is not gc'd

View file

@ -23,17 +23,12 @@ Then visit http://localhost:8000/docs to see the interactive API docs.
"""
import os
import argparse
from typing import Literal, Union
from typing import List, Literal, Union
import uvicorn
from llama_cpp.server.app import create_app, Settings
def get_non_none_base_types(annotation):
if not hasattr(annotation, "__args__"):
return annotation
return [arg for arg in annotation.__args__ if arg is not type(None)][0]
def get_base_type(annotation):
if getattr(annotation, '__origin__', None) is Literal:
return type(annotation.__args__[0])
@ -41,9 +36,22 @@ def get_base_type(annotation):
non_optional_args = [arg for arg in annotation.__args__ if arg is not type(None)]
if non_optional_args:
return get_base_type(non_optional_args[0])
elif getattr(annotation, '__origin__', None) is list or getattr(annotation, '__origin__', None) is List:
return get_base_type(annotation.__args__[0])
else:
return annotation
def contains_list_type(annotation) -> bool:
origin = getattr(annotation, '__origin__', None)
if origin is list or origin is List:
return True
elif origin in (Literal, Union):
return any(contains_list_type(arg) for arg in annotation.__args__)
else:
return False
if __name__ == "__main__":
parser = argparse.ArgumentParser()
for name, field in Settings.model_fields.items():
@ -53,6 +61,7 @@ if __name__ == "__main__":
parser.add_argument(
f"--{name}",
dest=name,
nargs="*" if contains_list_type(field.annotation) else None,
type=get_base_type(field.annotation) if field.annotation is not None else str,
help=description,
)