feat: Add tools/functions variables to Jinja2ChatFormatter, add function response formatting for all simple chat formats (#1273)

* Add tools/functions variables to Jinja2ChatFormatter

Also fixed missing tools/tool_choices parameters in chat_formatter_to_chat_completion_handler().

* Set grammar when doing explicit function calling

* Add function / tool response for all chat formats

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
Andrei 2024-03-19 04:55:57 -04:00 committed by GitHub
parent 18d7ce918f
commit 60d8498f21
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -188,6 +188,10 @@ class Jinja2ChatFormatter(ChatFormatter):
self,
*,
messages: List[llama_types.ChatCompletionRequestMessage],
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
**kwargs: Any,
) -> ChatFormatterResponse:
def raise_exception(message: str):
@ -199,6 +203,10 @@ class Jinja2ChatFormatter(ChatFormatter):
bos_token=self.bos_token,
raise_exception=raise_exception,
add_generation_prompt=self.add_generation_prompt,
functions=functions,
function_call=function_call,
tools=tools,
tool_choice=tool_choice,
)
return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token])
@ -288,6 +296,183 @@ def _convert_completion_to_chat(
return _convert_text_completion_to_chat(completion)
def _convert_completion_to_chat_function(
tool_name: str,
completion_or_chunks: Union[
llama_types.CreateCompletionResponse,
Iterator[llama_types.CreateCompletionStreamResponse],
],
stream: bool,
):
if not stream:
completion: llama_types.CreateCompletionResponse = completion_or_chunks # type: ignore
assert "usage" in completion
tool_id = "call_" + "_0_" + tool_name + "_" + completion["id"]
# TODO: Fix for legacy function calls
chat_completion: llama_types.CreateChatCompletionResponse = {
"id": "chat" + completion["id"],
"object": "chat.completion",
"created": completion["created"],
"model": completion["model"],
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": None,
"function_call": {
"name": tool_name,
"arguments": completion["choices"][0]["text"],
},
"tool_calls": [
{
"id": tool_id,
"type": "function",
"function": {
"name": tool_name,
"arguments": completion["choices"][0]["text"],
},
}
],
},
"finish_reason": "tool_calls",
}
],
"usage": completion["usage"],
}
return chat_completion
else:
chunks: Iterator[llama_types.CreateCompletionStreamResponse] = completion_or_chunks # type: ignore
def _stream_response_to_function_stream(
chunks: Iterator[llama_types.CreateCompletionStreamResponse],
) -> Iterator[llama_types.CreateChatCompletionStreamResponse]:
# blank first message
first = True
id_ = None
created = None
model = None
tool_id = None
for chunk in chunks:
if first:
id_ = "chat" + chunk["id"]
created = chunk["created"]
model = chunk["model"]
tool_id = "call_" + "_0_" + tool_name + "_" + chunk["id"]
yield {
"id": id_,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [
{
"index": 0,
"finish_reason": None,
"logprobs": None,
"delta": {
"role": "assistant",
"content": None,
"function_call": None,
"tool_calls": None,
},
}
],
}
yield {
"id": "chat" + chunk["id"],
"object": "chat.completion.chunk",
"created": chunk["created"],
"model": chunk["model"],
"choices": [
{
"index": 0,
"finish_reason": None,
"logprobs": None,
"delta": {
"role": None,
"content": None,
"function_call": {
"name": tool_name,
"arguments": chunk["choices"][0]["text"],
},
"tool_calls": [
{
"index": 0,
"id": tool_id,
"type": "function",
"function": {
"name": tool_name,
"arguments": "",
},
}
],
},
}
],
}
first = False
continue
assert tool_id is not None
yield {
"id": "chat" + chunk["id"],
"object": "chat.completion.chunk",
"created": chunk["created"],
"model": chunk["model"],
"choices": [
{
"index": 0,
"finish_reason": None,
"logprobs": None,
"delta": {
"role": None,
"content": None,
"function_call": {
"name": tool_name,
"arguments": chunk["choices"][0]["text"],
},
"tool_calls": [
{
"index": 0,
"id": tool_id,
"type": "function",
"function": {
"name": tool_name,
"arguments": chunk["choices"][0][
"text"
],
},
}
],
},
}
],
}
if id_ is not None and created is not None and model is not None:
yield {
"id": id_,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [
{
"index": 0,
"finish_reason": "tool_calls",
"logprobs": None,
"delta": {
"role": None,
"content": None,
"function_call": None,
"tool_calls": None,
},
}
],
}
return _stream_response_to_function_stream(chunks)
def chat_formatter_to_chat_completion_handler(
chat_formatter: ChatFormatter,
) -> LlamaChatCompletionHandler:
@ -331,6 +516,8 @@ def chat_formatter_to_chat_completion_handler(
messages=messages,
functions=functions,
function_call=function_call,
tools=tools,
tool_choice=tool_choice,
)
prompt = result.prompt
if result.stop is not None:
@ -341,6 +528,47 @@ def chat_formatter_to_chat_completion_handler(
if response_format is not None and response_format["type"] == "json_object":
grammar = _grammar_for_response_format(response_format, verbose=llama.verbose)
# Convert legacy functions to tools
if functions is not None:
tools = [
{
"type": "function",
"function": function,
}
for function in functions
]
# Convert legacy function_call to tool_choice
if function_call is not None:
if isinstance(function_call, str) and (
function_call == "none" or function_call == "auto"
):
tool_choice = function_call
if isinstance(function_call, dict) and "name" in function_call:
tool_choice = {
"type": "function",
"function": {
"name": function_call["name"],
},
}
tool = None
if tool_choice is not None and isinstance(tool_choice, dict) and tools is not None:
name = tool_choice["function"]["name"]
tool = next((t for t in tools if t["function"]["name"] == name), None)
if tool is None:
raise ValueError(f"Tool choice '{name}' not found in tools.")
schema = tool["function"]["parameters"]
try:
# create grammar from json schema
grammar = llama_grammar.LlamaGrammar.from_json_schema(
json.dumps(schema), verbose=llama.verbose
)
except Exception as e:
grammar = llama_grammar.LlamaGrammar.from_string(
llama_grammar.JSON_GBNF, verbose=llama.verbose
)
completion_or_chunks = llama.create_completion(
prompt=prompt,
temperature=temperature,
@ -364,6 +592,11 @@ def chat_formatter_to_chat_completion_handler(
grammar=grammar,
logit_bias=logit_bias,
)
if tool is not None:
tool_name = tool["function"]["name"]
return _convert_completion_to_chat_function(
tool_name, completion_or_chunks, stream
)
return _convert_completion_to_chat(completion_or_chunks, stream=stream)
return chat_completion_handler
@ -2198,181 +2431,6 @@ def chatml_function_calling(
stream=stream,
)
def _convert_completion_to_chat_function(
tool_name: str,
completion_or_chunks: Union[
llama_types.CreateCompletionResponse,
Iterator[llama_types.CreateCompletionStreamResponse],
],
stream: bool,
):
if not stream:
completion: llama_types.CreateCompletionResponse = completion_or_chunks # type: ignore
assert "usage" in completion
tool_id = "call_" + "_0_" + tool_name + "_" + completion["id"]
# TODO: Fix for legacy function calls
chat_completion: llama_types.CreateChatCompletionResponse = {
"id": "chat" + completion["id"],
"object": "chat.completion",
"created": completion["created"],
"model": completion["model"],
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": None,
"function_call": {
"name": tool_name,
"arguments": completion["choices"][0]["text"],
},
"tool_calls": [
{
"id": tool_id,
"type": "function",
"function": {
"name": tool_name,
"arguments": completion["choices"][0]["text"],
},
}
],
},
"finish_reason": "tool_calls",
}
],
"usage": completion["usage"],
}
return chat_completion
else:
chunks: Iterator[llama_types.CreateCompletionStreamResponse] = completion_or_chunks # type: ignore
def _stream_response_to_function_stream(
chunks: Iterator[llama_types.CreateCompletionStreamResponse],
) -> Iterator[llama_types.CreateChatCompletionStreamResponse]:
# blank first message
first = True
id_ = None
created = None
model = None
tool_id = None
for chunk in chunks:
if first:
id_ = "chat" + chunk["id"]
created = chunk["created"]
model = chunk["model"]
tool_id = "call_" + "_0_" + tool_name + "_" + chunk["id"]
yield {
"id": id_,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [
{
"index": 0,
"finish_reason": None,
"logprobs": None,
"delta": {
"role": "assistant",
"content": None,
"function_call": None,
"tool_calls": None,
},
}
],
}
yield {
"id": "chat" + chunk["id"],
"object": "chat.completion.chunk",
"created": chunk["created"],
"model": chunk["model"],
"choices": [
{
"index": 0,
"finish_reason": None,
"logprobs": None,
"delta": {
"role": None,
"content": None,
"function_call": {
"name": tool_name,
"arguments": chunk["choices"][0]["text"],
},
"tool_calls": [
{
"index": 0,
"id": tool_id,
"type": "function",
"function": {
"name": tool_name,
"arguments": "",
},
}
],
},
}
],
}
first = False
continue
assert tool_id is not None
yield {
"id": "chat" + chunk["id"],
"object": "chat.completion.chunk",
"created": chunk["created"],
"model": chunk["model"],
"choices": [
{
"index": 0,
"finish_reason": None,
"logprobs": None,
"delta": {
"role": None,
"content": None,
"function_call": {
"name": tool_name,
"arguments": chunk["choices"][0]["text"],
},
"tool_calls": [
{
"index": 0,
"id": tool_id,
"type": "function",
"function": {
"name": tool_name,
"arguments": chunk["choices"][0][
"text"
],
},
}
],
},
}
],
}
if id_ is not None and created is not None and model is not None:
yield {
"id": id_,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [
{
"index": 0,
"finish_reason": "tool_calls",
"logprobs": None,
"delta": {
"role": None,
"content": None,
"function_call": None,
"tool_calls": None,
},
}
],
}
return _stream_response_to_function_stream(chunks)
# Case 2: Tool choice by user
if isinstance(tool_choice, dict):
tool_name = tool_choice["function"]["name"]