From 3af7b21ff1aec2ce4c2f8559e51de25907ed943d Mon Sep 17 00:00:00 2001 From: Andrei Date: Fri, 3 Nov 2023 02:12:14 -0400 Subject: [PATCH] Add functionary support (#784) * Add common grammars and json-schema-to-grammar utility function from llama.cpp * Pass functions to format function * Add basic functionary formatting * Add LlamaChatHandler for more complex chat use cases * Add function calling example notebook * Add support for regular chat completions alongside function calling --- examples/notebooks/Functions.ipynb | 225 +++++++++++++++++ llama_cpp/llama.py | 94 +------ llama_cpp/llama_chat_format.py | 392 ++++++++++++++++++++++++++++- llama_cpp/llama_grammar.py | 315 ++++++++++++++++++++++- llama_cpp/server/app.py | 9 +- 5 files changed, 936 insertions(+), 99 deletions(-) create mode 100644 examples/notebooks/Functions.ipynb diff --git a/examples/notebooks/Functions.ipynb b/examples/notebooks/Functions.ipynb new file mode 100644 index 0000000..4d27bb0 --- /dev/null +++ b/examples/notebooks/Functions.ipynb @@ -0,0 +1,225 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\n", + " \"id\": \"chatcmpl-a6db1bbb-a128-4c28-88fe-30717ec806b2\",\n", + " \"object\": \"chat.completion\",\n", + " \"created\": 1698989577,\n", + " \"model\": \"gpt-3.5-turbo-0613\",\n", + " \"choices\": [\n", + " {\n", + " \"index\": 0,\n", + " \"message\": {\n", + " \"role\": \"assistant\",\n", + " \"content\": \"The current weather in Boston is sunny with a temperature of 72 degrees\"\n", + " },\n", + " \"finish_reason\": \"length\"\n", + " }\n", + " ],\n", + " \"usage\": {\n", + " \"prompt_tokens\": 135,\n", + " \"completion_tokens\": 16,\n", + " \"total_tokens\": 151\n", + " }\n", + "}\n" + ] + } + ], + "source": [ + "import openai\n", + "import json\n", + "\n", + "openai.api_key = \"sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\" # can be anything\n", + "openai.api_base = \"http://100.64.159.73:8000/v1\"\n", + "\n", + "# Example dummy function hard coded to return the same weather\n", + "# In production, this could be your backend API or an external API\n", + "def get_current_weather(location, unit=\"fahrenheit\"):\n", + " \"\"\"Get the current weather in a given location\"\"\"\n", + " weather_info = {\n", + " \"location\": location,\n", + " \"temperature\": \"72\",\n", + " \"unit\": unit,\n", + " \"forecast\": [\"sunny\", \"windy\"],\n", + " }\n", + " return json.dumps(weather_info)\n", + "\n", + "def run_conversation():\n", + " # Step 1: send the conversation and available functions to GPT\n", + " messages = [{\"role\": \"user\", \"content\": \"What's the weather like in Boston?\"}]\n", + " functions = [\n", + " {\n", + " \"name\": \"get_current_weather\",\n", + " \"description\": \"Get the current weather in a given location\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"location\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The city and state, e.g. San Francisco, CA\",\n", + " },\n", + " \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"]},\n", + " },\n", + " \"required\": [\"location\"],\n", + " },\n", + " }\n", + " ]\n", + " response = openai.ChatCompletion.create(\n", + " model=\"gpt-3.5-turbo-0613\",\n", + " messages=messages,\n", + " functions=functions,\n", + " function_call=\"auto\", # auto is default, but we'll be explicit\n", + " )\n", + " response_message = response[\"choices\"][0][\"message\"]\n", + "\n", + " # Step 2: check if GPT wanted to call a function\n", + " if response_message.get(\"function_call\"):\n", + " # Step 3: call the function\n", + " # Note: the JSON response may not always be valid; be sure to handle errors\n", + " available_functions = {\n", + " \"get_current_weather\": get_current_weather,\n", + " } # only one function in this example, but you can have multiple\n", + " function_name = response_message[\"function_call\"][\"name\"]\n", + " fuction_to_call = available_functions[function_name]\n", + " function_args = json.loads(response_message[\"function_call\"][\"arguments\"])\n", + " function_response = fuction_to_call(\n", + " location=function_args.get(\"location\"),\n", + " unit=function_args.get(\"unit\"),\n", + " )\n", + "\n", + " # Step 4: send the info on the function call and function response to GPT\n", + " messages.append(response_message) # extend conversation with assistant's reply\n", + " messages.append(\n", + " {\n", + " \"role\": \"function\",\n", + " \"name\": function_name,\n", + " \"content\": function_response,\n", + " }\n", + " ) # extend conversation with function response\n", + " second_response = openai.ChatCompletion.create(\n", + " model=\"gpt-3.5-turbo-0613\",\n", + " messages=messages,\n", + " ) # get a new response from GPT where it can see the function response\n", + " return second_response\n", + " else:\n", + " print(response)\n", + " print(\"No function\")\n", + "\n", + "print(run_conversation())" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "name='Jason' age=25\n" + ] + } + ], + "source": [ + "from pydantic import BaseModel\n", + "from instructor import patch\n", + "\n", + "patch()\n", + "\n", + "class UserDetail(BaseModel):\n", + " name: str\n", + " age: int\n", + "\n", + "user: UserDetail = openai.ChatCompletion.create(\n", + " model=\"gpt-3.5-turbo\",\n", + " response_model=UserDetail,\n", + " messages=[\n", + " {\"role\": \"user\", \"content\": \"Extract Jason is 25 years old\"},\n", + " ]\n", + ")\n", + "print(user)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\n", + " \"id\": \"chatcmpl-59bcefad-9df5-4d6b-802c-5537b3e9044e\",\n", + " \"object\": \"chat.completion\",\n", + " \"created\": 1698989585,\n", + " \"model\": \"gpt-3.5-turbo-0613\",\n", + " \"choices\": [\n", + " {\n", + " \"index\": 0,\n", + " \"message\": {\n", + " \"role\": \"assistant\",\n", + " \"content\": \"I don't have up-to-date information on the current weather conditions\"\n", + " },\n", + " \"finish_reason\": \"length\"\n", + " }\n", + " ],\n", + " \"usage\": {\n", + " \"prompt_tokens\": 62,\n", + " \"completion_tokens\": 16,\n", + " \"total_tokens\": 78\n", + " }\n", + "}\n" + ] + } + ], + "source": [ + "response = openai.ChatCompletion.create(\n", + " model=\"gpt-3.5-turbo-0613\",\n", + " messages=[\n", + " {\"role\": \"user\", \"content\": \"What's the weather like in Boston?\"}\n", + " ]\n", + ")\n", + "print(response)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "python-3.8.10", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5+" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index e9b1999..3e8cf58 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -24,7 +24,7 @@ import ctypes from . import llama_cpp from .llama_types import * from .llama_grammar import LlamaGrammar -from . import llama_chat_format +import llama_cpp.llama_chat_format as llama_chat_format import numpy as np import numpy.typing as npt @@ -428,7 +428,7 @@ class Llama: if self.verbose: print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr) - + self.chat_format = chat_format self._n_vocab = self.n_vocab() @@ -1539,78 +1539,6 @@ class Llama: grammar=grammar, ) - def _convert_text_completion_to_chat( - self, completion: Completion - ) -> ChatCompletion: - return { - "id": "chat" + completion["id"], - "object": "chat.completion", - "created": completion["created"], - "model": completion["model"], - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": completion["choices"][0]["text"], - }, - "finish_reason": completion["choices"][0]["finish_reason"], - } - ], - "usage": completion["usage"], - } - - def _convert_text_completion_chunks_to_chat( - self, - chunks: Iterator[CompletionChunk], - ) -> Iterator[ChatCompletionChunk]: - for i, chunk in enumerate(chunks): - if i == 0: - yield { - "id": "chat" + chunk["id"], - "model": chunk["model"], - "created": chunk["created"], - "object": "chat.completion.chunk", - "choices": [ - { - "index": 0, - "delta": { - "role": "assistant", - }, - "finish_reason": None, - } - ], - } - yield { - "id": "chat" + chunk["id"], - "model": chunk["model"], - "created": chunk["created"], - "object": "chat.completion.chunk", - "choices": [ - { - "index": 0, - "delta": { - "content": chunk["choices"][0]["text"], - } - if chunk["choices"][0]["finish_reason"] is None - else {}, - "finish_reason": chunk["choices"][0]["finish_reason"], - } - ], - } - - def _convert_completion_to_chat( - self, - completion_or_chunks: Union[Completion, Iterator[CompletionChunk]], - stream: bool = False, - ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]: - if stream: - chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore - return self._convert_text_completion_chunks_to_chat(chunks) - else: - completion: Completion = completion_or_chunks # type: ignore - return self._convert_text_completion_to_chat(completion) - def create_chat_completion( self, messages: List[ChatCompletionRequestMessage], @@ -1648,19 +1576,12 @@ class Llama: Returns: Generated chat completion or a stream of chat completion chunks. """ - - format = llama_chat_format.get_chat_format(self.chat_format) - result = format( + handler = llama_chat_format.get_chat_completion_handler(self.chat_format) + return handler( + self, messages=messages, - ) - prompt = result.prompt - if result.stop is not None: - stop = [] if stop is None else [stop] if isinstance(stop, str) else stop - rstop = result.stop if isinstance(result.stop, list) else [result.stop] - stop = stop + rstop - - completion_or_chunks = self.create_completion( - prompt=prompt, + functions=functions, + function_call=function_call, temperature=temperature, top_p=top_p, top_k=top_k, @@ -1678,7 +1599,6 @@ class Llama: logits_processor=logits_processor, grammar=grammar, ) - return self._convert_completion_to_chat(completion_or_chunks, stream=stream) # type: ignore def _free_model(self, *, _lbatch_free=llama_cpp._lib.llama_batch_free, _lfree_model=llama_cpp._lib.llama_free_model, _free=llama_cpp._lib.llama_free): batch = getattr(self, 'batch', None) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 518acc5..f92793d 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -1,6 +1,53 @@ +from __future__ import annotations + import dataclasses -from typing import Any, Dict, List, Optional, Tuple, Union, Protocol +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Protocol from . import llama_types +from . import llama + + +class LlamaChatCompletionHandler(Protocol): + def __call__( + self, + llama: llama.Llama, + messages: List[llama_types.ChatCompletionRequestMessage], + functions: Optional[List[llama_types.ChatCompletionFunction]] = None, + function_call: Optional[ + Union[str, llama_types.ChatCompletionFunctionCall] + ] = None, + temperature: float = 0.2, + top_p: float = 0.95, + top_k: int = 40, + stream: bool = False, + stop: Optional[Union[str, List[str]]] = [], + max_tokens: int = 256, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + repeat_penalty: float = 1.1, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, + model: Optional[str] = None, + logits_processor: Optional[llama.LogitsProcessorList] = None, + grammar: Optional[llama.LlamaGrammar] = None, + ) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]: + ... + + +CHAT_HANDLERS: Dict[str, LlamaChatCompletionHandler] = {} + + +def get_chat_completion_handler(name: str) -> LlamaChatCompletionHandler: + return CHAT_HANDLERS[name] + + +def register_chat_completion_handler(name: str): + def decorator(f: LlamaChatCompletionHandler): + CHAT_HANDLERS[name] = f + return f + + return decorator def _get_system_message( @@ -119,12 +166,150 @@ class ChatFormatter(Protocol): ... +class BasicChatHandler: + def __init__(self, chat_format: str): + self.chat_format = chat_format + + +def _convert_text_completion_to_chat( + completion: llama_types.Completion, +) -> llama_types.ChatCompletion: + return { + "id": "chat" + completion["id"], + "object": "chat.completion", + "created": completion["created"], + "model": completion["model"], + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": completion["choices"][0]["text"], + }, + "finish_reason": completion["choices"][0]["finish_reason"], + } + ], + "usage": completion["usage"], + } + + +def _convert_text_completion_chunks_to_chat( + chunks: Iterator[llama_types.CompletionChunk], +) -> Iterator[llama_types.ChatCompletionChunk]: + for i, chunk in enumerate(chunks): + if i == 0: + yield { + "id": "chat" + chunk["id"], + "model": chunk["model"], + "created": chunk["created"], + "object": "chat.completion.chunk", + "choices": [ + { + "index": 0, + "delta": { + "role": "assistant", + }, + "finish_reason": None, + } + ], + } + yield { + "id": "chat" + chunk["id"], + "model": chunk["model"], + "created": chunk["created"], + "object": "chat.completion.chunk", + "choices": [ + { + "index": 0, + "delta": { + "content": chunk["choices"][0]["text"], + } + if chunk["choices"][0]["finish_reason"] is None + else {}, + "finish_reason": chunk["choices"][0]["finish_reason"], + } + ], + } + + +def _convert_completion_to_chat( + completion_or_chunks: Union[ + llama_types.Completion, Iterator[llama_types.CompletionChunk] + ], + stream: bool = False, +) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]: + if stream: + chunks: Iterator[llama_types.CompletionChunk] = completion_or_chunks # type: ignore + return _convert_text_completion_chunks_to_chat(chunks) + else: + completion: llama_types.Completion = completion_or_chunks # type: ignore + return _convert_text_completion_to_chat(completion) + + _CHAT_FORMATS: Dict[str, ChatFormatter] = {} def register_chat_format(name: str): def decorator(f: ChatFormatter): - _CHAT_FORMATS[name] = f + def basic_create_chat_completion( + llama: llama.Llama, + messages: List[llama_types.ChatCompletionRequestMessage], + functions: Optional[List[llama_types.ChatCompletionFunction]] = None, + function_call: Optional[ + Union[str, llama_types.ChatCompletionFunctionCall] + ] = None, + temperature: float = 0.2, + top_p: float = 0.95, + top_k: int = 40, + stream: bool = False, + stop: Optional[Union[str, List[str]]] = [], + max_tokens: int = 256, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + repeat_penalty: float = 1.1, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, + model: Optional[str] = None, + logits_processor: Optional[llama.LogitsProcessorList] = None, + grammar: Optional[llama.LlamaGrammar] = None, + ) -> Union[ + llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk] + ]: + result = f( + messages=messages, + functions=functions, + function_call=function_call, + ) + prompt = result.prompt + if result.stop is not None: + stop = [] if stop is None else [stop] if isinstance(stop, str) else stop + rstop = result.stop if isinstance(result.stop, list) else [result.stop] + stop = stop + rstop + + completion_or_chunks = llama.create_completion( + prompt=prompt, + temperature=temperature, + top_p=top_p, + top_k=top_k, + stream=stream, + stop=stop, + max_tokens=max_tokens, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + repeat_penalty=repeat_penalty, + tfs_z=tfs_z, + mirostat_mode=mirostat_mode, + mirostat_tau=mirostat_tau, + mirostat_eta=mirostat_eta, + model=model, + logits_processor=logits_processor, + grammar=grammar, + ) + return _convert_completion_to_chat(completion_or_chunks, stream=stream) # type: ignore + + register_chat_completion_handler(name)(basic_create_chat_completion) return f return decorator @@ -320,3 +505,206 @@ def format_chatml( _messages.append((_roles["assistant"], None)) _prompt = _format_chatml(system_message, _messages, _sep) return ChatFormatterResponse(prompt=_prompt) + + +@register_chat_completion_handler("functionary") +def functionary_chat_handler( + llama: llama.Llama, + messages: List[llama_types.ChatCompletionRequestMessage], + functions: Optional[List[llama_types.ChatCompletionFunction]] = None, + function_call: Optional[Union[str, llama_types.ChatCompletionFunctionCall]] = None, + temperature: float = 0.2, + top_p: float = 0.95, + top_k: int = 40, + stream: bool = False, + stop: Optional[Union[str, List[str]]] = [], + max_tokens: int = 256, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + repeat_penalty: float = 1.1, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, + model: Optional[str] = None, + logits_processor: Optional[llama.LogitsProcessorList] = None, + grammar: Optional[llama.LlamaGrammar] = None, +) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]: + SYSTEM_MESSAGE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary""" + + def generate_schema_from_functions( + functions: List[llama_types.ChatCompletionFunctions], + namespace: str = "functions", + ): + """ + Convert functions schema to a schema that language models can understand. + """ + + schema = ( + "// Supported function definitions that should be called when necessary.\n" + ) + schema += f"namespace {namespace} {{\n\n" + + for function in functions: + # Convert a Function object to dict, if necessary + function_name = function["name"] + description = function.get("description", "") + schema += f"// {description}\n" + schema += f"type {function_name}" + + parameters = function.get("parameters", None) + schema += " = (_: {\n" + required_params = parameters.get("required", []) + for param_name, param in parameters.get("properties", {}).items(): + # Param Description + description = param.get("description") + if description is not None: + schema += f"// {description}\n" + + # Param Name + schema += f"{param_name}" + if param_name not in required_params: + schema += "?" + + # Param Type + param_type = param.get("type", "any") + if param_type == "integer": + param_type = "number" + if "enum" in param: + param_type = " | ".join([f'"{v}"' for v in param["enum"]]) + schema += f": {param_type},\n" + + schema += "}) => any;\n\n" + + schema += f"}} // namespace {namespace}" + + return schema + + def prepare_messages_for_inference( + messages: List[llama_types.ChatCompletionRequestMessage], + functions: Optional[List[llama_types.ChatCompletionFunctions]] = None, + ): + all_messages: List[llama_types.ChatCompletionRequestMessage] = [] + if functions is not None: + all_messages.append( + llama_types.ChatCompletionRequestMessage( + role="system", content=generate_schema_from_functions(functions) + ) + ) + + all_messages.append( + llama_types.ChatCompletionRequestMessage( + role="system", content=SYSTEM_MESSAGE + ) + ) + + for message in messages: + # Function call responses + if message["role"] == "function" and "name" in message: + message["name"] = f"functions.{message['name']}" + # Function call requests by assistant + if "function_call" in message: + message["function_call"][ + "name" + ] = f"functions.{message['function_call']['name']}" + all_messages.append(message) + + all_messages.append( + llama_types.ChatCompletionRequestMessage(role="assistant", content=None) + ) + + def message_to_str(msg: llama_types.ChatCompletionRequestMessage): + if msg["role"] == "system": + return f"system:\n{msg['content']}\n" + + elif msg["role"] == "function" and "name" in msg: + return f"function name={msg['name']}:\n{msg['content']}\n" + elif msg["role"] == "function" and "function_call" in msg: + return f"function name={msg['function_call']['name']}:\n{msg['function_call']['arguments']}\n" + elif msg["role"] == "user": + if msg["content"] is None: + return "user:\n" + else: + return f"user:\n{msg['content']}\n" + elif msg["role"] == "assistant": + if msg["content"] is not None and "function_call" in msg: + return f"assistant:\n{msg['content']}\nassistant to={msg['function_call']['name']}:\n{msg['function_call']['arguments']}" + elif "function_call" in msg: + return f"assistant to={msg['function_call']['name']}:\n{msg['function_call']['arguments']}" + elif msg["content"] is None: + return "assistant" + else: + return f"assistant:\n{msg['content']}\n" + else: + raise ValueError(f"Unsupported role: {msg['role']}") + + return "".join([message_to_str(msg) for msg in all_messages]) + + prompt = prepare_messages_for_inference(messages, functions) + + if function_call is None and (functions is None or len(functions) == 0): + completion_or_completion_chunks = llama.create_completion( + prompt=prompt + ":\n", + temperature=temperature, + top_p=top_p, + top_k=top_k, + stream=stream, + stop=["user:", ""], + max_tokens=max_tokens, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + repeat_penalty=repeat_penalty, + tfs_z=tfs_z, + mirostat_mode=mirostat_mode, + mirostat_tau=mirostat_tau, + mirostat_eta=mirostat_eta, + model=model, + logits_processor=logits_processor, + grammar=grammar, + ) + return _convert_completion_to_chat(completion_or_completion_chunks, stream=stream) # type: ignore + + if function_call is None or ( + isinstance(function_call, str) and function_call == "auto" + ): + stop = "\n" + completion: llama_types.Completion = llama.create_completion( + prompt=prompt, stop=stop, stream=False + ) # type: ignore + completion_text = completion["choices"][0]["text"] + # strip " to=functions." and ending ":" + function_call = completion_text[14:-1] + new_prompt = prompt + completion_text + stop + elif isinstance(function_call, str) and function_call != "none": + new_prompt = prompt + f"assistant:\n" + elif isinstance(function_call, dict): + new_prompt = prompt + f"assistant to={function_call['name']}:\n" + function_call = function_call["name"] + else: + new_prompt = prompt + f"assistant:\n" + + completion: llama_types.Completion = llama.create_completion( + prompt=new_prompt, stop=["user:", ""], stream=False + ) # type: ignore + + return llama_types.CreateChatCompletionResponse( + id="chat" + completion["id"], + object="chat.completion", + created=completion["created"], + model=completion["model"], + choices=[ + { + "index": 0, + "message": { + "role": "function", + "content": None, + "function_call": { + "name": function_call, + "arguments": completion["choices"][0]["text"], + }, + }, + "finish_reason": "function_call", + } + ], + usage=completion["usage"], + ) diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py index 8ff1565..29431d9 100644 --- a/llama_cpp/llama_grammar.py +++ b/llama_cpp/llama_grammar.py @@ -1,4 +1,5 @@ -"""C++ implementation of the llama grammar parser.""" +"""Python implementation of llama grammar parser directly translated from C++ source file in vendor/llama.cpp/common/grammar-parser.cpp.""" + # flake8: noqa from pathlib import Path import sys @@ -1056,8 +1057,7 @@ def print_rule( # fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str()); if rule.empty() or rule.back().type != llama_gretype.LLAMA_GRETYPE_END: raise RuntimeError( - "malformed rule, does not end with LLAMA_GRETYPE_END: " - + str(rule_id) + "malformed rule, does not end with LLAMA_GRETYPE_END: " + str(rule_id) ) print(f"{symbol_id_names.at(rule_id)} ::=", file=file, end=" ") # for (size_t i = 0, end = rule.size() - 1; i < end; i++) { @@ -1102,9 +1102,7 @@ def print_rule( for i, elem in enumerate(rule[:-1]): case = elem.type # type: llama_gretype if case is llama_gretype.LLAMA_GRETYPE_END: - raise RuntimeError( - "unexpected end of rule: " + str(rule_id) + "," + str(i) - ) + raise RuntimeError("unexpected end of rule: " + str(rule_id) + "," + str(i)) elif case is llama_gretype.LLAMA_GRETYPE_ALT: print("| ", file=file, end="") elif case is llama_gretype.LLAMA_GRETYPE_RULE_REF: @@ -1186,3 +1184,308 @@ def print_grammar(file: TextIO, state: parse_state) -> None: f"{print_grammar.__name__}: error printing grammar: {err}", file=sys.stderr, ) + + +"""llama.cpp gbnf rules from vendor/llama.cpp/grammars""" + +ARITHMETIC_GBNF = """\ +root ::= (expr "=" ws term "\n")+ +expr ::= term ([-+*/] term)* +term ::= ident | num | "(" ws expr ")" ws +ident ::= [a-z] [a-z0-9_]* ws +num ::= [0-9]+ ws +ws ::= [ \t\n]* +""" + +C_GBNF = """\ +root ::= (declaration)* + +declaration ::= dataType identifier "(" parameter? ")" "{" statement* "}" + +dataType ::= "int" ws | "float" ws | "char" ws +identifier ::= [a-zA-Z_] [a-zA-Z_0-9]* + +parameter ::= dataType identifier + +statement ::= + ( dataType identifier ws "=" ws expression ";" ) | + ( identifier ws "=" ws expression ";" ) | + ( identifier ws "(" argList? ")" ";" ) | + ( "return" ws expression ";" ) | + ( "while" "(" condition ")" "{" statement* "}" ) | + ( "for" "(" forInit ";" ws condition ";" ws forUpdate ")" "{" statement* "}" ) | + ( "if" "(" condition ")" "{" statement* "}" ("else" "{" statement* "}")? ) | + ( singleLineComment ) | + ( multiLineComment ) + +forInit ::= dataType identifier ws "=" ws expression | identifier ws "=" ws expression +forUpdate ::= identifier ws "=" ws expression + +condition ::= expression relationOperator expression +relationOperator ::= ("<=" | "<" | "==" | "!=" | ">=" | ">") + +expression ::= term (("+" | "-") term)* +term ::= factor(("*" | "/") factor)* + +factor ::= identifier | number | unaryTerm | funcCall | parenExpression +unaryTerm ::= "-" factor +funcCall ::= identifier "(" argList? ")" +parenExpression ::= "(" ws expression ws ")" + +argList ::= expression ("," ws expression)* + +number ::= [0-9]+ + +singleLineComment ::= "//" [^\n]* "\n" +multiLineComment ::= "/*" ( [^*] | ("*" [^/]) )* "*/" + +ws ::= ([ \t\n]+) +""" + +CHESS_GBNF = """\ +root ::= object +value ::= object | array | string | number | ("true" | "false" | "null") ws + +object ::= + "{" ws ( + string ":" ws value + ("," ws string ":" ws value)* + )? "}" ws + +array ::= + "[" ws ( + value + ("," ws value)* + )? "]" ws + +string ::= + "\"" ( + [^"\\] | + "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes + )* "\"" ws + +number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws + +# Optional space: by convention, applied in this grammar after literal chars when allowed +ws ::= ([ \t\n] ws)? +""" + +JAPANESE_GBNF = """\ +root ::= object +value ::= object | array | string | number | ("true" | "false" | "null") ws + +object ::= + "{" ws ( + string ":" ws value + ("," ws string ":" ws value)* + )? "}" ws + +array ::= + "[" ws ( + value + ("," ws value)* + )? "]" ws + +string ::= + "\"" ( + [^"\\] | + "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes + )* "\"" ws + +number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws + +# Optional space: by convention, applied in this grammar after literal chars when allowed +ws ::= ([ \t\n] ws)? +""" + +JSON_ARR_GBNF = """\ +# This is the same as json.gbnf but we restrict whitespaces at the end of the root array +# Useful for generating JSON arrays + +root ::= arr +value ::= object | array | string | number | ("true" | "false" | "null") ws + +arr ::= + "[\n" ws ( + value + (",\n" ws value)* + )? "]" + +object ::= + "{" ws ( + string ":" ws value + ("," ws string ":" ws value)* + )? "}" ws + +array ::= + "[" ws ( + value + ("," ws value)* + )? "]" ws + +string ::= + "\"" ( + [^"\\] | + "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes + )* "\"" ws + +number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws + +# Optional space: by convention, applied in this grammar after literal chars when allowed +ws ::= ([ \t\n] ws)? +""" + + +JSON_GBNF = """\ +root ::= object +value ::= object | array | string | number | ("true" | "false" | "null") ws + +object ::= + "{" ws ( + string ":" ws value + ("," ws string ":" ws value)* + )? "}" ws + +array ::= + "[" ws ( + value + ("," ws value)* + )? "]" ws + +string ::= + "\"" ( + [^"\\] | + "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes + )* "\"" ws + +number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws + +# Optional space: by convention, applied in this grammar after literal chars when allowed +ws ::= ([ \t\n] ws)?""" + +LIST_GBNF = """\ +root ::= item+ + +# Excludes various line break characters +item ::= "- " [^\r\n\x0b\x0c\x85\u2028\u2029]+ "\n" +""" + +"""llama.cpp json-schema to grammar converter from vendor/llama.cpp/examples/json-schema-to-grammar.py""" +import json +import re +from typing import List, Optional + +# whitespace is constrained to a single space char to prevent model "running away" in +# whitespace. Also maybe improves generation quality? +SPACE_RULE = '" "?' + +PRIMITIVE_RULES = { + "boolean": '("true" | "false") space', + "number": '("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space', + "integer": '("-"? ([0-9] | [1-9] [0-9]*)) space', + "string": r""" "\"" ( + [^"\\] | + "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) + )* "\"" space """, + "null": '"null" space', +} + +INVALID_RULE_CHARS_RE = re.compile(r"[^a-zA-Z0-9-]+") +GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]') +GRAMMAR_LITERAL_ESCAPES = {"\r": "\\r", "\n": "\\n", '"': '\\"'} + + +class SchemaConverter: + def __init__(self, prop_order): + self._prop_order = prop_order + self._rules = {"space": SPACE_RULE} + + def _format_literal(self, literal): + escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub( + lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), json.dumps(literal) + ) + return f'"{escaped}"' + + def _add_rule(self, name, rule): + esc_name = INVALID_RULE_CHARS_RE.sub("-", name) + if esc_name not in self._rules or self._rules[esc_name] == rule: + key = esc_name + else: + i = 0 + while f"{esc_name}{i}" in self._rules: + i += 1 + key = f"{esc_name}{i}" + self._rules[key] = rule + return key + + def visit(self, schema, name): + schema_type = schema.get("type") + rule_name = name or "root" + + if "oneOf" in schema or "anyOf" in schema: + rule = " | ".join( + ( + self.visit(alt_schema, f'{name}{"-" if name else ""}{i}') + for i, alt_schema in enumerate( + schema.get("oneOf") or schema["anyOf"] + ) + ) + ) + return self._add_rule(rule_name, rule) + + elif "const" in schema: + return self._add_rule(rule_name, self._format_literal(schema["const"])) + + elif "enum" in schema: + rule = " | ".join((self._format_literal(v) for v in schema["enum"])) + return self._add_rule(rule_name, rule) + + elif schema_type == "object" and "properties" in schema: + # TODO: `required` keyword + prop_order = self._prop_order + prop_pairs = sorted( + schema["properties"].items(), + # sort by position in prop_order (if specified) then by key + key=lambda kv: (prop_order.get(kv[0], len(prop_order)), kv[0]), + ) + + rule = '"{" space' + for i, (prop_name, prop_schema) in enumerate(prop_pairs): + prop_rule_name = self.visit( + prop_schema, f'{name}{"-" if name else ""}{prop_name}' + ) + if i > 0: + rule += ' "," space' + rule += rf' {self._format_literal(prop_name)} space ":" space {prop_rule_name}' + rule += ' "}" space' + + return self._add_rule(rule_name, rule) + + elif schema_type == "array" and "items" in schema: + # TODO `prefixItems` keyword + item_rule_name = self.visit( + schema["items"], f'{name}{"-" if name else ""}item' + ) + rule = ( + f'"[" space ({item_rule_name} ("," space {item_rule_name})*)? "]" space' + ) + return self._add_rule(rule_name, rule) + + else: + assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}" + return self._add_rule( + "root" if rule_name == "root" else schema_type, + PRIMITIVE_RULES[schema_type], + ) + + def format_grammar(self): + return "\n".join((f"{name} ::= {rule}" for name, rule in self._rules.items())) + + +def json_schema_to_gbnf(schema: str, prop_order: Optional[List[str]] = None): + prop_order = prop_order or [] + schema = json.load(schema) + prop_order = {name: idx for idx, name in enumerate(prop_order)} + converter = SchemaConverter(prop_order) + converter.visit(schema, "") + return converter.format_grammar() diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index 73b660a..bec9561 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -743,21 +743,22 @@ async def create_embedding( class ChatCompletionRequestMessage(BaseModel): - role: Literal["system", "user", "assistant"] = Field( + role: Literal["system", "user", "assistant", "function"] = Field( default="user", description="The role of the message." ) - content: str = Field(default="", description="The content of the message.") + content: Optional[str] = Field(default="", description="The content of the message.") +from typing import Any class CreateChatCompletionRequest(BaseModel): - messages: List[ChatCompletionRequestMessage] = Field( + messages: List[Any] = Field( default=[], description="A list of messages to generate completions for." ) functions: Optional[List[llama_cpp.ChatCompletionFunction]] = Field( default=None, description="A list of functions to apply to the generated completions.", ) - function_call: Optional[Union[str, llama_cpp.ChatCompletionFunctionCall]] = Field( + function_call: Optional[Union[Literal["auto", "none"], llama_cpp.ChatCompletionFunctionCallOption]] = Field( default=None, description="A function to apply to the generated completions.", )