misc: Format

This commit is contained in:
Andrei Betlen 2024-02-28 14:27:40 -05:00
parent 0d37ce52b1
commit 727d60c28a
5 changed files with 44 additions and 39 deletions

View file

@ -199,8 +199,8 @@ async def authenticate(
@router.post(
"/v1/completions",
summary="Completion",
dependencies=[Depends(authenticate)],
response_model= Union[
dependencies=[Depends(authenticate)],
response_model=Union[
llama_cpp.CreateCompletionResponse,
str,
],
@ -211,19 +211,19 @@ async def authenticate(
"application/json": {
"schema": {
"anyOf": [
{"$ref": "#/components/schemas/CreateCompletionResponse"}
{"$ref": "#/components/schemas/CreateCompletionResponse"}
],
"title": "Completion response, when stream=False",
}
},
"text/event-stream":{
"schema": {
"type": "string",
"title": "Server Side Streaming response, when stream=True. " +
"See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501
"example": """data: {... see CreateCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]"""
"text/event-stream": {
"schema": {
"type": "string",
"title": "Server Side Streaming response, when stream=True. "
+ "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501
"example": """data: {... see CreateCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]""",
}
}
},
},
}
},
@ -290,7 +290,7 @@ async def create_completion(
inner_send_chan=send_chan,
iterator=iterator(),
),
sep='\n',
sep="\n",
)
else:
return iterator_or_completion
@ -310,10 +310,10 @@ async def create_embedding(
@router.post(
"/v1/chat/completions", summary="Chat", dependencies=[Depends(authenticate)],
response_model= Union[
llama_cpp.ChatCompletion, str
],
"/v1/chat/completions",
summary="Chat",
dependencies=[Depends(authenticate)],
response_model=Union[llama_cpp.ChatCompletion, str],
responses={
"200": {
"description": "Successful Response",
@ -321,19 +321,21 @@ async def create_embedding(
"application/json": {
"schema": {
"anyOf": [
{"$ref": "#/components/schemas/CreateChatCompletionResponse"}
{
"$ref": "#/components/schemas/CreateChatCompletionResponse"
}
],
"title": "Completion response, when stream=False",
}
},
"text/event-stream":{
"schema": {
"type": "string",
"title": "Server Side Streaming response, when stream=True" +
"See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501
"example": """data: {... see CreateChatCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]"""
"text/event-stream": {
"schema": {
"type": "string",
"title": "Server Side Streaming response, when stream=True"
+ "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501
"example": """data: {... see CreateChatCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]""",
}
}
},
},
}
},
@ -383,7 +385,7 @@ async def create_chat_completion(
inner_send_chan=send_chan,
iterator=iterator(),
),
sep='\n',
sep="\n",
)
else:
return iterator_or_completion

View file

@ -22,6 +22,7 @@ from llama_cpp.server.types import (
CreateChatCompletionRequest,
)
class ErrorResponse(TypedDict):
"""OpenAI style error response"""
@ -75,7 +76,7 @@ class ErrorResponseFormatters:
(completion_tokens or 0) + prompt_tokens,
prompt_tokens,
completion_tokens,
), # type: ignore
), # type: ignore
type="invalid_request_error",
param="messages",
code="context_length_exceeded",
@ -207,4 +208,3 @@ class RouteErrorHandler(APIRoute):
)
return custom_route_handler

View file

@ -88,15 +88,15 @@ class LlamaProxy:
assert (
settings.hf_tokenizer_config_path is not None
), "hf_tokenizer_config_path must be set for hf-tokenizer-config"
chat_handler = (
llama_cpp.llama_chat_format.hf_tokenizer_config_to_chat_completion_handler(
json.load(open(settings.hf_tokenizer_config_path))
)
chat_handler = llama_cpp.llama_chat_format.hf_tokenizer_config_to_chat_completion_handler(
json.load(open(settings.hf_tokenizer_config_path))
)
tokenizer: Optional[llama_cpp.BaseLlamaTokenizer] = None
if settings.hf_pretrained_model_name_or_path is not None:
tokenizer = llama_tokenizer.LlamaHFTokenizer.from_pretrained(settings.hf_pretrained_model_name_or_path)
tokenizer = llama_tokenizer.LlamaHFTokenizer.from_pretrained(
settings.hf_pretrained_model_name_or_path
)
draft_model = None
if settings.draft_model is not None:
@ -120,17 +120,20 @@ class LlamaProxy:
kv_overrides[key] = float(value)
else:
raise ValueError(f"Unknown value type {value_type}")
import functools
kwargs = {}
if settings.hf_model_repo_id is not None:
create_fn = functools.partial(llama_cpp.Llama.from_pretrained, repo_id=settings.hf_model_repo_id, filename=settings.model)
create_fn = functools.partial(
llama_cpp.Llama.from_pretrained,
repo_id=settings.hf_model_repo_id,
filename=settings.model,
)
else:
create_fn = llama_cpp.Llama
kwargs["model_path"] = settings.model
_model = create_fn(
**kwargs,

View file

@ -74,7 +74,9 @@ class ModelSettings(BaseSettings):
ge=0,
description="The number of threads to use when batch processing.",
)
rope_scaling_type: int = Field(default=llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED)
rope_scaling_type: int = Field(
default=llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
)
rope_freq_base: float = Field(default=0.0, description="RoPE base frequency")
rope_freq_scale: float = Field(
default=0.0, description="RoPE frequency scaling factor"
@ -193,6 +195,4 @@ class Settings(ServerSettings, ModelSettings):
class ConfigFileSettings(ServerSettings):
"""Configuration file format settings."""
models: List[ModelSettings] = Field(
default=[], description="Model configs"
)
models: List[ModelSettings] = Field(default=[], description="Model configs")

View file

@ -110,7 +110,7 @@ class CreateCompletionRequest(BaseModel):
default=None,
description="A suffix to append to the generated text. If None, no suffix is appended. Useful for chatbots.",
)
max_tokens: Optional[int] = Field(
max_tokens: Optional[int] = Field(
default=16, ge=0, description="The maximum number of tokens to generate."
)
temperature: float = temperature_field