Compare commits

...

31 commits

Author SHA1 Message Date
baalajimaestro ce85be97e2
Merge https://github.com/abetlen/llama-cpp-python 2024-04-25 10:48:33 +05:30
Andrei Betlen c50d3300d2 chore: Bump version 2024-04-23 02:53:20 -04:00
Andrei Betlen 611781f531 ci: Build arm64 wheels. Closes #1342 2024-04-23 02:48:09 -04:00
Sean Bailey 53ebcc8bb5
feat(server): Provide ability to dynamically allocate all threads if desired using -1 (#1364) 2024-04-23 02:35:38 -04:00
Geza Velkey 507c1da066
fix: Update scikit-build-core build dependency avoid bug in 0.9.1 (#1370)
cmake [options] <path-to-source>
  cmake [options] <path-to-existing-build>
  cmake [options] -S <path-to-source> -B <path-to-build>

Specify a source directory to (re-)generate a build system for it in the
current working directory.  Specify an existing build directory to
re-generate its build system.

Run 'cmake --help' for more information. issue
2024-04-23 02:34:15 -04:00
abk16 8559e8ce88
feat: Add Llama-3 chat format (#1371)
* feat: Add Llama-3 chat format

* feat: Auto-detect Llama-3 chat format from gguf template

* feat: Update llama.cpp to b2715

Includes proper Llama-3 <|eot_id|> token handling.

---------

Co-authored-by: Andrei Betlen <abetlen@gmail.com>
2024-04-23 02:33:29 -04:00
Andrei Betlen 617d536e1c feat: Update llama.cpp 2024-04-23 02:31:40 -04:00
Andrei Betlen d40a250ef3 feat: Use new llama_token_is_eog in create_completions 2024-04-22 00:35:47 -04:00
Andrei Betlen b21ba0e2ac Merge branch 'main' of https://github.com/abetlen/llama-cpp-python into main 2024-04-21 20:46:42 -04:00
Andrei Betlen 159cc4e5d9 feat: Update llama.cpp 2024-04-21 20:46:40 -04:00
Andrei Betlen 0281214863 chore: Bump version 2024-04-20 00:09:37 -04:00
Andrei Betlen cc81afebf0 feat: Add stopping_criteria to ChatFormatter, allow stopping on arbitrary token ids, fixes llama3 instruct 2024-04-20 00:00:53 -04:00
Andrei Betlen d17c1887a3 feat: Update llama.cpp 2024-04-19 23:58:16 -04:00
Andrei Betlen 893a27a736 chore: Bump version 2024-04-18 01:43:39 -04:00
Andrei Betlen a128c80500 feat: Update llama.cpp 2024-04-18 01:39:45 -04:00
Lucca Zenóbio 4f42664955
feat: update grammar schema converter to match llama.cpp (#1353)
* feat: improve function calling

* feat:grammar

* fix

* fix

* fix
2024-04-18 01:36:25 -04:00
Andrei Betlen fa4bb0cf81 Revert "feat: Update json to grammar (#1350)"
This reverts commit 610a592f70.
2024-04-17 16:18:16 -04:00
Lucca Zenóbio 610a592f70
feat: Update json to grammar (#1350)
* feat: improve function calling

* feat:grammar
2024-04-17 10:10:21 -04:00
khimaros b73c73c0c6
feat: add disable_ping_events flag (#1257)
for backward compatibility, this is false by default

it can be set to true to disable EventSource pings
which are not supported by some OpenAI clients.

fixes https://github.com/abetlen/llama-cpp-python/issues/1256
2024-04-17 10:08:19 -04:00
tc-wolf 4924455dec
feat: Make saved state more compact on-disk (#1296)
* State load/save changes

- Only store up to `n_tokens` logits instead of full `(n_ctx, n_vocab)`
  sized array.
  - Difference between ~350MB and ~1500MB for example prompt with ~300
    tokens (makes sense lol)
- Auto-formatting changes

* Back out formatting changes
2024-04-17 10:06:50 -04:00
Andrei Betlen 9842cbf99d feat: Update llama.cpp 2024-04-17 10:06:15 -04:00
ddh0 c96b2daebf feat: Use all available CPUs for batch processing (#1345) 2024-04-17 10:05:54 -04:00
Andrei Betlen a420f9608b feat: Update llama.cpp 2024-04-14 19:14:09 -04:00
Andrei Betlen 90dceaba8a feat: Update llama.cpp 2024-04-14 11:35:57 -04:00
Andrei Betlen 2e9ffd28fd feat: Update llama.cpp 2024-04-12 21:09:12 -04:00
Andrei Betlen ef29235d45 chore: Bump version 2024-04-10 03:44:46 -04:00
Andrei Betlen bb65b4d764 fix: pass correct type to chat handlers for chat completion logprobs 2024-04-10 03:41:55 -04:00
Andrei Betlen 060bfa64d5 feat: Add support for yaml based configs 2024-04-10 02:47:01 -04:00
Andrei Betlen 1347e1d050 feat: Add typechecking for ctypes structure attributes 2024-04-10 02:40:41 -04:00
Andrei Betlen 889d0e8981 feat: Update llama.cpp 2024-04-10 02:25:58 -04:00
Andrei Betlen 56071c956a feat: Update llama.cpp 2024-04-09 09:53:49 -04:00
14 changed files with 1201 additions and 191 deletions

View file

@ -41,6 +41,35 @@ jobs:
with: with:
path: ./wheelhouse/*.whl path: ./wheelhouse/*.whl
build_arm64_wheels:
name: Build arm64 wheels
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
submodules: "recursive"
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
with:
platforms: linux/arm64
- name: Build wheels
uses: pypa/cibuildwheel@v2.16.5
env:
CIBW_SKIP: "*musllinux* pp*"
CIBW_REPAIR_WHEEL_COMMAND: ""
CIBW_ARCHS: "aarch64"
CIBW_BUILD: "cp38-* cp39-* cp310-* cp311-* cp312-*"
with:
output-dir: wheelhouse/
- name: Upload wheels as artifacts
uses: actions/upload-artifact@v4
with:
name: wheels-${{ matrix.version }}
path: wheelhouse/*.whl
build_sdist: build_sdist:
name: Build source distribution name: Build source distribution
runs-on: ubuntu-latest runs-on: ubuntu-latest
@ -65,7 +94,7 @@ jobs:
release: release:
name: Release name: Release
needs: [build_wheels, build_sdist] needs: [build_wheels, build_arm64_wheels, build_sdist]
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:

View file

@ -7,6 +7,35 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased] ## [Unreleased]
## [0.2.64]
- feat: Update llama.cpp to ggerganov/llama.cpp@4e96a812b3ce7322a29a3008db2ed73d9087b176
- feat: Add `llama-3` chat format by @andreabak in #1371
- feat: Use new llama_token_is_eog in create_completions by @abetlen in d40a250ef3cfaa8224d12c83776a2f1de96ae3d1
- feat(server): Provide ability to dynamically allocate all threads if desired using -1 by @sean-bailey in #1364
- ci: Build arm64 wheels by @gaby in 611781f5319719a3d05fefccbbf0cc321742a026
- fix: Update scikit-build-core build dependency avoid bug in 0.9.1 by @evelkey in #1370
## [0.2.63]
- feat: Update llama.cpp to ggerganov/llama.cpp@0e4802b2ecbaab04b4f829fde4a3096ca19c84b5
- feat: Add stopping_criteria to ChatFormatter, allow stopping on arbitrary token ids, fixes llama3 instruct by @abetlen in cc81afebf04d26ca1ac3cf72f23f18da6ab58588
## [0.2.62]
- feat: Update llama.cpp to ggerganov/llama.cpp@3b8f1ec4b18770531d0b1d792f3edf08254e4f0c
- feat: update grammar schema converter to match llama.cpp by @themrzmaster in #1353
- feat: add disable_ping_events flag by @khimaros in #1257
- feat: Make saved state more compact on-disk by @tc-wolf in #1296
- feat: Use all available CPUs for batch processing by @ddh0 in #1345
## [0.2.61]
- feat: Update llama.cpp to ggerganov/llama.cpp@ba5e134e073ec6837078c874aba44a702944a676
- fix: pass correct type to chat handlers for chat completion logprobs by @abetlen in bb65b4d76411112c6fb0bf759efd746f99ef3c6b
- feat: Add support for yaml based server configs by @abetlen in 060bfa64d529ade2af9b1f4e207a3937bbc4138f
- feat: Add typechecking for ctypes structure attributes by @abetlen in 1347e1d050fc5a9a32ffe0bb3e22858da28003bd
## [0.2.60] ## [0.2.60]
- feat: Update llama.cpp to ggerganov/llama.cpp@75cd4c77292034ecec587ecb401366f57338f7c0 - feat: Update llama.cpp to ggerganov/llama.cpp@75cd4c77292034ecec587ecb401366f57338f7c0

View file

@ -0,0 +1,30 @@
"""llama-cpp-python server from scratch in a single file.
"""
# import llama_cpp
# path = b"../../models/Qwen1.5-0.5B-Chat-GGUF/qwen1_5-0_5b-chat-q8_0.gguf"
# model_params = llama_cpp.llama_model_default_params()
# model = llama_cpp.llama_load_model_from_file(path, model_params)
# if model is None:
# raise RuntimeError(f"Failed to load model from file: {path}")
# ctx_params = llama_cpp.llama_context_default_params()
# ctx = llama_cpp.llama_new_context_with_model(model, ctx_params)
# if ctx is None:
# raise RuntimeError("Failed to create context")
from fastapi import FastAPI
app = FastAPI()
import openai.types.chat as types
@app.post("/v1/chat/completions")
def create_chat_completions():
return {"message": "Hello World"}

View file

@ -1,4 +1,4 @@
from .llama_cpp import * from .llama_cpp import *
from .llama import * from .llama import *
__version__ = "0.2.60" __version__ = "0.2.64"

View file

@ -181,20 +181,20 @@ class _LlamaModel:
) )
return list(tokens[:n_tokens]) return list(tokens[:n_tokens])
def token_to_piece(self, token: int) -> bytes: def token_to_piece(self, token: int, special: bool = False) -> bytes:
assert self.model is not None assert self.model is not None
buf = ctypes.create_string_buffer(32) buf = ctypes.create_string_buffer(32)
llama_cpp.llama_token_to_piece(self.model, token, buf, 32) llama_cpp.llama_token_to_piece(self.model, token, buf, 32, special)
return bytes(buf) return bytes(buf)
def detokenize(self, tokens: List[int]) -> bytes: def detokenize(self, tokens: List[int], special: bool = False) -> bytes:
assert self.model is not None assert self.model is not None
output = b"" output = b""
size = 32 size = 32
buffer = (ctypes.c_char * size)() buffer = (ctypes.c_char * size)()
for token in tokens: for token in tokens:
n = llama_cpp.llama_token_to_piece( n = llama_cpp.llama_token_to_piece(
self.model, llama_cpp.llama_token(token), buffer, size self.model, llama_cpp.llama_token(token), buffer, size, special
) )
assert n <= size assert n <= size
output += bytes(buffer[:n]) output += bytes(buffer[:n])
@ -597,13 +597,13 @@ def _tokenize(model: _LlamaModel, text: str, add_bos: bool, special: bool) -> li
return list(result) return list(result)
def _token_to_piece(model: _LlamaModel, token: int) -> str: def _token_to_piece(model: _LlamaModel, token: int, special: bool = False) -> str:
assert model.model is not None assert model.model is not None
result = (ctypes.c_char * 8)(0) result = (ctypes.c_char * 8)(0)
n_tokens = llama_cpp.llama_token_to_piece(model.model, token, result, len(result)) n_tokens = llama_cpp.llama_token_to_piece(model.model, token, result, len(result), special)
if n_tokens < 0: if n_tokens < 0:
result = (ctypes.c_char * -n_tokens)(0) result = (ctypes.c_char * -n_tokens)(0)
check = llama_cpp.llama_token_to_piece(model.model, token, result, len(result)) check = llama_cpp.llama_token_to_piece(model.model, token, result, len(result), special)
if check != -n_tokens: if check != -n_tokens:
raise RuntimeError(f"Failed to get piece: token={token}") raise RuntimeError(f"Failed to get piece: token={token}")
else: else:

View file

@ -18,6 +18,7 @@ from typing import (
Iterator, Iterator,
Deque, Deque,
Callable, Callable,
Dict,
) )
from collections import deque from collections import deque
from pathlib import Path from pathlib import Path
@ -262,9 +263,7 @@ class Llama:
self.n_batch = min(n_ctx, n_batch) # ??? self.n_batch = min(n_ctx, n_batch) # ???
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1) self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
self.n_threads_batch = n_threads_batch or max( self.n_threads_batch = n_threads_batch or multiprocessing.cpu_count()
multiprocessing.cpu_count() // 2, 1
)
# Context Params # Context Params
self.context_params = llama_cpp.llama_context_default_params() self.context_params = llama_cpp.llama_context_default_params()
@ -427,7 +426,10 @@ class Llama:
print(f"Using chat bos_token: {bos_token}", file=sys.stderr) print(f"Using chat bos_token: {bos_token}", file=sys.stderr)
self.chat_handler = llama_chat_format.Jinja2ChatFormatter( self.chat_handler = llama_chat_format.Jinja2ChatFormatter(
template=template, eos_token=eos_token, bos_token=bos_token template=template,
eos_token=eos_token,
bos_token=bos_token,
stop_token_ids=[eos_token_id],
).to_chat_handler() ).to_chat_handler()
if self.chat_format is None and self.chat_handler is None: if self.chat_format is None and self.chat_handler is None:
@ -1032,7 +1034,8 @@ class Llama:
logits_processor=logits_processor, logits_processor=logits_processor,
grammar=grammar, grammar=grammar,
): ):
if token == self._token_eos: assert self._model.model is not None
if llama_cpp.llama_token_is_eog(self._model.model, token):
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens) text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
finish_reason = "stop" finish_reason = "stop"
break break
@ -1664,7 +1667,8 @@ class Llama:
top_k=top_k, top_k=top_k,
min_p=min_p, min_p=min_p,
typical_p=typical_p, typical_p=typical_p,
logprobs=top_logprobs if logprobs else None, logprobs=logprobs,
top_logprobs=top_logprobs,
stream=stream, stream=stream,
stop=stop, stop=stop,
seed=seed, seed=seed,
@ -1792,7 +1796,7 @@ class Llama:
file=sys.stderr, file=sys.stderr,
) )
return LlamaState( return LlamaState(
scores=self.scores.copy(), scores=self._scores.copy(),
input_ids=self.input_ids.copy(), input_ids=self.input_ids.copy(),
n_tokens=self.n_tokens, n_tokens=self.n_tokens,
llama_state=bytes(llama_state_compact), llama_state=bytes(llama_state_compact),
@ -1801,7 +1805,9 @@ class Llama:
def load_state(self, state: LlamaState) -> None: def load_state(self, state: LlamaState) -> None:
assert self._ctx.ctx is not None assert self._ctx.ctx is not None
self.scores = state.scores.copy() # Only filling in up to `n_tokens` and then zero-ing out the rest
self.scores[: state.n_tokens, :] = state.scores.copy()
self.scores[state.n_tokens :, :] = 0.0
self.input_ids = state.input_ids.copy() self.input_ids = state.input_ids.copy()
self.n_tokens = state.n_tokens self.n_tokens = state.n_tokens
state_size = state.llama_state_size state_size = state.llama_state_size
@ -1952,7 +1958,6 @@ class Llama:
local_dir_use_symlinks=local_dir_use_symlinks, local_dir_use_symlinks=local_dir_use_symlinks,
cache_dir=cache_dir, cache_dir=cache_dir,
local_files_only=True, local_files_only=True,
) )
else: else:
model_path = os.path.join(local_dir, filename) model_path = os.path.join(local_dir, filename)

View file

@ -10,6 +10,9 @@ from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union, P
import jinja2 import jinja2
import numpy as np
import numpy.typing as npt
import llama_cpp.llama as llama import llama_cpp.llama as llama
import llama_cpp.llama_types as llama_types import llama_cpp.llama_types as llama_types
import llama_cpp.llama_grammar as llama_grammar import llama_cpp.llama_grammar as llama_grammar
@ -32,6 +35,9 @@ MISTRAL_INSTRUCT_EOS_TOKEN = "</s>"
# Source: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/tokenizer_config.json # Source: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/tokenizer_config.json
MIXTRAL_INSTRUCT_CHAT_TEMPLATE = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" MIXTRAL_INSTRUCT_CHAT_TEMPLATE = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
# Source: https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json
LLAMA3_INSTRUCT_CHAT_TEMPLATE = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"
### Chat Completion Handler ### ### Chat Completion Handler ###
@ -77,6 +83,8 @@ class LlamaChatCompletionHandler(Protocol):
mirostat_eta: float = 0.1, mirostat_eta: float = 0.1,
logits_processor: Optional[llama.LogitsProcessorList] = None, logits_processor: Optional[llama.LogitsProcessorList] = None,
grammar: Optional[llama.LlamaGrammar] = None, grammar: Optional[llama.LlamaGrammar] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
**kwargs, # type: ignore **kwargs, # type: ignore
) -> Union[ ) -> Union[
llama_types.CreateChatCompletionResponse, llama_types.CreateChatCompletionResponse,
@ -148,6 +156,7 @@ class ChatFormatterResponse:
prompt: str prompt: str
stop: Optional[Union[str, List[str]]] = None stop: Optional[Union[str, List[str]]] = None
stopping_criteria: Optional[llama.StoppingCriteriaList] = None
class ChatFormatter(Protocol): class ChatFormatter(Protocol):
@ -171,12 +180,14 @@ class Jinja2ChatFormatter(ChatFormatter):
eos_token: str, eos_token: str,
bos_token: str, bos_token: str,
add_generation_prompt: bool = True, add_generation_prompt: bool = True,
stop_token_ids: Optional[List[int]] = None,
): ):
"""A chat formatter that uses jinja2 templates to format the prompt.""" """A chat formatter that uses jinja2 templates to format the prompt."""
self.template = template self.template = template
self.eos_token = eos_token self.eos_token = eos_token
self.bos_token = bos_token self.bos_token = bos_token
self.add_generation_prompt = add_generation_prompt self.add_generation_prompt = add_generation_prompt
self.stop_token_ids = set(stop_token_ids) if stop_token_ids is not None else None
self._environment = jinja2.Environment( self._environment = jinja2.Environment(
loader=jinja2.BaseLoader(), loader=jinja2.BaseLoader(),
@ -209,7 +220,16 @@ class Jinja2ChatFormatter(ChatFormatter):
tool_choice=tool_choice, tool_choice=tool_choice,
) )
return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token]) stopping_criteria = None
if self.stop_token_ids is not None:
def stop_on_last_token(
tokens: npt.NDArray[np.intc],
logits: npt.NDArray[np.single]
) -> bool:
return tokens[-1] in self.stop_token_ids
stopping_criteria = llama.StoppingCriteriaList([stop_on_last_token])
return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token], stopping_criteria=stopping_criteria)
def to_chat_handler(self) -> LlamaChatCompletionHandler: def to_chat_handler(self) -> LlamaChatCompletionHandler:
return chat_formatter_to_chat_completion_handler(self) return chat_formatter_to_chat_completion_handler(self)
@ -338,7 +358,7 @@ def _convert_completion_to_chat_function(
} }
], ],
}, },
"logprobs": None, "logprobs": completion["choices"][0]["logprobs"],
"finish_reason": "tool_calls", "finish_reason": "tool_calls",
} }
], ],
@ -391,7 +411,7 @@ def _convert_completion_to_chat_function(
{ {
"index": 0, "index": 0,
"finish_reason": None, "finish_reason": None,
"logprobs": None, "logprobs": chunk["choices"][0]["logprobs"],
"delta": { "delta": {
"role": None, "role": None,
"content": None, "content": None,
@ -426,7 +446,7 @@ def _convert_completion_to_chat_function(
{ {
"index": 0, "index": 0,
"finish_reason": None, "finish_reason": None,
"logprobs": None, "logprobs": chunk["choices"][0]["logprobs"],
"delta": { "delta": {
"role": None, "role": None,
"content": None, "content": None,
@ -491,7 +511,6 @@ def chat_formatter_to_chat_completion_handler(
temperature: float = 0.2, temperature: float = 0.2,
top_p: float = 0.95, top_p: float = 0.95,
top_k: int = 40, top_k: int = 40,
logprobs: int = 0,
min_p: float = 0.05, min_p: float = 0.05,
typical_p: float = 1.0, typical_p: float = 1.0,
stream: bool = False, stream: bool = False,
@ -512,6 +531,8 @@ def chat_formatter_to_chat_completion_handler(
logits_processor: Optional[llama.LogitsProcessorList] = None, logits_processor: Optional[llama.LogitsProcessorList] = None,
grammar: Optional[llama.LlamaGrammar] = None, grammar: Optional[llama.LlamaGrammar] = None,
logit_bias: Optional[Dict[str, float]] = None, logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
**kwargs, # type: ignore **kwargs, # type: ignore
) -> Union[ ) -> Union[
llama_types.CreateChatCompletionResponse, llama_types.CreateChatCompletionResponse,
@ -530,6 +551,10 @@ def chat_formatter_to_chat_completion_handler(
rstop = result.stop if isinstance(result.stop, list) else [result.stop] rstop = result.stop if isinstance(result.stop, list) else [result.stop]
stop = stop + rstop stop = stop + rstop
stopping_criteria = None
if result.stopping_criteria is not None:
stopping_criteria = result.stopping_criteria
if response_format is not None and response_format["type"] == "json_object": if response_format is not None and response_format["type"] == "json_object":
grammar = _grammar_for_response_format(response_format, verbose=llama.verbose) grammar = _grammar_for_response_format(response_format, verbose=llama.verbose)
@ -581,7 +606,7 @@ def chat_formatter_to_chat_completion_handler(
top_k=top_k, top_k=top_k,
min_p=min_p, min_p=min_p,
typical_p=typical_p, typical_p=typical_p,
logprobs=logprobs, logprobs=top_logprobs if logprobs else None,
stream=stream, stream=stream,
stop=stop, stop=stop,
seed=seed, seed=seed,
@ -595,6 +620,7 @@ def chat_formatter_to_chat_completion_handler(
mirostat_eta=mirostat_eta, mirostat_eta=mirostat_eta,
model=model, model=model,
logits_processor=logits_processor, logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
grammar=grammar, grammar=grammar,
logit_bias=logit_bias, logit_bias=logit_bias,
) )
@ -706,6 +732,9 @@ def guess_chat_format_from_gguf_metadata(metadata: Dict[str, str]) -> Optional[s
metadata["tokenizer.chat_template"] == MIXTRAL_INSTRUCT_CHAT_TEMPLATE): metadata["tokenizer.chat_template"] == MIXTRAL_INSTRUCT_CHAT_TEMPLATE):
return "mistral-instruct" return "mistral-instruct"
if metadata["tokenizer.chat_template"] == LLAMA3_INSTRUCT_CHAT_TEMPLATE:
return "llama-3"
return None return None
@ -897,6 +926,26 @@ def format_llama2(
return ChatFormatterResponse(prompt=_prompt) return ChatFormatterResponse(prompt=_prompt)
# Chat format for Llama-3 models, see more details at:
# https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py#L202-L229
@register_chat_format("llama-3")
def format_llama3(
messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs: Any,
) -> ChatFormatterResponse:
_roles = dict(
system="<|start_header_id|>system<|end_header_id|>\n\n",
user="<|start_header_id|>user<|end_header_id|>\n\n",
assistant="<|start_header_id|>assistant<|end_header_id|>\n\n",
)
_begin_token = "<|begin_of_text|>"
_sep = "<|eot_id|>"
_messages = _map_roles(messages, _roles)
_messages.append((_roles["assistant"], None))
_prompt = _format_no_colon_single(_begin_token, _messages, _sep)
return ChatFormatterResponse(prompt=_prompt, stop=_sep)
@register_chat_format("alpaca") @register_chat_format("alpaca")
def format_alpaca( def format_alpaca(
messages: List[llama_types.ChatCompletionRequestMessage], messages: List[llama_types.ChatCompletionRequestMessage],
@ -1628,7 +1677,7 @@ def functionary_chat_handler(
} }
], ],
}, },
"logprobs": None, "logprobs": completion["choices"][0]["logprobs"],
"finish_reason": "tool_calls", "finish_reason": "tool_calls",
} }
], ],
@ -2085,7 +2134,7 @@ def functionary_v1_v2_chat_handler(
choices=[ choices=[
{ {
"index": 0, "index": 0,
"logprobs": None, "logprobs": completion["choices"][0]["logprobs"],
"message": { "message": {
"role": "assistant", "role": "assistant",
"content": None if content == "" else content, "content": None if content == "" else content,
@ -2311,11 +2360,14 @@ def chatml_function_calling(
model: Optional[str] = None, model: Optional[str] = None,
logits_processor: Optional[llama.LogitsProcessorList] = None, logits_processor: Optional[llama.LogitsProcessorList] = None,
grammar: Optional[llama.LlamaGrammar] = None, grammar: Optional[llama.LlamaGrammar] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
**kwargs, # type: ignore **kwargs, # type: ignore
) -> Union[ ) -> Union[
llama_types.CreateChatCompletionResponse, llama_types.CreateChatCompletionResponse,
Iterator[llama_types.CreateChatCompletionStreamResponse], Iterator[llama_types.CreateChatCompletionStreamResponse],
]: ]:
print(logprobs)
function_calling_template = ( function_calling_template = (
"{% for message in messages %}" "{% for message in messages %}"
"<|im_start|>{{ message.role }}\n" "<|im_start|>{{ message.role }}\n"
@ -2437,6 +2489,7 @@ def chatml_function_calling(
model=model, model=model,
logits_processor=logits_processor, logits_processor=logits_processor,
grammar=grammar, grammar=grammar,
logprobs=top_logprobs if logprobs else None,
), ),
stream=stream, stream=stream,
) )
@ -2549,6 +2602,7 @@ def chatml_function_calling(
typical_p=typical_p, typical_p=typical_p,
stream=stream, stream=stream,
stop=["<|im_end|>"], stop=["<|im_end|>"],
logprobs=top_logprobs if logprobs else None,
max_tokens=None, max_tokens=None,
presence_penalty=presence_penalty, presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty, frequency_penalty=frequency_penalty,
@ -2660,7 +2714,7 @@ def chatml_function_calling(
{ {
"finish_reason": "tool_calls", "finish_reason": "tool_calls",
"index": 0, "index": 0,
"logprobs": None, "logprobs": completion["choices"][0]["logprobs"],
"message": { "message": {
"role": "assistant", "role": "assistant",
"content": None, "content": None,
@ -2701,4 +2755,4 @@ def chatml_function_calling(
}, },
} }
raise ValueError("Automatic streaming tool choice is not supported") raise ValueError("Automatic streaming tool choice is not supported")

View file

@ -237,11 +237,18 @@ LLAMA_FILE_MAGIC_GGLA = 0x67676C61
# define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' # define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
LLAMA_FILE_MAGIC_GGSN = 0x6767736E LLAMA_FILE_MAGIC_GGSN = 0x6767736E
# define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
LLAMA_FILE_MAGIC_GGSQ = 0x67677371
# define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN # define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
LLAMA_SESSION_MAGIC = LLAMA_FILE_MAGIC_GGSN LLAMA_SESSION_MAGIC = LLAMA_FILE_MAGIC_GGSN
# define LLAMA_SESSION_VERSION 5 # define LLAMA_SESSION_VERSION 5
LLAMA_SESSION_VERSION = 5 LLAMA_SESSION_VERSION = 5
# define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
LLAMA_STATE_SEQ_MAGIC = LLAMA_FILE_MAGIC_GGSQ
# define LLAMA_STATE_SEQ_VERSION 1
LLAMA_STATE_SEQ_VERSION = 1
# struct llama_model; # struct llama_model;
llama_model_p = NewType("llama_model_p", int) llama_model_p = NewType("llama_model_p", int)
@ -424,6 +431,11 @@ class llama_token_data(ctypes.Structure):
logit (float): log-odds of the token logit (float): log-odds of the token
p (float): probability of the token""" p (float): probability of the token"""
if TYPE_CHECKING:
id: llama_token
logit: float
p: float
_fields_ = [ _fields_ = [
("id", llama_token), ("id", llama_token),
("logit", ctypes.c_float), ("logit", ctypes.c_float),
@ -447,6 +459,11 @@ class llama_token_data_array(ctypes.Structure):
size (int): size of the array size (int): size of the array
sorted (bool): whether the array is sorted""" sorted (bool): whether the array is sorted"""
if TYPE_CHECKING:
data: CtypesArray[llama_token_data]
size: int
sorted: bool
_fields_ = [ _fields_ = [
("data", llama_token_data_p), ("data", llama_token_data_p),
("size", ctypes.c_size_t), ("size", ctypes.c_size_t),
@ -508,6 +525,15 @@ class llama_batch(ctypes.Structure):
logits (ctypes.Array[ctypes.ctypes.c_int8]): if zero, the logits for the respective token will not be output logits (ctypes.Array[ctypes.ctypes.c_int8]): if zero, the logits for the respective token will not be output
""" """
if TYPE_CHECKING:
n_tokens: int
token: CtypesArray[llama_token]
embd: CtypesArray[ctypes.c_float]
pos: CtypesArray[CtypesArray[llama_pos]]
n_seq_id: CtypesArray[ctypes.c_int]
seq_id: CtypesArray[CtypesArray[llama_seq_id]]
logits: CtypesArray[ctypes.c_int8]
_fields_ = [ _fields_ = [
("n_tokens", ctypes.c_int32), ("n_tokens", ctypes.c_int32),
("token", ctypes.POINTER(llama_token)), ("token", ctypes.POINTER(llama_token)),
@ -602,6 +628,18 @@ class llama_model_params(ctypes.Structure):
use_mmap (bool): use mmap if possible use_mmap (bool): use mmap if possible
use_mlock (bool): force system to keep model in RAM""" use_mlock (bool): force system to keep model in RAM"""
if TYPE_CHECKING:
n_gpu_layers: int
split_mode: int
main_gpu: int
tensor_split: CtypesArray[ctypes.c_float]
progress_callback: Callable[[float, ctypes.c_void_p], bool]
progress_callback_user_data: ctypes.c_void_p
kv_overrides: CtypesArray[llama_model_kv_override]
vocab_only: bool
use_mmap: bool
use_mlock: bool
_fields_ = [ _fields_ = [
("n_gpu_layers", ctypes.c_int32), ("n_gpu_layers", ctypes.c_int32),
("split_mode", ctypes.c_int), ("split_mode", ctypes.c_int),
@ -689,6 +727,34 @@ class llama_context_params(ctypes.Structure):
abort_callback_data (ctypes.ctypes.c_void_p): data for abort_callback abort_callback_data (ctypes.ctypes.c_void_p): data for abort_callback
""" """
if TYPE_CHECKING:
seed: int
n_ctx: int
n_batch: int
n_ubatch: int
n_seq_max: int
n_threads: int
n_threads_batch: int
rope_scaling_type: int
pooling_type: int
rope_freq_base: float
rope_freq_scale: float
yarn_ext_factor: float
yarn_attn_factor: float
yarn_beta_fast: float
yarn_beta_slow: float
yarn_orig_ctx: int
defrag_thold: float
cb_eval: Callable[[ctypes.c_void_p, bool], bool]
cb_eval_user_data: ctypes.c_void_p
type_k: int
type_v: int
logits_all: bool
embeddings: bool
offload_kqv: bool
abort_callback: Callable[[ctypes.c_void_p], bool]
abort_callback_data: ctypes.c_void_p
_fields_ = [ _fields_ = [
("seed", ctypes.c_uint32), ("seed", ctypes.c_uint32),
("n_ctx", ctypes.c_uint32), ("n_ctx", ctypes.c_uint32),
@ -764,6 +830,18 @@ class llama_model_quantize_params(ctypes.Structure):
kv_overrides (ctypes.c_void_p): pointer to vector containing overrides kv_overrides (ctypes.c_void_p): pointer to vector containing overrides
""" """
if TYPE_CHECKING:
nthread: int
ftype: int
output_tensor_type: int
token_embedding_type: int
allow_requantize: bool
quantize_output_tensor: bool
only_copy: bool
pure: bool
imatrix: ctypes.c_void_p
kv_overrides: ctypes.c_void_p
_fields_ = [ _fields_ = [
("nthread", ctypes.c_int32), ("nthread", ctypes.c_int32),
("ftype", ctypes.c_int), ("ftype", ctypes.c_int),
@ -821,6 +899,10 @@ LLAMA_GRETYPE_CHAR_ALT = 6
# uint32_t value; // Unicode code point or rule ID # uint32_t value; // Unicode code point or rule ID
# } llama_grammar_element; # } llama_grammar_element;
class llama_grammar_element(ctypes.Structure): class llama_grammar_element(ctypes.Structure):
if TYPE_CHECKING:
type: int
value: int
_fields_ = [ _fields_ = [
("type", ctypes.c_int), ("type", ctypes.c_int),
("value", ctypes.c_uint32), ("value", ctypes.c_uint32),
@ -844,6 +926,17 @@ llama_grammar_element_p = ctypes.POINTER(llama_grammar_element)
# int32_t n_eval; # int32_t n_eval;
# }; # };
class llama_timings(ctypes.Structure): class llama_timings(ctypes.Structure):
if TYPE_CHECKING:
t_start_ms: float
t_end_ms: float
t_load_ms: float
t_sample_ms: float
t_p_eval_ms: float
t_eval_ms: float
n_sample: int
n_p_eval: int
n_eval: int
_fields_ = [ _fields_ = [
("t_start_ms", ctypes.c_double), ("t_start_ms", ctypes.c_double),
("t_end_ms", ctypes.c_double), ("t_end_ms", ctypes.c_double),
@ -944,7 +1037,8 @@ GGML_NUMA_STRATEGY_COUNT = 5
[ctypes.c_int], [ctypes.c_int],
None, None,
) )
def llama_numa_init(numa: int, /): ... def llama_numa_init(numa: int, /):
...
# // Call once at the end of the program - currently only used for MPI # // Call once at the end of the program - currently only used for MPI
@ -969,7 +1063,8 @@ def llama_backend_free():
) )
def llama_load_model_from_file( def llama_load_model_from_file(
path_model: bytes, params: llama_model_params, / path_model: bytes, params: llama_model_params, /
) -> Optional[llama_model_p]: ... ) -> Optional[llama_model_p]:
...
# LLAMA_API void llama_free_model(struct llama_model * model); # LLAMA_API void llama_free_model(struct llama_model * model);
@ -978,7 +1073,8 @@ def llama_load_model_from_file(
[llama_model_p_ctypes], [llama_model_p_ctypes],
None, None,
) )
def llama_free_model(model: llama_model_p, /): ... def llama_free_model(model: llama_model_p, /):
...
# LLAMA_API struct llama_context * llama_new_context_with_model( # LLAMA_API struct llama_context * llama_new_context_with_model(
@ -991,7 +1087,8 @@ def llama_free_model(model: llama_model_p, /): ...
) )
def llama_new_context_with_model( def llama_new_context_with_model(
model: llama_model_p, params: llama_context_params, / model: llama_model_p, params: llama_context_params, /
) -> Optional[llama_context_p]: ... ) -> Optional[llama_context_p]:
...
# // Frees all allocated memory # // Frees all allocated memory
@ -1012,82 +1109,98 @@ def llama_free(ctx: llama_context_p, /):
[], [],
ctypes.c_int64, ctypes.c_int64,
) )
def llama_time_us() -> int: ... def llama_time_us() -> int:
...
# LLAMA_API size_t llama_max_devices(void); # LLAMA_API size_t llama_max_devices(void);
@ctypes_function("llama_max_devices", [], ctypes.c_size_t) @ctypes_function("llama_max_devices", [], ctypes.c_size_t)
def llama_max_devices() -> int: ... def llama_max_devices() -> int:
...
# LLAMA_API bool llama_supports_mmap (void); # LLAMA_API bool llama_supports_mmap (void);
@ctypes_function("llama_supports_mmap", [], ctypes.c_bool) @ctypes_function("llama_supports_mmap", [], ctypes.c_bool)
def llama_supports_mmap() -> bool: ... def llama_supports_mmap() -> bool:
...
# LLAMA_API bool llama_supports_mlock (void); # LLAMA_API bool llama_supports_mlock (void);
@ctypes_function("llama_supports_mlock", [], ctypes.c_bool) @ctypes_function("llama_supports_mlock", [], ctypes.c_bool)
def llama_supports_mlock() -> bool: ... def llama_supports_mlock() -> bool:
...
# LLAMA_API bool llama_supports_gpu_offload(void); # LLAMA_API bool llama_supports_gpu_offload(void);
@ctypes_function("llama_supports_gpu_offload", [], ctypes.c_bool) @ctypes_function("llama_supports_gpu_offload", [], ctypes.c_bool)
def llama_supports_gpu_offload() -> bool: ... def llama_supports_gpu_offload() -> bool:
...
# LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx); # LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
@ctypes_function("llama_get_model", [llama_context_p_ctypes], llama_model_p_ctypes) @ctypes_function("llama_get_model", [llama_context_p_ctypes], llama_model_p_ctypes)
def llama_get_model(ctx: llama_context_p, /) -> Optional[llama_model_p]: ... def llama_get_model(ctx: llama_context_p, /) -> Optional[llama_model_p]:
...
# LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx); # LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
@ctypes_function("llama_n_ctx", [llama_context_p_ctypes], ctypes.c_uint32) @ctypes_function("llama_n_ctx", [llama_context_p_ctypes], ctypes.c_uint32)
def llama_n_ctx(ctx: llama_context_p, /) -> int: ... def llama_n_ctx(ctx: llama_context_p, /) -> int:
...
# LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx); # LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
@ctypes_function("llama_n_batch", [llama_context_p_ctypes], ctypes.c_uint32) @ctypes_function("llama_n_batch", [llama_context_p_ctypes], ctypes.c_uint32)
def llama_n_batch(ctx: llama_context_p, /) -> int: ... def llama_n_batch(ctx: llama_context_p, /) -> int:
...
# LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx); # LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
@ctypes_function("llama_n_ubatch", [llama_context_p_ctypes], ctypes.c_uint32) @ctypes_function("llama_n_ubatch", [llama_context_p_ctypes], ctypes.c_uint32)
def llama_n_ubatch(ctx: llama_context_p, /) -> int: ... def llama_n_ubatch(ctx: llama_context_p, /) -> int:
...
# LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx); # LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
@ctypes_function("llama_n_seq_max", [llama_context_p_ctypes], ctypes.c_uint32) @ctypes_function("llama_n_seq_max", [llama_context_p_ctypes], ctypes.c_uint32)
def llama_n_seq_max(ctx: llama_context_p, /) -> int: ... def llama_n_seq_max(ctx: llama_context_p, /) -> int:
...
# LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model); # LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
@ctypes_function("llama_vocab_type", [llama_model_p_ctypes], ctypes.c_int) @ctypes_function("llama_vocab_type", [llama_model_p_ctypes], ctypes.c_int)
def llama_vocab_type(model: llama_model_p, /) -> int: ... def llama_vocab_type(model: llama_model_p, /) -> int:
...
# LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model); # LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
@ctypes_function("llama_rope_type", [llama_model_p_ctypes], ctypes.c_int) @ctypes_function("llama_rope_type", [llama_model_p_ctypes], ctypes.c_int)
def llama_rope_type(model: llama_model_p, /) -> int: ... def llama_rope_type(model: llama_model_p, /) -> int:
...
# LLAMA_API int32_t llama_n_vocab (const struct llama_model * model); # LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
@ctypes_function("llama_n_vocab", [llama_model_p_ctypes], ctypes.c_int32) @ctypes_function("llama_n_vocab", [llama_model_p_ctypes], ctypes.c_int32)
def llama_n_vocab(model: llama_model_p, /) -> int: ... def llama_n_vocab(model: llama_model_p, /) -> int:
...
# LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model); # LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
@ctypes_function("llama_n_ctx_train", [llama_model_p_ctypes], ctypes.c_int32) @ctypes_function("llama_n_ctx_train", [llama_model_p_ctypes], ctypes.c_int32)
def llama_n_ctx_train(model: llama_model_p, /) -> int: ... def llama_n_ctx_train(model: llama_model_p, /) -> int:
...
# LLAMA_API int32_t llama_n_embd (const struct llama_model * model); # LLAMA_API int32_t llama_n_embd (const struct llama_model * model);
@ctypes_function("llama_n_embd", [llama_model_p_ctypes], ctypes.c_int32) @ctypes_function("llama_n_embd", [llama_model_p_ctypes], ctypes.c_int32)
def llama_n_embd(model: llama_model_p, /) -> int: ... def llama_n_embd(model: llama_model_p, /) -> int:
...
# LLAMA_API int32_t llama_n_layer (const struct llama_model * model); # LLAMA_API int32_t llama_n_layer (const struct llama_model * model);
@ctypes_function("llama_n_layer", [llama_model_p_ctypes], ctypes.c_int32) @ctypes_function("llama_n_layer", [llama_model_p_ctypes], ctypes.c_int32)
def llama_n_layer(model: llama_model_p, /) -> int: ... def llama_n_layer(model: llama_model_p, /) -> int:
...
# // Get the model's RoPE frequency scaling factor # // Get the model's RoPE frequency scaling factor
@ -1351,6 +1464,9 @@ class llama_kv_cache_view_cell(ctypes.Structure):
pos (llama_pos): The position for this cell. Takes KV cache shifts into account. pos (llama_pos): The position for this cell. Takes KV cache shifts into account.
May be negative if the cell is not populated.""" May be negative if the cell is not populated."""
if TYPE_CHECKING:
pos: llama_pos
_fields_ = [("pos", llama_pos)] _fields_ = [("pos", llama_pos)]
@ -1387,6 +1503,16 @@ class llama_kv_cache_view_cell(ctypes.Structure):
# llama_seq_id * cells_sequences; # llama_seq_id * cells_sequences;
# }; # };
class llama_kv_cache_view(ctypes.Structure): class llama_kv_cache_view(ctypes.Structure):
if TYPE_CHECKING:
n_cells: int
n_max_seq: int
token_count: int
used_cells: int
max_contiguous: int
max_contiguous_idx: int
cells: CtypesArray[llama_kv_cache_view_cell]
cells_sequences: CtypesArray[llama_seq_id]
_fields_ = [ _fields_ = [
("n_cells", ctypes.c_int32), ("n_cells", ctypes.c_int32),
("n_max_seq", ctypes.c_int32), ("n_max_seq", ctypes.c_int32),
@ -1467,6 +1593,7 @@ def llama_kv_cache_clear(ctx: llama_context_p, /):
# // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) # // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
# // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
# // seq_id < 0 : match any sequence # // seq_id < 0 : match any sequence
# // p0 < 0 : [0, p1] # // p0 < 0 : [0, p1]
# // p1 < 0 : [p0, inf) # // p1 < 0 : [p0, inf)
@ -1493,6 +1620,9 @@ def llama_kv_cache_seq_rm(
/, /,
) -> bool: ) -> bool:
"""Removes all tokens that belong to the specified sequence and have positions in [p0, p1) """Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
seq_id < 0 : match any sequence seq_id < 0 : match any sequence
p0 < 0 : [0, p1] p0 < 0 : [0, p1]
p1 < 0 : [p0, inf)""" p1 < 0 : [p0, inf)"""
@ -1652,7 +1782,16 @@ def llama_kv_cache_update(ctx: llama_context_p, /):
# Returns the maximum size in bytes of the state (rng, logits, embedding # Returns the maximum size in bytes of the state (rng, logits, embedding
# and kv_cache) - will often be smaller after compacting tokens # and kv_cache) - will often be smaller after compacting tokens
# LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx); # LLAMA_API size_t llama_state_get_size(const struct llama_context * ctx);
@ctypes_function("llama_state_get_size", [llama_context_p_ctypes], ctypes.c_size_t)
def llama_state_get_size(ctx: llama_context_p, /) -> int:
"""Returns the maximum size in bytes of the state (rng, logits, embedding
and kv_cache) - will often be smaller after compacting tokens"""
...
# LLAMA_API DEPRECATED(size_t llama_get_state_size(const struct llama_context * ctx),
# "use llama_state_get_size instead");
@ctypes_function("llama_get_state_size", [llama_context_p_ctypes], ctypes.c_size_t) @ctypes_function("llama_get_state_size", [llama_context_p_ctypes], ctypes.c_size_t)
def llama_get_state_size(ctx: llama_context_p, /) -> int: def llama_get_state_size(ctx: llama_context_p, /) -> int:
"""Returns the maximum size in bytes of the state (rng, logits, embedding """Returns the maximum size in bytes of the state (rng, logits, embedding
@ -1663,9 +1802,30 @@ def llama_get_state_size(ctx: llama_context_p, /) -> int:
# Copies the state to the specified destination address. # Copies the state to the specified destination address.
# Destination needs to have allocated enough memory. # Destination needs to have allocated enough memory.
# Returns the number of bytes copied # Returns the number of bytes copied
# LLAMA_API size_t llama_copy_state_data( # LLAMA_API size_t llama_state_get_data(
# struct llama_context * ctx, # struct llama_context * ctx,
# uint8_t * dst); # uint8_t * dst);
@ctypes_function(
"llama_state_get_data",
[
llama_context_p_ctypes,
ctypes.POINTER(ctypes.c_uint8),
],
ctypes.c_size_t,
)
def llama_state_get_data(
ctx: llama_context_p, dst: CtypesArray[ctypes.c_uint8], /
) -> int:
"""Copies the state to the specified destination address.
Destination needs to have allocated enough memory.
Returns the number of bytes copied"""
...
# LLAMA_API DEPRECATED(size_t llama_copy_state_data(
# struct llama_context * ctx,
# uint8_t * dst),
# "use llama_state_get_data instead");
@ctypes_function( @ctypes_function(
"llama_copy_state_data", "llama_copy_state_data",
[ [
@ -1685,9 +1845,26 @@ def llama_copy_state_data(
# // Set the state reading from the specified address # // Set the state reading from the specified address
# // Returns the number of bytes read # // Returns the number of bytes read
# LLAMA_API size_t llama_set_state_data( # LLAMA_API size_t llama_state_set_data(
# struct llama_context * ctx, # struct llama_context * ctx,
# const uint8_t * src); # const uint8_t * src);
@ctypes_function(
"llama_state_set_data",
[llama_context_p_ctypes, ctypes.POINTER(ctypes.c_uint8)],
ctypes.c_size_t,
)
def llama_state_set_data(
ctx: llama_context_p, src: CtypesArray[ctypes.c_uint8], /
) -> int:
"""Set the state reading from the specified address
Returns the number of bytes read"""
...
# LLAMA_API DEPRECATED(size_t llama_set_state_data(
# struct llama_context * ctx,
# const uint8_t * src),
# "use llama_state_set_data instead");
@ctypes_function( @ctypes_function(
"llama_set_state_data", "llama_set_state_data",
[llama_context_p_ctypes, ctypes.POINTER(ctypes.c_uint8)], [llama_context_p_ctypes, ctypes.POINTER(ctypes.c_uint8)],
@ -1701,12 +1878,41 @@ def llama_set_state_data(
# Save/load session file # Save/load session file
# LLAMA_API bool llama_load_session_file( # LLAMA_API bool llama_state_load_file(
# struct llama_context * ctx, # struct llama_context * ctx,
# const char * path_session, # const char * path_session,
# llama_token * tokens_out, # llama_token * tokens_out,
# size_t n_token_capacity, # size_t n_token_capacity,
# size_t * n_token_count_out); # size_t * n_token_count_out);
@ctypes_function(
"llama_state_load_file",
[
llama_context_p_ctypes,
ctypes.c_char_p,
llama_token_p,
ctypes.c_size_t,
ctypes.POINTER(ctypes.c_size_t),
],
ctypes.c_bool,
)
def llama_state_load_file(
ctx: llama_context_p,
path_session: bytes,
tokens_out: CtypesArray[llama_token],
n_token_capacity: Union[ctypes.c_size_t, int],
n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t],
/,
) -> bool:
...
# LLAMA_API DEPRECATED(bool llama_load_session_file(
# struct llama_context * ctx,
# const char * path_session,
# llama_token * tokens_out,
# size_t n_token_capacity,
# size_t * n_token_count_out),
# "use llama_state_load_file instead");
@ctypes_function( @ctypes_function(
"llama_load_session_file", "llama_load_session_file",
[ [
@ -1725,14 +1931,41 @@ def llama_load_session_file(
n_token_capacity: Union[ctypes.c_size_t, int], n_token_capacity: Union[ctypes.c_size_t, int],
n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t], n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t],
/, /,
) -> int: ... ) -> int:
...
# LLAMA_API bool llama_save_session_file( # LLAMA_API bool llama_state_save_file(
# struct llama_context * ctx, # struct llama_context * ctx,
# const char * path_session, # const char * path_session,
# const llama_token * tokens, # const llama_token * tokens,
# size_t n_token_count); # size_t n_token_count);
@ctypes_function(
"llama_state_save_file",
[
llama_context_p_ctypes,
ctypes.c_char_p,
llama_token_p,
ctypes.c_size_t,
],
ctypes.c_bool,
)
def llama_state_save_file(
ctx: llama_context_p,
path_session: bytes,
tokens: CtypesArray[llama_token],
n_token_count: Union[ctypes.c_size_t, int],
/,
) -> bool:
...
# LLAMA_API DEPRECATED(bool llama_save_session_file(
# struct llama_context * ctx,
# const char * path_session,
# const llama_token * tokens,
# size_t n_token_count),
# "use llama_state_save_file instead");
@ctypes_function( @ctypes_function(
"llama_save_session_file", "llama_save_session_file",
[ [
@ -1749,7 +1982,118 @@ def llama_save_session_file(
tokens: CtypesArray[llama_token], tokens: CtypesArray[llama_token],
n_token_count: Union[ctypes.c_size_t, int], n_token_count: Union[ctypes.c_size_t, int],
/, /,
) -> int: ... ) -> int:
...
# // Get the exact size needed to copy the KV cache of a single sequence
# LLAMA_API size_t llama_state_seq_get_size(
# struct llama_context * ctx,
# llama_seq_id seq_id);
@ctypes_function(
"llama_state_seq_get_size",
[llama_context_p_ctypes, llama_seq_id],
ctypes.c_size_t,
)
def llama_state_seq_get_size(ctx: llama_context_p, seq_id: llama_seq_id, /) -> int:
"""Get the exact size needed to copy the KV cache of a single sequence"""
...
# // Copy the KV cache of a single sequence into the specified buffer
# LLAMA_API size_t llama_state_seq_get_data(
# struct llama_context * ctx,
# uint8_t * dst,
# llama_seq_id seq_id);
@ctypes_function(
"llama_state_seq_get_data",
[llama_context_p_ctypes, ctypes.POINTER(ctypes.c_uint8), llama_seq_id],
ctypes.c_size_t,
)
def llama_state_seq_get_data(
ctx: llama_context_p, dst: CtypesArray[ctypes.c_uint8], seq_id: llama_seq_id, /
) -> int:
"""Copy the KV cache of a single sequence into the specified buffer"""
...
# // Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence
# // Returns:
# // - Positive: Ok
# // - Zero: Failed to load
# LLAMA_API size_t llama_state_seq_set_data(
# struct llama_context * ctx,
# const uint8_t * src,
# llama_seq_id dest_seq_id);
@ctypes_function(
"llama_state_seq_set_data",
[llama_context_p_ctypes, ctypes.POINTER(ctypes.c_uint8), llama_seq_id],
ctypes.c_size_t,
)
def llama_state_seq_set_data(
ctx: llama_context_p, src: CtypesArray[ctypes.c_uint8], dest_seq_id: llama_seq_id, /
) -> int:
"""Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence"""
...
# LLAMA_API size_t llama_state_seq_save_file(
# struct llama_context * ctx,
# const char * filepath,
# llama_seq_id seq_id,
# const llama_token * tokens,
# size_t n_token_count);
@ctypes_function(
"llama_state_seq_save_file",
[
llama_context_p_ctypes,
ctypes.c_char_p,
llama_seq_id,
llama_token_p,
ctypes.c_size_t,
],
ctypes.c_size_t,
)
def llama_state_seq_save_file(
ctx: llama_context_p,
filepath: bytes,
seq_id: llama_seq_id,
tokens: CtypesArray[llama_token],
n_token_count: Union[ctypes.c_size_t, int],
/,
) -> int:
...
# LLAMA_API size_t llama_state_seq_load_file(
# struct llama_context * ctx,
# const char * filepath,
# llama_seq_id dest_seq_id,
# llama_token * tokens_out,
# size_t n_token_capacity,
# size_t * n_token_count_out);
@ctypes_function(
"llama_state_seq_load_file",
[
llama_context_p_ctypes,
ctypes.c_char_p,
llama_seq_id,
llama_token_p,
ctypes.c_size_t,
ctypes.POINTER(ctypes.c_size_t),
],
ctypes.c_size_t,
)
def llama_state_seq_load_file(
ctx: llama_context_p,
filepath: bytes,
dest_seq_id: llama_seq_id,
tokens_out: CtypesArray[llama_token],
n_token_capacity: Union[ctypes.c_size_t, int],
n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t],
/,
) -> int:
...
# // # //
@ -1930,8 +2274,9 @@ def llama_get_logits(ctx: llama_context_p, /) -> CtypesArray[ctypes.c_float]:
... ...
# // Logits for the ith token. Equivalent to: # // Logits for the ith token. For positive indices, Equivalent to:
# // llama_get_logits(ctx) + ctx->output_ids[i]*n_vocab # // llama_get_logits(ctx) + ctx->output_ids[i]*n_vocab
# // Negative indicies can be used to access logits in reverse order, -1 is the last logit.
# // returns NULL for invalid ids. # // returns NULL for invalid ids.
# LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i); # LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
@ctypes_function( @ctypes_function(
@ -1963,8 +2308,9 @@ def llama_get_embeddings(ctx: llama_context_p, /) -> CtypesArray[ctypes.c_float]
... ...
# // Get the embeddings for the ith token. Equivalent to: # // Get the embeddings for the ith token. For positive indices, Equivalent to:
# // llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd # // llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd
# // Negative indicies can be used to access embeddings in reverse order, -1 is the last embedding.
# // shape: [n_embd] (1-dimensional) # // shape: [n_embd] (1-dimensional)
# // returns NULL for invalid ids. # // returns NULL for invalid ids.
# LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); # LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
@ -2010,7 +2356,8 @@ def llama_get_embeddings_seq(
) )
def llama_token_get_text( def llama_token_get_text(
model: llama_model_p, token: Union[llama_token, int], / model: llama_model_p, token: Union[llama_token, int], /
) -> bytes: ... ) -> bytes:
...
# LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token); # LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token);
@ -2019,7 +2366,8 @@ def llama_token_get_text(
) )
def llama_token_get_score( def llama_token_get_score(
model: llama_model_p, token: Union[llama_token, int], / model: llama_model_p, token: Union[llama_token, int], /
) -> float: ... ) -> float:
...
# LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token); # LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token);
@ -2028,7 +2376,20 @@ def llama_token_get_score(
) )
def llama_token_get_type( def llama_token_get_type(
model: llama_model_p, token: Union[llama_token, int], / model: llama_model_p, token: Union[llama_token, int], /
) -> int: ... ) -> int:
...
# // Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.)
# LLAMA_API bool llama_token_is_eog(const struct llama_model * model, llama_token token);
@ctypes_function(
"llama_token_is_eog", [llama_model_p_ctypes, llama_token], ctypes.c_bool
)
def llama_token_is_eog(
model: llama_model_p, token: Union[llama_token, int], /
) -> bool:
"""Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.)"""
...
# // Special tokens # // Special tokens
@ -2048,6 +2409,20 @@ def llama_token_eos(model: llama_model_p, /) -> int:
... ...
# LLAMA_API llama_token llama_token_cls(const struct llama_model * model); // classification
@ctypes_function("llama_token_cls", [llama_model_p_ctypes], llama_token)
def llama_token_cls(model: llama_model_p, /) -> int:
"""classification"""
...
# LLAMA_API llama_token llama_token_sep(const struct llama_model * model); // sentence separator
@ctypes_function("llama_token_sep", [llama_model_p_ctypes], llama_token)
def llama_token_sep(model: llama_model_p, /) -> int:
"""sentence separator"""
...
# LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line # LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
@ctypes_function("llama_token_nl", [llama_model_p_ctypes], llama_token) @ctypes_function("llama_token_nl", [llama_model_p_ctypes], llama_token)
def llama_token_nl(model: llama_model_p, /) -> int: def llama_token_nl(model: llama_model_p, /) -> int:
@ -2071,7 +2446,7 @@ def llama_add_eos_token(model: llama_model_p, /) -> int:
... ...
# // codellama infill tokens # // Codellama infill tokens
# LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix # LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix
@ctypes_function("llama_token_prefix", [llama_model_p_ctypes], llama_token) @ctypes_function("llama_token_prefix", [llama_model_p_ctypes], llama_token)
def llama_token_prefix(model: llama_model_p) -> int: def llama_token_prefix(model: llama_model_p) -> int:
@ -2081,17 +2456,20 @@ def llama_token_prefix(model: llama_model_p) -> int:
# LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle # LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle
@ctypes_function("llama_token_middle", [llama_model_p_ctypes], llama_token) @ctypes_function("llama_token_middle", [llama_model_p_ctypes], llama_token)
def llama_token_middle(model: llama_model_p, /) -> int: ... def llama_token_middle(model: llama_model_p, /) -> int:
...
# LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix # LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix
@ctypes_function("llama_token_suffix", [llama_model_p_ctypes], llama_token) @ctypes_function("llama_token_suffix", [llama_model_p_ctypes], llama_token)
def llama_token_suffix(model: llama_model_p, /) -> int: ... def llama_token_suffix(model: llama_model_p, /) -> int:
...
# LLAMA_API llama_token llama_token_eot (const struct llama_model * model); // End of infill middle # LLAMA_API llama_token llama_token_eot (const struct llama_model * model); // End of infill middle
@ctypes_function("llama_token_eot", [llama_model_p_ctypes], llama_token) @ctypes_function("llama_token_eot", [llama_model_p_ctypes], llama_token)
def llama_token_eot(model: llama_model_p, /) -> int: ... def llama_token_eot(model: llama_model_p, /) -> int:
...
# // # //
@ -2103,16 +2481,16 @@ def llama_token_eot(model: llama_model_p, /) -> int: ...
# /// @param tokens The tokens pointer must be large enough to hold the resulting tokens. # /// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
# /// @return Returns the number of tokens on success, no more than n_tokens_max # /// @return Returns the number of tokens on success, no more than n_tokens_max
# /// @return Returns a negative number on failure - the number of tokens that would have been returned # /// @return Returns a negative number on failure - the number of tokens that would have been returned
# /// @param special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. # /// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated
# /// Does not insert a leading space. # /// as plaintext. Does not insert a leading space.
# LLAMA_API int32_t llama_tokenize( # LLAMA_API int32_t llama_tokenize(
# const struct llama_model * model, # const struct llama_model * model,
# const char * text, # const char * text,
# int32_t text_len, # int32_t text_len,
# llama_token * tokens, # llama_token * tokens,
# int32_t n_tokens_max, # int32_t n_tokens_max,
# bool add_bos, # bool add_special,
# bool special); # bool parse_special);
@ctypes_function( @ctypes_function(
"llama_tokenize", "llama_tokenize",
[ [
@ -2132,8 +2510,8 @@ def llama_tokenize(
text_len: Union[ctypes.c_int, int], text_len: Union[ctypes.c_int, int],
tokens: CtypesArray[llama_token], tokens: CtypesArray[llama_token],
n_tokens_max: Union[ctypes.c_int, int], n_tokens_max: Union[ctypes.c_int, int],
add_bos: Union[ctypes.c_bool, bool], add_special: Union[ctypes.c_bool, bool],
special: Union[ctypes.c_bool, bool], parse_special: Union[ctypes.c_bool, bool],
/, /,
) -> int: ) -> int:
"""Convert the provided text into tokens. """Convert the provided text into tokens.
@ -2144,9 +2522,8 @@ def llama_tokenize(
text_len: The length of the text. text_len: The length of the text.
tokens: The tokens pointer must be large enough to hold the resulting tokens. tokens: The tokens pointer must be large enough to hold the resulting tokens.
n_max_tokens: The maximum number of tokens to return. n_max_tokens: The maximum number of tokens to return.
add_bos: Whether to add a beginning-of-sentence token. add_special: Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space.
special: Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. parse_special: Allow parsing special tokens.
Does not insert a leading space.
Returns: Returns:
Returns the number of tokens on success, no more than n_tokens_max Returns the number of tokens on success, no more than n_tokens_max
@ -2159,11 +2536,13 @@ def llama_tokenize(
# // Uses the vocabulary in the provided context. # // Uses the vocabulary in the provided context.
# // Does not write null terminator to the buffer. # // Does not write null terminator to the buffer.
# // User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens. # // User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens.
# // @param special If true, special tokens are rendered in the output.
# LLAMA_API int32_t llama_token_to_piece( # LLAMA_API int32_t llama_token_to_piece(
# const struct llama_model * model, # const struct llama_model * model,
# llama_token token, # llama_token token,
# char * buf, # char * buf,
# int32_t length); # int32_t length,
# bool special);
@ctypes_function( @ctypes_function(
"llama_token_to_piece", "llama_token_to_piece",
[ [
@ -2171,6 +2550,7 @@ def llama_tokenize(
llama_token, llama_token,
ctypes.c_char_p, ctypes.c_char_p,
ctypes.c_int32, ctypes.c_int32,
ctypes.c_bool,
], ],
ctypes.c_int32, ctypes.c_int32,
) )
@ -2179,13 +2559,20 @@ def llama_token_to_piece(
token: Union[llama_token, int], token: Union[llama_token, int],
buf: Union[ctypes.c_char_p, bytes, CtypesArray[ctypes.c_char]], buf: Union[ctypes.c_char_p, bytes, CtypesArray[ctypes.c_char]],
length: Union[ctypes.c_int, int], length: Union[ctypes.c_int, int],
special: Union[ctypes.c_bool, bool],
/, /,
) -> int: ) -> int:
"""Token Id -> Piece. """Token Id -> Piece.
Uses the vocabulary in the provided context. Uses the vocabulary in the provided context.
Does not write null terminator to the buffer. Does not write null terminator to the buffer.
User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens. User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens.
"""
Args:
model: The model to use for tokenization.
token: The token to convert.
buf: The buffer to write the token to.
length: The length of the buffer.
special: If true, special tokens are rendered in the output."""
... ...
@ -2223,7 +2610,8 @@ def llama_chat_apply_template(
chat: CtypesArray[llama_chat_message], chat: CtypesArray[llama_chat_message],
n_msg: int, n_msg: int,
/, /,
) -> int: ... ) -> int:
...
# // # //
@ -2753,6 +3141,12 @@ def llama_grammar_accept_token(
# bool eob; // Callback should set this to true when a beam is at end-of-beam. # bool eob; // Callback should set this to true when a beam is at end-of-beam.
# }; # };
class llama_beam_view(ctypes.Structure): class llama_beam_view(ctypes.Structure):
if TYPE_CHECKING:
tokens: CtypesArray[llama_token]
n_tokens: int
p: float
eob: bool
_fields_ = [ _fields_ = [
("tokens", llama_token_p), ("tokens", llama_token_p),
("n_tokens", ctypes.c_size_t), ("n_tokens", ctypes.c_size_t),
@ -2772,6 +3166,12 @@ class llama_beam_view(ctypes.Structure):
# bool last_call; // True iff this is the last callback invocation. # bool last_call; // True iff this is the last callback invocation.
# }; # };
class llama_beams_state(ctypes.Structure): class llama_beams_state(ctypes.Structure):
if TYPE_CHECKING:
beam_views: CtypesArray[llama_beam_view]
n_beams: int
common_prefix_length: int
last_call: bool
_fields_ = [ _fields_ = [
("beam_views", ctypes.POINTER(llama_beam_view)), ("beam_views", ctypes.POINTER(llama_beam_view)),
("n_beams", ctypes.c_size_t), ("n_beams", ctypes.c_size_t),
@ -2824,7 +3224,8 @@ def llama_beam_search(
n_past: Union[ctypes.c_int, int], n_past: Union[ctypes.c_int, int],
n_predict: Union[ctypes.c_int, int], n_predict: Union[ctypes.c_int, int],
/, /,
): ... ):
...
# /// @details Build a split GGUF final path for this chunk. # /// @details Build a split GGUF final path for this chunk.
@ -2943,4 +3344,5 @@ def llama_log_set(
[ctypes.c_void_p, llama_context_p_ctypes], [ctypes.c_void_p, llama_context_p_ctypes],
None, None,
) )
def llama_dump_timing_info_yaml(stream: ctypes.c_void_p, ctx: llama_context_p, /): ... def llama_dump_timing_info_yaml(stream: ctypes.c_void_p, ctx: llama_context_p, /):
...

View file

@ -5,11 +5,12 @@ from pathlib import Path
import sys import sys
from ctypes import * # type: ignore from ctypes import * # type: ignore
from enum import Enum from enum import Enum
from itertools import islice from itertools import islice, groupby
from typing import ( from typing import (
Any, Any,
Callable, Callable,
Dict, Dict,
Set,
Generic, Generic,
List, List,
Optional, Optional,
@ -1391,145 +1392,561 @@ from typing import List, Optional
# whitespace. Also maybe improves generation quality? # whitespace. Also maybe improves generation quality?
SPACE_RULE = '" "?' SPACE_RULE = '" "?'
PRIMITIVE_RULES = {
"boolean": '("true" | "false") space',
"number": '("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space',
"integer": '("-"? ([0-9] | [1-9] [0-9]*)) space',
"string": r""" "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space """,
"null": '"null" space',
}
INVALID_RULE_CHARS_RE = re.compile(r"[^a-zA-Z0-9-]+") INVALID_RULE_CHARS_RE = re.compile(r"[^a-zA-Z0-9-]+")
GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]') GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]')
GRAMMAR_LITERAL_ESCAPES = {"\r": "\\r", "\n": "\\n", '"': '\\"'} GRAMMAR_LITERAL_ESCAPES = {"\r": "\\r", "\n": "\\n", '"': '\\"'}
# whitespace is constrained to a single space char to prevent model "running away" in
# whitespace. Also maybe improves generation quality?
SPACE_RULE = '" "?'
def _build_repetition(item_rule, min_items, max_items, separator_rule=None, item_rule_is_literal=False):
if not separator_rule:
if min_items == 0 and max_items == 1:
return f'{item_rule}?'
elif min_items == 1 and max_items is None:
return f'{item_rule}+'
result = ''
if min_items > 0:
if item_rule_is_literal and separator_rule is None:
result = '"' + (item_rule[1:-1] * min_items) + '"'
else:
result = (f' {separator_rule} ' if separator_rule else ' ').join([item_rule] * min_items)
def opt_repetitions(up_to_n, prefix_with_sep=False):
'''
- n=4, no sep: '(a (a (a (a)?)?)?)?'
- n=4, sep=',', prefix: '("," a ("," a ("," a ("," a)?)?)?)?'
- n=4, sep=',', no prefix: '(a ("," a ("," a ("," a)?)?)?)?'
'''
content = f'{separator_rule} {item_rule}' if prefix_with_sep and separator_rule else item_rule
if up_to_n == 0:
return ''
elif up_to_n == 1:
return f'({content})?'
elif separator_rule and not prefix_with_sep:
return f'({content} {opt_repetitions(up_to_n - 1, prefix_with_sep=True)})?'
else:
return (f'({content} ' * up_to_n).rstrip() + (')?' * up_to_n)
if min_items > 0 and max_items != min_items:
result += ' '
if max_items is not None:
result += opt_repetitions(max_items - min_items, prefix_with_sep=min_items > 0)
else:
item_operator = f'({separator_rule + " " if separator_rule else ""}{item_rule})'
if min_items == 0 and separator_rule:
result = f'({item_rule} {item_operator}*)?'
else:
result += f'{item_operator}*'
return result
class BuiltinRule:
def __init__(self, content: str, deps: list = None):
self.content = content
self.deps = deps or []
_up_to_15_digits = _build_repetition('[0-9]', 0, 15)
PRIMITIVE_RULES = {
'boolean' : BuiltinRule('("true" | "false") space', []),
'decimal-part' : BuiltinRule('[0-9] ' + _up_to_15_digits, []),
'integral-part': BuiltinRule('[0-9] | [1-9] ' + _up_to_15_digits, []),
'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']),
'integer' : BuiltinRule('("-"? integral-part) space', ['integral-part']),
'value' : BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']),
'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']),
'array' : BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']),
'uuid' : BuiltinRule(r'"\"" ' + ' "-" '.join('[0-9a-fA-F]' * n for n in [8, 4, 4, 4, 12]) + r' "\"" space', []),
'char' : BuiltinRule(r'[^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])', []),
'string' : BuiltinRule(r'"\"" char* "\"" space', ['char']),
'null' : BuiltinRule('"null" space', []),
}
# TODO: support "uri", "email" string formats
STRING_FORMAT_RULES = {
'date' : BuiltinRule('[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []),
'time' : BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []),
'date-time' : BuiltinRule('date "T" time', ['date', 'time']),
'date-string' : BuiltinRule('"\\"" date "\\"" space', ['date']),
'time-string' : BuiltinRule('"\\"" time "\\"" space', ['time']),
'date-time-string': BuiltinRule('"\\"" date-time "\\"" space', ['date-time']),
}
DOTALL = '[\\U00000000-\\U0010FFFF]'
DOT = '[^\\x0A\\x0D]'
RESERVED_NAMES = set(["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()])
NON_LITERAL_SET = set('|.()[]{}*+?')
ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('[]()|{}*+?')
class SchemaConverter: class SchemaConverter:
def __init__(self, prop_order): def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern):
self._prop_order = prop_order self._prop_order = prop_order
self._rules = {"space": SPACE_RULE} self._allow_fetch = allow_fetch
self._defs: Dict[str, Any] = {} self._dotall = dotall
self._raw_pattern = raw_pattern
self._rules = {
'space': SPACE_RULE,
}
self._refs = {}
self._refs_being_resolved = set()
def _format_literal(self, literal: str): def _format_literal(self, literal):
escaped: str = GRAMMAR_LITERAL_ESCAPE_RE.sub( escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), json.dumps(literal) lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), literal
) )
return f'"{escaped}"' return f'"{escaped}"'
def _add_rule(self, name: str, rule: str): def not_literal(self, literal: str, dotall: bool = True, maybe_escaped_underscores = False) -> str:
esc_name = INVALID_RULE_CHARS_RE.sub("-", name) '''
not_literal('a') -> '[^a]'
not_literal('abc') -> '([^a] | "a" ([^b] | "b" ([^c])?)?)?'
'''
assert len(literal) > 0, 'Empty literal not supported'
def recurse(i: int):
c = literal[i]
if maybe_escaped_underscores and c == '_':
yield f'[^{c}\\\\]'
yield ' | '
yield f'"\\\\"? "{c}"'
else:
yield f'[^{c}]'
if i < len(literal) - 1:
yield ' | '
yield self._format_literal(c)
yield ' ('
yield from recurse(i + 1)
yield ')?'
return ''.join(('(', *recurse(0), ')'))
def _add_rule(self, name, rule):
esc_name = INVALID_RULE_CHARS_RE.sub('-', name)
if esc_name not in self._rules or self._rules[esc_name] == rule: if esc_name not in self._rules or self._rules[esc_name] == rule:
key = esc_name key = esc_name
else: else:
i = 0 i = 0
while f"{esc_name}{i}" in self._rules: while f'{esc_name}{i}' in self._rules and self._rules[f'{esc_name}{i}'] != rule:
i += 1 i += 1
key = f"{esc_name}{i}" key = f'{esc_name}{i}'
self._rules[key] = rule self._rules[key] = rule
return key return key
def visit(self, schema: Dict[str, Any], name: str) -> str: def resolve_refs(self, schema: dict, url: str):
rule_name = name or "root" '''
Resolves all $ref fields in the given schema, fetching any remote schemas,
replacing $ref with absolute reference URL and populating self._refs with the
respective referenced (sub)schema dictionaries.
'''
def visit(n: dict):
if isinstance(n, list):
return [visit(x) for x in n]
elif isinstance(n, dict):
ref = n.get('$ref')
if ref is not None and ref not in self._refs:
if ref.startswith('https://'):
assert self._allow_fetch, 'Fetching remote schemas is not allowed (use --allow-fetch for force)'
import requests
if "$defs" in schema: frag_split = ref.split('#')
# add defs to self._defs for later inlining base_url = frag_split[0]
for def_name, def_schema in schema["$defs"].items():
self._defs[def_name] = def_schema
if "oneOf" in schema or "anyOf" in schema: target = self._refs.get(base_url)
rule = " | ".join( if target is None:
( target = self.resolve_refs(requests.get(ref).json(), base_url)
self.visit(alt_schema, f'{name}{"-" if name else ""}{i}') self._refs[base_url] = target
for i, alt_schema in enumerate(
schema.get("oneOf") or schema["anyOf"] if len(frag_split) == 1 or frag_split[-1] == '':
) return target
) elif ref.startswith('#/'):
target = schema
ref = f'{url}{ref}'
n['$ref'] = ref
else:
raise ValueError(f'Unsupported ref {ref}')
for sel in ref.split('#')[-1].split('/')[1:]:
assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}'
target = target[sel]
self._refs[ref] = target
else:
for v in n.values():
visit(v)
return n
return visit(schema)
def _generate_union_rule(self, name, alt_schemas):
return ' | '.join((
self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}')
for i, alt_schema in enumerate(alt_schemas)
))
def _visit_pattern(self, pattern, name):
'''
Transforms a regular expression pattern into a GBNF rule.
Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions
Output: https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers.
Mostly a 1:1 translation, except for {x} / {x,} / {x,y} quantifiers for which
we define sub-rules to keep the output lean.
'''
assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"'
pattern = pattern[1:-1]
sub_rule_ids = {}
i = 0
length = len(pattern)
def to_rule(s: Tuple[str, bool]) -> str:
(txt, is_literal) = s
return "\"" + txt + "\"" if is_literal else txt
def transform() -> Tuple[str, bool]:
'''
Parse a unit at index i (advancing it), and return its string representation + whether it's a literal.
'''
nonlocal i
nonlocal pattern
nonlocal sub_rule_ids
start = i
# For each component of this sequence, store its string representation and whether it's a literal.
# We only need a flat structure here to apply repetition operators to the last item, and
# to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially
# (GBNF's syntax is luckily very close to regular expressions!)
seq: list[Tuple[str, bool]] = []
def get_dot():
if self._dotall:
rule = DOTALL
else:
# Accept any character... except \n and \r line break chars (\x0A and \xOD)
rule = DOT
return self._add_rule(f'dot', rule)
def join_seq():
nonlocal seq
ret = []
for is_literal, g in groupby(seq, lambda x: x[1]):
if is_literal:
ret.append((''.join(x[0] for x in g), True))
else:
ret.extend(g)
if len(ret) == 1:
return ret[0]
return (' '.join(to_rule(x) for x in seq), False)
while i < length:
c = pattern[i]
if c == '.':
seq.append((get_dot(), False))
i += 1
elif c == '(':
i += 1
if i < length:
assert pattern[i] != '?', f'Unsupported pattern syntax "{pattern[i]}" at index {i} of /{pattern}/'
seq.append((f'({to_rule(transform())})', False))
elif c == ')':
i += 1
assert start > 0 and pattern[start-1] == '(', f'Unbalanced parentheses; start = {start}, i = {i}, pattern = {pattern}'
return join_seq()
elif c == '[':
square_brackets = c
i += 1
while i < length and pattern[i] != ']':
if pattern[i] == '\\':
square_brackets += pattern[i:i+2]
i += 2
else:
square_brackets += pattern[i]
i += 1
assert i < length, f'Unbalanced square brackets; start = {start}, i = {i}, pattern = {pattern}'
square_brackets += ']'
i += 1
seq.append((square_brackets, False))
elif c == '|':
seq.append(('|', False))
i += 1
elif c in ('*', '+', '?'):
seq[-1] = (to_rule(seq[-1]) + c, False)
i += 1
elif c == '{':
curly_brackets = c
i += 1
while i < length and pattern[i] != '}':
curly_brackets += pattern[i]
i += 1
assert i < length, f'Unbalanced curly brackets; start = {start}, i = {i}, pattern = {pattern}'
curly_brackets += '}'
i += 1
nums = [s.strip() for s in curly_brackets[1:-1].split(',')]
min_times = 0
max_times = None
try:
if len(nums) == 1:
min_times = int(nums[0])
max_times = min_times
else:
assert len(nums) == 2
min_times = int(nums[0]) if nums[0] else 0
max_times = int(nums[1]) if nums[1] else None
except ValueError:
raise ValueError(f'Invalid quantifier {curly_brackets} in /{pattern}/')
(sub, sub_is_literal) = seq[-1]
if not sub_is_literal:
id = sub_rule_ids.get(sub)
if id is None:
id = self._add_rule(f'{name}-{len(sub_rule_ids) + 1}', sub)
sub_rule_ids[sub] = id
sub = id
seq[-1] = (_build_repetition(f'"{sub}"' if sub_is_literal else sub, min_times, max_times, item_rule_is_literal=sub_is_literal), False)
else:
literal = ''
while i < length:
if pattern[i] == '\\' and i < length - 1:
next = pattern[i + 1]
if next in ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS:
i += 1
literal += pattern[i]
i += 1
else:
literal += pattern[i:i+2]
i += 2
elif pattern[i] == '"' and not self._raw_pattern:
literal += '\\"'
i += 1
elif pattern[i] not in NON_LITERAL_SET and \
(i == length - 1 or literal == '' or pattern[i+1] == '.' or pattern[i+1] not in NON_LITERAL_SET):
literal += pattern[i]
i += 1
else:
break
if literal:
seq.append((literal, True))
return join_seq()
return self._add_rule(
name,
to_rule(transform()) if self._raw_pattern \
else "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space")
def _resolve_ref(self, ref):
ref_name = ref.split('/')[-1]
if ref_name not in self._rules and ref not in self._refs_being_resolved:
self._refs_being_resolved.add(ref)
resolved = self._refs[ref]
ref_name = self.visit(resolved, ref_name)
self._refs_being_resolved.remove(ref)
return ref_name
def _generate_constant_rule(self, value):
return self._format_literal(json.dumps(value))
def visit(self, schema, name):
schema_type = schema.get('type')
schema_format = schema.get('format')
rule_name = name + '-' if name in RESERVED_NAMES else name or 'root'
if (ref := schema.get('$ref')) is not None:
return self._add_rule(rule_name, self._resolve_ref(ref))
elif 'oneOf' in schema or 'anyOf' in schema:
return self._add_rule(rule_name, self._generate_union_rule(name, schema.get('oneOf') or schema['anyOf']))
elif isinstance(schema_type, list):
return self._add_rule(rule_name, self._generate_union_rule(name, [{'type': t} for t in schema_type]))
elif 'const' in schema:
return self._add_rule(rule_name, self._generate_constant_rule(schema['const']))
elif 'enum' in schema:
rule = ' | '.join((self._generate_constant_rule(v) for v in schema['enum']))
return self._add_rule(rule_name, rule)
elif schema_type in (None, 'object') and \
('properties' in schema or \
('additionalProperties' in schema and schema['additionalProperties'] is not True)):
required = set(schema.get('required', []))
properties = list(schema.get('properties', {}).items())
return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties')))
elif schema_type in (None, 'object') and 'allOf' in schema:
required = set()
properties = []
hybrid_name = name
def add_component(comp_schema, is_required):
if (ref := comp_schema.get('$ref')) is not None:
comp_schema = self._refs[ref]
if 'properties' in comp_schema:
for prop_name, prop_schema in comp_schema['properties'].items():
properties.append((prop_name, prop_schema))
if is_required:
required.add(prop_name)
for t in schema['allOf']:
if 'anyOf' in t:
for tt in t['anyOf']:
add_component(tt, is_required=False)
else:
add_component(t, is_required=True)
return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=[]))
elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):
items = schema.get('items') or schema['prefixItems']
if isinstance(items, list):
return self._add_rule(
rule_name,
'"[" space ' +
' "," space '.join(
self.visit(item, f'{name}{"-" if name else ""}tuple-{i}')
for i, item in enumerate(items)) +
' "]" space')
else:
item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item')
min_items = schema.get("minItems", 0)
max_items = schema.get("maxItems")
return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' "]" space')
elif schema_type in (None, 'string') and 'pattern' in schema:
return self._visit_pattern(schema['pattern'], rule_name)
elif schema_type in (None, 'string') and re.match(r'^uuid[1-5]?$', schema_format or ''):
return self._add_primitive(
'root' if rule_name == 'root' else schema_format,
PRIMITIVE_RULES['uuid']
) )
return self._add_rule(rule_name, rule)
elif "const" in schema: elif schema_type in (None, 'string') and f'{schema_format}-string' in STRING_FORMAT_RULES:
return self._add_rule(rule_name, self._format_literal(schema["const"])) prim_name = f'{schema_format}-string'
return self._add_rule(rule_name, self._add_primitive(prim_name, STRING_FORMAT_RULES[prim_name]))
elif "enum" in schema: elif schema_type == 'string' and ('minLength' in schema or 'maxLength' in schema):
rule = " | ".join((self._format_literal(v) for v in schema["enum"])) char_rule = self._add_primitive('char', PRIMITIVE_RULES['char'])
return self._add_rule(rule_name, rule) min_len = schema.get('minLength', 0)
max_len = schema.get('maxLength')
elif "$ref" in schema: return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\"" space')
ref = schema["$ref"]
assert ref.startswith("#/$defs/"), f"Unrecognized schema: {schema}"
# inline $defs
def_name = ref[len("#/$defs/") :]
def_schema = self._defs[def_name]
return self.visit(def_schema, f'{name}{"-" if name else ""}{def_name}')
elif (schema_type == 'object') or (len(schema) == 0):
schema_type: Optional[str] = schema.get("type") # type: ignore return self._add_rule(rule_name, self._add_primitive('object', PRIMITIVE_RULES['object']))
assert isinstance(schema_type, str), f"Unrecognized schema: {schema}"
if schema_type == "object" and "properties" in schema:
# TODO: `required` keyword
if self._prop_order:
prop_order = self._prop_order
prop_pairs = sorted(
schema["properties"].items(),
# sort by position in prop_order (if specified) then by key
key=lambda kv: (prop_order.get(kv[0], len(prop_order)), kv[0]),
)
else:
prop_pairs = schema["properties"].items()
rule = '"{" space'
for i, (prop_name, prop_schema) in enumerate(prop_pairs):
prop_rule_name = self.visit(
prop_schema, f'{name}{"-" if name else ""}{prop_name}'
)
if i > 0:
rule += ' "," space'
rule += rf' {self._format_literal(prop_name)} space ":" space {prop_rule_name}'
rule += ' "}" space'
return self._add_rule(rule_name, rule)
elif schema_type == "array" and "items" in schema:
# TODO `prefixItems` keyword
item_rule_name = self.visit(
schema["items"], f'{name}{"-" if name else ""}item'
)
list_item_operator = f'("," space {item_rule_name})'
successive_items = ""
min_items = schema.get("minItems", 0)
if min_items > 0:
first_item = f"({item_rule_name})"
successive_items = list_item_operator * (min_items - 1)
min_items -= 1
else:
first_item = f"({item_rule_name})?"
max_items = schema.get("maxItems")
if max_items is not None and max_items > min_items:
successive_items += (list_item_operator + "?") * (max_items - min_items - 1)
else:
successive_items += list_item_operator + "*"
rule = f'"[" space {first_item} {successive_items} "]" space'
return self._add_rule(rule_name, rule)
else: else:
assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}" assert schema_type in PRIMITIVE_RULES, f'Unrecognized schema: {schema}'
return self._add_rule( # TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
"root" if rule_name == "root" else schema_type, return self._add_primitive('root' if rule_name == 'root' else schema_type, PRIMITIVE_RULES[schema_type])
PRIMITIVE_RULES[schema_type],
def _add_primitive(self, name: str, rule: BuiltinRule):
n = self._add_rule(name, rule.content)
for dep in rule.deps:
dep_rule = PRIMITIVE_RULES.get(dep) or STRING_FORMAT_RULES.get(dep)
assert dep_rule, f'Rule {dep} not known'
if dep not in self._rules:
self._add_primitive(dep, dep_rule)
return n
def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Union[bool, Any]):
prop_order = self._prop_order
# sort by position in prop_order (if specified) then by original order
sorted_props = [kv[0] for _, kv in sorted(enumerate(properties), key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]))]
prop_kv_rule_names = {}
for prop_name, prop_schema in properties:
prop_rule_name = self.visit(prop_schema, f'{name}{"-" if name else ""}{prop_name}')
prop_kv_rule_names[prop_name] = self._add_rule(
f'{name}{"-" if name else ""}{prop_name}-kv',
fr'{self._format_literal(json.dumps(prop_name))} space ":" space {prop_rule_name}'
) )
required_props = [k for k in sorted_props if k in required]
optional_props = [k for k in sorted_props if k not in required]
if additional_properties == True or isinstance(additional_properties, dict):
sub_name = f'{name}{"-" if name else ""}additional'
value_rule = self.visit({} if additional_properties == True else additional_properties, f'{sub_name}-value')
prop_kv_rule_names["*"] = self._add_rule(
f'{sub_name}-kv',
self._add_primitive('string', PRIMITIVE_RULES['string']) + f' ":" space {value_rule}'
)
optional_props.append("*")
rule = '"{" space '
rule += ' "," space '.join(prop_kv_rule_names[k] for k in required_props)
if optional_props:
rule += ' ('
if required_props:
rule += ' "," space ( '
def get_recursive_refs(ks, first_is_optional):
[k, *rest] = ks
kv_rule_name = prop_kv_rule_names[k]
if k == '*':
res = self._add_rule(
f'{name}{"-" if name else ""}additional-kvs',
f'{kv_rule_name} ( "," space ' + kv_rule_name + ' )*'
)
elif first_is_optional:
res = f'( "," space {kv_rule_name} )?'
else:
res = kv_rule_name
if len(rest) > 0:
res += ' ' + self._add_rule(
f'{name}{"-" if name else ""}{k}-rest',
get_recursive_refs(rest, first_is_optional=True)
)
return res
rule += ' | '.join(
get_recursive_refs(optional_props[i:], first_is_optional=False)
for i in range(len(optional_props))
)
if required_props:
rule += ' )'
rule += ' )?'
rule += ' "}" space'
return rule
def format_grammar(self): def format_grammar(self):
return "\n".join((f"{name} ::= {rule}" for name, rule in self._rules.items())) return '\n'.join(
f'{name} ::= {rule}'
for name, rule in sorted(self._rules.items(), key=lambda kv: kv[0])
)
def json_schema_to_gbnf(schema: str, prop_order: Optional[List[str]] = None): def json_schema_to_gbnf(schema: str, prop_order: Optional[List[str]] = None):
prop_order = prop_order or [] prop_order = prop_order or []
schema = json.loads(schema) schema = json.loads(schema)
prop_order = {name: idx for idx, name in enumerate(prop_order)} prop_order = {name: idx for idx, name in enumerate(prop_order)}
converter = SchemaConverter(prop_order) converter = SchemaConverter(prop_order=prop_order, allow_fetch=False, dotall=False, raw_pattern=False)
schema = converter.resolve_refs(schema, "stdin")
converter.visit(schema, "") converter.visit(schema, "")
return converter.format_grammar() return converter.format_grammar()

View file

@ -59,7 +59,16 @@ def main():
if not os.path.exists(config_file): if not os.path.exists(config_file):
raise ValueError(f"Config file {config_file} not found!") raise ValueError(f"Config file {config_file} not found!")
with open(config_file, "rb") as f: with open(config_file, "rb") as f:
config_file_settings = ConfigFileSettings.model_validate_json(f.read()) # Check if yaml file
if config_file.endswith(".yaml") or config_file.endswith(".yml"):
import yaml
import json
config_file_settings = ConfigFileSettings.model_validate_json(
json.dumps(yaml.safe_load(f))
)
else:
config_file_settings = ConfigFileSettings.model_validate_json(f.read())
server_settings = ServerSettings.model_validate(config_file_settings) server_settings = ServerSettings.model_validate(config_file_settings)
model_settings = config_file_settings.models model_settings = config_file_settings.models
else: else:

View file

@ -87,6 +87,13 @@ def get_llama_proxy():
llama_outer_lock.release() llama_outer_lock.release()
_ping_message_factory = None
def set_ping_message_factory(factory):
global _ping_message_factory
_ping_message_factory = factory
def create_app( def create_app(
settings: Settings | None = None, settings: Settings | None = None,
server_settings: ServerSettings | None = None, server_settings: ServerSettings | None = None,
@ -97,7 +104,15 @@ def create_app(
if not os.path.exists(config_file): if not os.path.exists(config_file):
raise ValueError(f"Config file {config_file} not found!") raise ValueError(f"Config file {config_file} not found!")
with open(config_file, "rb") as f: with open(config_file, "rb") as f:
config_file_settings = ConfigFileSettings.model_validate_json(f.read()) # Check if yaml file
if config_file.endswith(".yaml") or config_file.endswith(".yml"):
import yaml
config_file_settings = ConfigFileSettings.model_validate_json(
json.dumps(yaml.safe_load(f))
)
else:
config_file_settings = ConfigFileSettings.model_validate_json(f.read())
server_settings = ServerSettings.model_validate(config_file_settings) server_settings = ServerSettings.model_validate(config_file_settings)
model_settings = config_file_settings.models model_settings = config_file_settings.models
@ -130,6 +145,9 @@ def create_app(
assert model_settings is not None assert model_settings is not None
set_llama_proxy(model_settings=model_settings) set_llama_proxy(model_settings=model_settings)
if server_settings.disable_ping_events:
set_ping_message_factory(lambda: bytes())
return app return app
@ -294,6 +312,7 @@ async def create_completion(
iterator=iterator(), iterator=iterator(),
), ),
sep="\n", sep="\n",
ping_message_factory=_ping_message_factory,
) )
else: else:
return iterator_or_completion return iterator_or_completion
@ -462,6 +481,7 @@ async def create_chat_completion(
iterator=iterator(), iterator=iterator(),
), ),
sep="\n", sep="\n",
ping_message_factory=_ping_message_factory,
) )
else: else:
return iterator_or_completion return iterator_or_completion

View file

@ -3,7 +3,7 @@ from __future__ import annotations
import multiprocessing import multiprocessing
from typing import Optional, List, Literal, Union from typing import Optional, List, Literal, Union
from pydantic import Field from pydantic import Field, root_validator
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
import llama_cpp import llama_cpp
@ -67,12 +67,12 @@ class ModelSettings(BaseSettings):
n_threads: int = Field( n_threads: int = Field(
default=max(multiprocessing.cpu_count() // 2, 1), default=max(multiprocessing.cpu_count() // 2, 1),
ge=1, ge=1,
description="The number of threads to use.", description="The number of threads to use. Use -1 for max cpu threads",
) )
n_threads_batch: int = Field( n_threads_batch: int = Field(
default=max(multiprocessing.cpu_count() // 2, 1), default=max(multiprocessing.cpu_count(), 1),
ge=0, ge=0,
description="The number of threads to use when batch processing.", description="The number of threads to use when batch processing. Use -1 for max cpu threads",
) )
rope_scaling_type: int = Field( rope_scaling_type: int = Field(
default=llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED default=llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
@ -173,6 +173,16 @@ class ModelSettings(BaseSettings):
default=True, description="Whether to print debug information." default=True, description="Whether to print debug information."
) )
@root_validator(pre=True) # pre=True to ensure this runs before any other validation
def set_dynamic_defaults(cls, values):
# If n_threads or n_threads_batch is -1, set it to multiprocessing.cpu_count()
cpu_count = multiprocessing.cpu_count()
if values.get('n_threads', 0) == -1:
values['n_threads'] = cpu_count
if values.get('n_threads_batch', 0) == -1:
values['n_threads_batch'] = cpu_count
return values
class ServerSettings(BaseSettings): class ServerSettings(BaseSettings):
"""Server settings used to configure the FastAPI and Uvicorn server.""" """Server settings used to configure the FastAPI and Uvicorn server."""
@ -195,6 +205,10 @@ class ServerSettings(BaseSettings):
default=True, default=True,
description="Whether to interrupt requests when a new request is received.", description="Whether to interrupt requests when a new request is received.",
) )
disable_ping_events: bool = Field(
default=False,
description="Disable EventSource pings (may be needed for some clients).",
)
class Settings(ServerSettings, ModelSettings): class Settings(ServerSettings, ModelSettings):

View file

@ -1,5 +1,5 @@
[build-system] [build-system]
requires = ["scikit-build-core[pyproject]>=0.5.1"] requires = ["scikit-build-core[pyproject]>=0.9.2"]
build-backend = "scikit_build_core.build" build-backend = "scikit_build_core.build"
[project] [project]
@ -35,6 +35,7 @@ server = [
"pydantic-settings>=2.0.1", "pydantic-settings>=2.0.1",
"sse-starlette>=1.6.1", "sse-starlette>=1.6.1",
"starlette-context>=0.3.6,<0.4", "starlette-context>=0.3.6,<0.4",
"PyYAML>=5.1",
] ]
test = [ test = [
"pytest>=7.4.0", "pytest>=7.4.0",

2
vendor/llama.cpp vendored

@ -1 +1 @@
Subproject commit 75cd4c77292034ecec587ecb401366f57338f7c0 Subproject commit 4e96a812b3ce7322a29a3008db2ed73d9087b176