fix: missing logprobs in response, incorrect response type for functionary, minor type issues. Closes #1328 Closes #1314

This commit is contained in:
Andrei Betlen 2024-04-05 10:50:49 -04:00
parent 9111b6e03a
commit 1ae3abbcc3

View file

@ -6,7 +6,7 @@ import ctypes
import dataclasses
import random
import string
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union, Protocol
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union, Protocol, cast
import jinja2
@ -338,6 +338,7 @@ def _convert_completion_to_chat_function(
}
],
},
"logprobs": None,
"finish_reason": "tool_calls",
}
],
@ -1191,7 +1192,6 @@ def format_mistral_instruct(
elif (
message["role"] == "assistant"
and message["content"] is not None
and isinstance(message["content"], str)
):
prompt += " [/INST]" + message["content"] + eos
prompt += " [/INST]"
@ -1263,7 +1263,7 @@ def format_gemma(
**kwargs: Any,
) -> ChatFormatterResponse:
system_message = _get_system_message(messages)
if system_message is not None and system_message != "":
if system_message != "":
logger.debug(
"`role='system'` messages are not allowed on Google's Gemma models."
)
@ -1628,6 +1628,7 @@ def functionary_chat_handler(
}
],
},
"logprobs": None,
"finish_reason": "tool_calls",
}
],
@ -1909,14 +1910,14 @@ def functionary_v1_v2_chat_handler(
return grammar
def create_completion(stop):
completion: llama_types.Completion = llama.create_completion(
completion = cast(llama_types.Completion, llama.create_completion(
prompt=prompt,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
typical_p=typical_p,
stream=stream,
stream=False,
stop=stop,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
@ -1929,7 +1930,7 @@ def functionary_v1_v2_chat_handler(
model=model,
logits_processor=logits_processor,
grammar=grammar,
)
))
return completion
@ -2050,7 +2051,7 @@ def functionary_v1_v2_chat_handler(
assert "usage" in completion
assert len(function_calls) == len(function_bodies)
tool_calls = []
tool_calls: List[llama_types.ChatCompletionMessageToolCall] = []
for function_call, function_body in zip(function_calls, function_bodies):
tool_calls.append(
{
@ -2070,6 +2071,12 @@ def functionary_v1_v2_chat_handler(
)
# TODO: support stream mode
function_call_dict: Union[Dict[str, str], Dict[Literal["function_call"], llama_types.ChatCompletionRequestAssistantMessageFunctionCall]] = {
"function_call": {
"name": tool_calls[0]["function"]["name"],
"arguments": tool_calls[0]["function"]["arguments"],
}
} if len(tool_calls) == 1 else {}
return llama_types.CreateChatCompletionResponse(
id="chat" + completion["id"],
object="chat.completion",
@ -2078,14 +2085,12 @@ def functionary_v1_v2_chat_handler(
choices=[
{
"index": 0,
"logprobs": None,
"message": {
"role": "assistant",
"content": None if content == "" else content,
"function_call": {
"name": tool_calls[0]["function"]["name"],
"arguments": tool_calls[0]["function"]["arguments"],
} if len(tool_calls) > 0 else None,
"tool_calls": tool_calls if len(tool_calls) > 0 else None,
"tool_calls": tool_calls,
**function_call_dict,
},
"finish_reason": "tool_calls" if len(tool_calls) > 0 else "stop",
}
@ -2565,8 +2570,8 @@ def chatml_function_calling(
tool_name = text[len("functions.") :]
tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None)
if not stream:
completions = []
completions_tool_name = []
completions: List[llama_types.CreateCompletionResponse] = []
completions_tool_name: List[str] = []
while tool is not None:
prompt += f"functions.{tool_name}:\n"
try:
@ -2603,6 +2608,7 @@ def chatml_function_calling(
logits_processor=logits_processor,
grammar=grammar,
)
completion_or_chunks = cast(llama_types.CreateCompletionResponse, completion_or_chunks)
completions.append(completion_or_chunks)
completions_tool_name.append(tool_name)
prompt += completion_or_chunks["choices"][0]["text"]
@ -2631,6 +2637,7 @@ def chatml_function_calling(
follow_up_gbnf_tool_grammar, verbose=llama.verbose
),
)
response = cast(llama_types.CreateCompletionResponse, response)
tool_name = response["choices"][0]["text"][len("functions.") :]
tool = next(
@ -2638,7 +2645,7 @@ def chatml_function_calling(
)
# Merge completions
function_call = {
function_call_dict: Union[Dict[str, str], Dict[Literal["function_call"], llama_types.ChatCompletionRequestAssistantMessageFunctionCall]] = {
"function_call": {
"name": tool_name,
"arguments": completions[0]["choices"][0]["text"],
@ -2653,6 +2660,7 @@ def chatml_function_calling(
{
"finish_reason": "tool_calls",
"index": 0,
"logprobs": None,
"message": {
"role": "assistant",
"content": None,
@ -2673,20 +2681,22 @@ def chatml_function_calling(
zip(completions_tool_name, completions)
)
],
**function_call
**function_call_dict
},
}
],
"usage": {
"completion_tokens": sum(
completion["usage"]["completion_tokens"]
completion["usage"]["completion_tokens"] if "usage" in completion else 0
for completion in completions
),
"prompt_tokens": sum(
completion["usage"]["prompt_tokens"] for completion in completions
completion["usage"]["prompt_tokens"] if "usage" in completion else 0
for completion in completions
),
"total_tokens": sum(
completion["usage"]["total_tokens"] for completion in completions
completion["usage"]["total_tokens"] if "usage" in completion else 0
for completion in completions
),
},
}