This commit is contained in:
Andrei Betlen 2023-09-13 21:23:23 -04:00
parent 2920c4bf7e
commit 4daf77e546

View file

@ -144,10 +144,8 @@ class ErrorResponseFormatters:
@staticmethod @staticmethod
def context_length_exceeded( def context_length_exceeded(
request: Union[ request: Union["CreateCompletionRequest", "CreateChatCompletionRequest"],
"CreateCompletionRequest", "CreateChatCompletionRequest" match, # type: Match[str] # type: ignore
],
match, # type: Match[str] # type: ignore
) -> Tuple[int, ErrorResponse]: ) -> Tuple[int, ErrorResponse]:
"""Formatter for context length exceeded error""" """Formatter for context length exceeded error"""
@ -184,10 +182,8 @@ class ErrorResponseFormatters:
@staticmethod @staticmethod
def model_not_found( def model_not_found(
request: Union[ request: Union["CreateCompletionRequest", "CreateChatCompletionRequest"],
"CreateCompletionRequest", "CreateChatCompletionRequest" match, # type: Match[str] # type: ignore
],
match # type: Match[str] # type: ignore
) -> Tuple[int, ErrorResponse]: ) -> Tuple[int, ErrorResponse]:
"""Formatter for model_not_found error""" """Formatter for model_not_found error"""
@ -315,12 +311,7 @@ def create_app(settings: Optional[Settings] = None):
settings = Settings() settings = Settings()
middleware = [ middleware = [
Middleware( Middleware(RawContextMiddleware, plugins=(plugins.RequestIdPlugin(),))
RawContextMiddleware,
plugins=(
plugins.RequestIdPlugin(),
)
)
] ]
app = FastAPI( app = FastAPI(
middleware=middleware, middleware=middleware,
@ -426,12 +417,13 @@ async def get_event_publisher(
except anyio.get_cancelled_exc_class() as e: except anyio.get_cancelled_exc_class() as e:
print("disconnected") print("disconnected")
with anyio.move_on_after(1, shield=True): with anyio.move_on_after(1, shield=True):
print( print(f"Disconnected from client (via refresh/close) {request.client}")
f"Disconnected from client (via refresh/close) {request.client}"
)
raise e raise e
model_field = Field(description="The model to use for generating completions.", default=None)
model_field = Field(
description="The model to use for generating completions.", default=None
)
max_tokens_field = Field( max_tokens_field = Field(
default=16, ge=1, description="The maximum number of tokens to generate." default=16, ge=1, description="The maximum number of tokens to generate."
@ -625,9 +617,9 @@ async def create_completion(
] ]
) )
iterator_or_completion: Union[llama_cpp.Completion, Iterator[ iterator_or_completion: Union[
llama_cpp.CompletionChunk llama_cpp.Completion, Iterator[llama_cpp.CompletionChunk]
]] = await run_in_threadpool(llama, **kwargs) ] = await run_in_threadpool(llama, **kwargs)
if isinstance(iterator_or_completion, Iterator): if isinstance(iterator_or_completion, Iterator):
# EAFP: It's easier to ask for forgiveness than permission # EAFP: It's easier to ask for forgiveness than permission
@ -641,12 +633,13 @@ async def create_completion(
send_chan, recv_chan = anyio.create_memory_object_stream(10) send_chan, recv_chan = anyio.create_memory_object_stream(10)
return EventSourceResponse( return EventSourceResponse(
recv_chan, data_sender_callable=partial( # type: ignore recv_chan,
data_sender_callable=partial( # type: ignore
get_event_publisher, get_event_publisher,
request=request, request=request,
inner_send_chan=send_chan, inner_send_chan=send_chan,
iterator=iterator(), iterator=iterator(),
) ),
) )
else: else:
return iterator_or_completion return iterator_or_completion
@ -762,9 +755,9 @@ async def create_chat_completion(
] ]
) )
iterator_or_completion: Union[llama_cpp.ChatCompletion, Iterator[ iterator_or_completion: Union[
llama_cpp.ChatCompletionChunk llama_cpp.ChatCompletion, Iterator[llama_cpp.ChatCompletionChunk]
]] = await run_in_threadpool(llama.create_chat_completion, **kwargs) ] = await run_in_threadpool(llama.create_chat_completion, **kwargs)
if isinstance(iterator_or_completion, Iterator): if isinstance(iterator_or_completion, Iterator):
# EAFP: It's easier to ask for forgiveness than permission # EAFP: It's easier to ask for forgiveness than permission
@ -778,12 +771,13 @@ async def create_chat_completion(
send_chan, recv_chan = anyio.create_memory_object_stream(10) send_chan, recv_chan = anyio.create_memory_object_stream(10)
return EventSourceResponse( return EventSourceResponse(
recv_chan, data_sender_callable=partial( # type: ignore recv_chan,
data_sender_callable=partial( # type: ignore
get_event_publisher, get_event_publisher,
request=request, request=request,
inner_send_chan=send_chan, inner_send_chan=send_chan,
iterator=iterator(), iterator=iterator(),
) ),
) )
else: else:
return iterator_or_completion return iterator_or_completion