feat: Add Llama-3 chat format (#1371)

* feat: Add Llama-3 chat format

* feat: Auto-detect Llama-3 chat format from gguf template

* feat: Update llama.cpp to b2715

Includes proper Llama-3 <|eot_id|> token handling.

---------

Co-authored-by: Andrei Betlen <abetlen@gmail.com>
This commit is contained in:
abk16 2024-04-23 06:33:29 +00:00 committed by GitHub
parent 617d536e1c
commit 8559e8ce88
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -35,6 +35,9 @@ MISTRAL_INSTRUCT_EOS_TOKEN = "</s>"
# Source: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/tokenizer_config.json
MIXTRAL_INSTRUCT_CHAT_TEMPLATE = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
# Source: https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json
LLAMA3_INSTRUCT_CHAT_TEMPLATE = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"
### Chat Completion Handler ###
@ -729,6 +732,9 @@ def guess_chat_format_from_gguf_metadata(metadata: Dict[str, str]) -> Optional[s
metadata["tokenizer.chat_template"] == MIXTRAL_INSTRUCT_CHAT_TEMPLATE):
return "mistral-instruct"
if metadata["tokenizer.chat_template"] == LLAMA3_INSTRUCT_CHAT_TEMPLATE:
return "llama-3"
return None
@ -920,6 +926,26 @@ def format_llama2(
return ChatFormatterResponse(prompt=_prompt)
# Chat format for Llama-3 models, see more details at:
# https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py#L202-L229
@register_chat_format("llama-3")
def format_llama3(
messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs: Any,
) -> ChatFormatterResponse:
_roles = dict(
system="<|start_header_id|>system<|end_header_id|>\n\n",
user="<|start_header_id|>user<|end_header_id|>\n\n",
assistant="<|start_header_id|>assistant<|end_header_id|>\n\n",
)
_begin_token = "<|begin_of_text|>"
_sep = "<|eot_id|>"
_messages = _map_roles(messages, _roles)
_messages.append((_roles["assistant"], None))
_prompt = _format_no_colon_single(_begin_token, _messages, _sep)
return ChatFormatterResponse(prompt=_prompt, stop=_sep)
@register_chat_format("alpaca")
def format_alpaca(
messages: List[llama_types.ChatCompletionRequestMessage],