fix: pass correct type to chat handlers for chat completion logprobs

This commit is contained in:
Andrei Betlen 2024-04-10 03:41:55 -04:00
parent 060bfa64d5
commit bb65b4d764
2 changed files with 18 additions and 9 deletions

View file

@ -1664,7 +1664,8 @@ class Llama:
top_k=top_k,
min_p=min_p,
typical_p=typical_p,
logprobs=top_logprobs if logprobs else None,
logprobs=logprobs,
top_logprobs=top_logprobs,
stream=stream,
stop=stop,
seed=seed,

View file

@ -77,6 +77,8 @@ class LlamaChatCompletionHandler(Protocol):
mirostat_eta: float = 0.1,
logits_processor: Optional[llama.LogitsProcessorList] = None,
grammar: Optional[llama.LlamaGrammar] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
**kwargs, # type: ignore
) -> Union[
llama_types.CreateChatCompletionResponse,
@ -338,7 +340,7 @@ def _convert_completion_to_chat_function(
}
],
},
"logprobs": None,
"logprobs": completion["choices"][0]["logprobs"],
"finish_reason": "tool_calls",
}
],
@ -391,7 +393,7 @@ def _convert_completion_to_chat_function(
{
"index": 0,
"finish_reason": None,
"logprobs": None,
"logprobs": chunk["choices"][0]["logprobs"],
"delta": {
"role": None,
"content": None,
@ -426,7 +428,7 @@ def _convert_completion_to_chat_function(
{
"index": 0,
"finish_reason": None,
"logprobs": None,
"logprobs": chunk["choices"][0]["logprobs"],
"delta": {
"role": None,
"content": None,
@ -491,7 +493,6 @@ def chat_formatter_to_chat_completion_handler(
temperature: float = 0.2,
top_p: float = 0.95,
top_k: int = 40,
logprobs: int = 0,
min_p: float = 0.05,
typical_p: float = 1.0,
stream: bool = False,
@ -512,6 +513,8 @@ def chat_formatter_to_chat_completion_handler(
logits_processor: Optional[llama.LogitsProcessorList] = None,
grammar: Optional[llama.LlamaGrammar] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
**kwargs, # type: ignore
) -> Union[
llama_types.CreateChatCompletionResponse,
@ -581,7 +584,7 @@ def chat_formatter_to_chat_completion_handler(
top_k=top_k,
min_p=min_p,
typical_p=typical_p,
logprobs=logprobs,
logprobs=top_logprobs if logprobs else None,
stream=stream,
stop=stop,
seed=seed,
@ -1628,7 +1631,7 @@ def functionary_chat_handler(
}
],
},
"logprobs": None,
"logprobs": completion["choices"][0]["logprobs"],
"finish_reason": "tool_calls",
}
],
@ -2085,7 +2088,7 @@ def functionary_v1_v2_chat_handler(
choices=[
{
"index": 0,
"logprobs": None,
"logprobs": completion["choices"][0]["logprobs"],
"message": {
"role": "assistant",
"content": None if content == "" else content,
@ -2311,11 +2314,14 @@ def chatml_function_calling(
model: Optional[str] = None,
logits_processor: Optional[llama.LogitsProcessorList] = None,
grammar: Optional[llama.LlamaGrammar] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
**kwargs, # type: ignore
) -> Union[
llama_types.CreateChatCompletionResponse,
Iterator[llama_types.CreateChatCompletionStreamResponse],
]:
print(logprobs)
function_calling_template = (
"{% for message in messages %}"
"<|im_start|>{{ message.role }}\n"
@ -2437,6 +2443,7 @@ def chatml_function_calling(
model=model,
logits_processor=logits_processor,
grammar=grammar,
logprobs=top_logprobs if logprobs else None,
),
stream=stream,
)
@ -2549,6 +2556,7 @@ def chatml_function_calling(
typical_p=typical_p,
stream=stream,
stop=["<|im_end|>"],
logprobs=top_logprobs if logprobs else None,
max_tokens=None,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
@ -2660,7 +2668,7 @@ def chatml_function_calling(
{
"finish_reason": "tool_calls",
"index": 0,
"logprobs": None,
"logprobs": completion["choices"][0]["logprobs"],
"message": {
"role": "assistant",
"content": None,