This commit is contained in:
Andrei Betlen 2023-09-13 21:23:23 -04:00
parent 2920c4bf7e
commit 4daf77e546

View file

@ -144,10 +144,8 @@ class ErrorResponseFormatters:
@staticmethod
def context_length_exceeded(
request: Union[
"CreateCompletionRequest", "CreateChatCompletionRequest"
],
match, # type: Match[str] # type: ignore
request: Union["CreateCompletionRequest", "CreateChatCompletionRequest"],
match, # type: Match[str] # type: ignore
) -> Tuple[int, ErrorResponse]:
"""Formatter for context length exceeded error"""
@ -184,10 +182,8 @@ class ErrorResponseFormatters:
@staticmethod
def model_not_found(
request: Union[
"CreateCompletionRequest", "CreateChatCompletionRequest"
],
match # type: Match[str] # type: ignore
request: Union["CreateCompletionRequest", "CreateChatCompletionRequest"],
match, # type: Match[str] # type: ignore
) -> Tuple[int, ErrorResponse]:
"""Formatter for model_not_found error"""
@ -315,12 +311,7 @@ def create_app(settings: Optional[Settings] = None):
settings = Settings()
middleware = [
Middleware(
RawContextMiddleware,
plugins=(
plugins.RequestIdPlugin(),
)
)
Middleware(RawContextMiddleware, plugins=(plugins.RequestIdPlugin(),))
]
app = FastAPI(
middleware=middleware,
@ -426,12 +417,13 @@ async def get_event_publisher(
except anyio.get_cancelled_exc_class() as e:
print("disconnected")
with anyio.move_on_after(1, shield=True):
print(
f"Disconnected from client (via refresh/close) {request.client}"
)
print(f"Disconnected from client (via refresh/close) {request.client}")
raise e
model_field = Field(description="The model to use for generating completions.", default=None)
model_field = Field(
description="The model to use for generating completions.", default=None
)
max_tokens_field = Field(
default=16, ge=1, description="The maximum number of tokens to generate."
@ -625,9 +617,9 @@ async def create_completion(
]
)
iterator_or_completion: Union[llama_cpp.Completion, Iterator[
llama_cpp.CompletionChunk
]] = await run_in_threadpool(llama, **kwargs)
iterator_or_completion: Union[
llama_cpp.Completion, Iterator[llama_cpp.CompletionChunk]
] = await run_in_threadpool(llama, **kwargs)
if isinstance(iterator_or_completion, Iterator):
# EAFP: It's easier to ask for forgiveness than permission
@ -641,12 +633,13 @@ async def create_completion(
send_chan, recv_chan = anyio.create_memory_object_stream(10)
return EventSourceResponse(
recv_chan, data_sender_callable=partial( # type: ignore
recv_chan,
data_sender_callable=partial( # type: ignore
get_event_publisher,
request=request,
inner_send_chan=send_chan,
iterator=iterator(),
)
),
)
else:
return iterator_or_completion
@ -762,9 +755,9 @@ async def create_chat_completion(
]
)
iterator_or_completion: Union[llama_cpp.ChatCompletion, Iterator[
llama_cpp.ChatCompletionChunk
]] = await run_in_threadpool(llama.create_chat_completion, **kwargs)
iterator_or_completion: Union[
llama_cpp.ChatCompletion, Iterator[llama_cpp.ChatCompletionChunk]
] = await run_in_threadpool(llama.create_chat_completion, **kwargs)
if isinstance(iterator_or_completion, Iterator):
# EAFP: It's easier to ask for forgiveness than permission
@ -778,12 +771,13 @@ async def create_chat_completion(
send_chan, recv_chan = anyio.create_memory_object_stream(10)
return EventSourceResponse(
recv_chan, data_sender_callable=partial( # type: ignore
recv_chan,
data_sender_callable=partial( # type: ignore
get_event_publisher,
request=request,
inner_send_chan=send_chan,
iterator=iterator(),
)
),
)
else:
return iterator_or_completion