Fix tensor_split cli option

This commit is contained in:
Andrei Betlen 2023-09-13 20:00:42 -04:00
parent 203ede4ba2
commit c4c440ba2d
2 changed files with 16 additions and 7 deletions

View file

@ -288,7 +288,7 @@ class Llama:
if self.tensor_split is not None: if self.tensor_split is not None:
# Type conversion and expand the list to the length of LLAMA_MAX_DEVICES # Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES.value FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES
self._c_tensor_split = FloatArray( self._c_tensor_split = FloatArray(
*tensor_split *tensor_split
) # keep a reference to the array so it is not gc'd ) # keep a reference to the array so it is not gc'd

View file

@ -23,17 +23,12 @@ Then visit http://localhost:8000/docs to see the interactive API docs.
""" """
import os import os
import argparse import argparse
from typing import Literal, Union from typing import List, Literal, Union
import uvicorn import uvicorn
from llama_cpp.server.app import create_app, Settings from llama_cpp.server.app import create_app, Settings
def get_non_none_base_types(annotation):
if not hasattr(annotation, "__args__"):
return annotation
return [arg for arg in annotation.__args__ if arg is not type(None)][0]
def get_base_type(annotation): def get_base_type(annotation):
if getattr(annotation, '__origin__', None) is Literal: if getattr(annotation, '__origin__', None) is Literal:
return type(annotation.__args__[0]) return type(annotation.__args__[0])
@ -41,9 +36,22 @@ def get_base_type(annotation):
non_optional_args = [arg for arg in annotation.__args__ if arg is not type(None)] non_optional_args = [arg for arg in annotation.__args__ if arg is not type(None)]
if non_optional_args: if non_optional_args:
return get_base_type(non_optional_args[0]) return get_base_type(non_optional_args[0])
elif getattr(annotation, '__origin__', None) is list or getattr(annotation, '__origin__', None) is List:
return get_base_type(annotation.__args__[0])
else: else:
return annotation return annotation
def contains_list_type(annotation) -> bool:
origin = getattr(annotation, '__origin__', None)
if origin is list or origin is List:
return True
elif origin in (Literal, Union):
return any(contains_list_type(arg) for arg in annotation.__args__)
else:
return False
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
for name, field in Settings.model_fields.items(): for name, field in Settings.model_fields.items():
@ -53,6 +61,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
f"--{name}", f"--{name}",
dest=name, dest=name,
nargs="*" if contains_list_type(field.annotation) else None,
type=get_base_type(field.annotation) if field.annotation is not None else str, type=get_base_type(field.annotation) if field.annotation is not None else str,
help=description, help=description,
) )