diff --git a/llama_cpp/llama_types.py b/llama_cpp/llama_types.py index b770a01..bfc7342 100644 --- a/llama_cpp/llama_types.py +++ b/llama_cpp/llama_types.py @@ -60,6 +60,8 @@ class Completion(TypedDict): class ChatCompletionMessage(TypedDict): role: Literal["assistant", "user", "system"] content: str + user: NotRequired[str] + class ChatCompletionChoice(TypedDict): index: int diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index 0b7b1b2..ba2ca2f 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -185,7 +185,13 @@ def create_completion( completion_or_chunks = llama( **request.dict( exclude={ - "model" + "model", + "n", + "frequency_penalty", + "presence_penalty", + "best_of", + "logit_bias", + "user", } ) ) @@ -221,7 +227,7 @@ CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding) def create_embedding( request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama) ): - return llama.create_embedding(**request.dict(exclude={"model"})) + return llama.create_embedding(**request.dict(exclude={"model", "user"})) class ChatCompletionRequestMessage(BaseModel): @@ -283,7 +289,12 @@ def create_chat_completion( completion_or_chunks = llama.create_chat_completion( **request.dict( exclude={ - "model" + "model", + "n", + "presence_penalty", + "frequency_penalty", + "logit_bias", + "user", } ), )