From 8a60c7bc8cae7aa9770eeac0f482d39350763a6f Mon Sep 17 00:00:00 2001 From: Jeffrey Fong Date: Mon, 18 Mar 2024 22:40:57 +0800 Subject: [PATCH] fix: Fix and optimize functionary chat handler (#1282) * fix functionary chat logic * further fixes --------- Co-authored-by: Andrei --- llama_cpp/llama_chat_format.py | 131 ++++++++++++++++----------------- 1 file changed, 65 insertions(+), 66 deletions(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 81ca552..c89cce8 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -1596,13 +1596,15 @@ def functionary_v1_v2_chat_handler( function_call = ( tool_choice if isinstance(tool_choice, str) else tool_choice["function"] ) + else: + function_call = "auto" prompt = prepare_messages_for_inference( messages, tokenizer, version, functions, tools ) # If no tools/functions are provided - if function_call is None and (functions is None or len(functions) == 0): + if function_call == "none" or functions is None or len(functions) == 0: if version == "v1": stop = END_ASSISTANT_TOKEN else: @@ -1630,6 +1632,7 @@ def functionary_v1_v2_chat_handler( logits_processor=logits_processor, grammar=grammar, ) + completion_or_completion_chunks["choices"][0]["text"] = completion_or_completion_chunks["choices"][0]["text"].lstrip() return _convert_completion_to_chat(completion_or_completion_chunks, stream=stream) # type: ignore assert stream is False # TODO: support stream mode @@ -1692,13 +1695,12 @@ def functionary_v1_v2_chat_handler( return completion + content = "" function_calls, function_bodies = [], [] if version == "v1": # If no or "auto" tool_choice/function_call - if function_call is None or ( - isinstance(function_call, str) and function_call == "auto" - ): + if isinstance(function_call, str) and function_call == "auto": stops = ["\n", END_ASSISTANT_TOKEN] # If tool_choice/function_call is "none" elif isinstance(function_call, str) and function_call == "none": @@ -1747,70 +1749,67 @@ def functionary_v1_v2_chat_handler( else: function_bodies.append(completion_text.strip()) else: - # Loop until all parallel function calls are generated - while True: - # If no or "auto" tool_choice/function_call - if function_call is None or ( - isinstance(function_call, str) and function_call == "auto" - ): - grammar = None - stops = CONTENT_TOKEN - # If tool_choice/function_call is "none" - elif isinstance(function_call, str) and function_call == "none": - prompt = ( - prepare_messages_for_inference(messages, tokenizer, version, [], []) - + "all\n<|content|>" - ) - stops = STOP_TOKEN - # If tool_choice/function_call is provided - elif isinstance(function_call, dict): - prompt += f"{function_call['name']}\n{CONTENT_TOKEN}" - stops = STOP_TOKEN - function_call = function_call["name"] - function_calls.append(function_call) - grammar = get_grammar(function_call) - else: - prompt = prompt - stops = STOP_TOKEN - + # If tool_choice/function_call is "none" + if isinstance(function_call, str) and function_call == "none": + prompt = ( + prepare_messages_for_inference(messages, tokenizer, version, [], []) + + "all\n<|content|>" + ) + stops = [STOP_TOKEN, FROM_TOKEN] + completion = create_completion(stop=stops) + completion["choices"][0]["text"] = completion["choices"][0]["text"].strip() + return _convert_completion_to_chat(completion, stream=stream) # type: ignore + # If tool_choice/function_call is provided + elif isinstance(function_call, dict): + prompt += f"{function_call['name']}\n{CONTENT_TOKEN}" + function_call = function_call["name"] + function_calls.append(function_call) + grammar = get_grammar(function_call) + stops = [STOP_TOKEN, FROM_TOKEN] completion = create_completion(stop=stops) completion_text = completion["choices"][0]["text"] - - # If the generation does not involve a function call - if prompt.endswith("all\n<|content|>") and not completion_text.startswith( - "all" - ): - return _convert_completion_to_chat(completion, stream=stream) # type: ignore - # Generate model response if the model decides not to call any function - elif prompt.endswith(RECIPIENT_TOKEN) and completion_text.startswith("all"): - prompt += completion_text + CONTENT_TOKEN - completion = create_completion(stop=STOP_TOKEN) - return _convert_completion_to_chat(completion, stream=stream) # type: ignore - # Generate parameters if model decides to call a function - elif prompt.endswith(RECIPIENT_TOKEN): - function_calls.append(completion_text[:-1]) - grammar = get_grammar(function_calls[-1]) - completion = create_completion(stop=[STOP_TOKEN, "\n"]) - function_bodies.append(completion["choices"][0]["text"].strip()) - prompt += f"{function_calls[-1]}\n{CONTENT_TOKEN}{function_bodies[-1]}" + function_bodies.append(completion_text.strip()) + # If "auto" or no tool_choice/function_call + elif isinstance(function_call, str) and function_call == "auto": + while True: + # Generate function name first grammar = None - - # Try to generate the beginning of next turn - # If empty completion, break from loop - next_turn_completion_text = create_completion( - stop=[STOP_TOKEN, RECIPIENT_TOKEN] - )["choices"][0]["text"] - if len(next_turn_completion_text) > 0: - prompt += f"\n{FROM_TOKEN}assistant\n{RECIPIENT_TOKEN}" + stops = CONTENT_TOKEN + completion = create_completion(stop=stops) + completion_text = completion["choices"][0]["text"] + function_name = completion_text.strip() + if function_name == "all": + prompt += "all\n<|content|>" else: - break - # Break from loop if tool_choice/function_call is provided as a dict - else: - function_bodies.append(completion_text.strip()) - break + function_call = completion_text.strip() + prompt += f"{function_call}\n<|content|>" + function_calls.append(function_call) + grammar = get_grammar(function_call) + # Generate content + stops = [RECIPIENT_TOKEN, STOP_TOKEN] + completion = create_completion(stop=stops) + completion_text = completion["choices"][0]["text"] + if function_name == "all": + content += completion_text.removesuffix("\n<|from|>assistant\n").removesuffix("\n<|from|> assistant\n") + content = content.lstrip() + # Check whether the model wants to generate another turn + if "<|from|> assistant" in completion_text or "<|from|>assistant" in completion_text: + cleaned_completion_text = completion_text.removesuffix("\n<|from|>assistant\n").removesuffix("\n<|from|> assistant\n").strip() + prompt += f"{cleaned_completion_text}\n<|from|>assistant\n<|recipient|>" + else: + break + else: + function_bodies.append(completion_text.strip()) + # Check whether the model wants to generate another turn + prompt += completion_text.strip() + grammar = None + completion = create_completion(stop=stops) + if "<|from|> assistant" in completion["choices"][0]["text"] or "<|from|>assistant" in completion["choices"][0]["text"]: + prompt += "\n<|from|>assistant\n<|recipient|>" + else: + break assert "usage" in completion - assert len(function_calls) > 0 assert len(function_calls) == len(function_bodies) tool_calls = [] @@ -1843,14 +1842,14 @@ def functionary_v1_v2_chat_handler( "index": 0, "message": { "role": "assistant", - "content": None, + "content": None if content == "" else content, "function_call": { "name": tool_calls[0]["function"]["name"], "arguments": tool_calls[0]["function"]["arguments"], - }, - "tool_calls": tool_calls, + } if len(tool_calls) > 0 else None, + "tool_calls": tool_calls if len(tool_calls) > 0 else None, }, - "finish_reason": "tool_calls", + "finish_reason": "tool_calls" if len(tool_calls) > 0 else "stop", } ], usage=completion["usage"],