diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index bd8110f..9a09a28 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -91,6 +91,19 @@ def _format_add_colon_space_single( return ret +def _format_chatml( + system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str +) -> str: + """Format the prompt with the chatml style.""" + ret = "" if system_message == "" else system_message + sep + "\n" + for role, message in messages: + if message: + ret += role + "\n" + message + sep + "\n" + else: + ret += role + "\n" + return ret + + @dataclasses.dataclass class ChatFormatterResponse: prompt: str @@ -290,3 +303,20 @@ def format_open_orca( _messages.append((roles[1], None)) _prompt = _format_add_colon_space_single(system_message, _messages, sep) return ChatFormatterResponse(prompt=_prompt, stop=stop_str) + + +@register_chat_format("chatml") +def format_chatml( + messages: List[llama_types.ChatCompletionRequestMessage], + **kwargs: Any, +) -> ChatFormatterResponse: + system_template = """<|im_start|>system +{system_message}""" + system_message = _get_system_message(messages) + system_message = system_template.format(system_message=system_message) + _roles = dict(user="<|im_start|>user", assistant="<|im_start|>assistant") + _sep = "<|im_end|>" + _messages = _map_roles(messages, _roles) + _messages.append((_roles["assistant"], None)) + _prompt = _format_chatml(system_message, _messages, _sep) + return ChatFormatterResponse(prompt=_prompt)