Compare commits

...

17 commits

Author SHA1 Message Date
baalajimaestro dc23d15918
Merge https://github.com/abetlen/llama-cpp-python 2024-03-09 15:30:07 +05:30
Andrei Betlen a7281994d8 chore: Bump version 2024-03-08 21:14:44 -05:00
Andrei Betlen 919fca9f2b Merge branch 'main' of https://github.com/abetlen/llama-cpp-python into main 2024-03-08 21:10:56 -05:00
Andrei Betlen d02a9cf16f Fixed json strings grammar by blacklisting character control set. Closes #1259 2024-03-08 21:10:53 -05:00
Felipe Lorenz c139f8b5d5
feat: Add endpoints for tokenize, detokenize and count tokens (#1136)
* Add endpoint to count tokens

* Add tokenize and detokenize endpoints

* Change response key to tokens for tokenize endpoint

* Fix dependency bug

* Cleanup

* Remove example added by mistake

* Move tokenize, detokenize, and count to Extras namespace. Tag existing endpoints

---------

Co-authored-by: Andrei Betlen <abetlen@gmail.com>
2024-03-08 21:09:00 -05:00
Kevin Cao 1f3156d4f2
fix: Check for existence of clip model path (#1264) 2024-03-08 21:00:10 -05:00
Douglas Hanley 2811014bae
feat: Switch embed to llama_get_embeddings_seq (#1263)
* switch to llama_get_embeddings_seq

* Remove duplicate definition of llama_get_embeddings_seq

Co-authored-by: Andrei <abetlen@gmail.com>

---------

Co-authored-by: Andrei <abetlen@gmail.com>
2024-03-08 20:59:35 -05:00
Andrei Betlen 40c6b54f68 feat: Update llama.cpp 2024-03-08 20:58:50 -05:00
Andrei Betlen 93dc56ace8 Update llama.cpp 2024-03-06 01:32:00 -05:00
Andrei Betlen 87a6e5797e feat: Update llama.cpp 2024-03-03 11:27:04 -05:00
Andrei Betlen 13177aae0f chore: Bump version 2024-03-02 22:46:40 -05:00
Kenneth Hoste 663659f730
docs: fix small typo in README: 'model know how' -> 'model knows how' (#1244)
Co-authored-by: Andrei <abetlen@gmail.com>
2024-03-02 22:20:41 -05:00
Andrei Betlen 0e70984fb6 feat: Update llama.cpp 2024-03-02 22:20:04 -05:00
Andrei Betlen d5df431278 chore: Bump version 2024-03-01 13:15:16 -05:00
Andrei Betlen 97aa3a153d docs: Add information re: auto chat formats. Closes #1236 2024-03-01 13:10:25 -05:00
Andrei Betlen f062a7f51d feat: Update llama.cpp 2024-03-01 12:57:16 -05:00
Douglas Hanley cf1fdd8a9a
docs: fix typo in README.md embeddings example. (#1232) 2024-02-29 13:55:50 -05:00
10 changed files with 232 additions and 43 deletions

View file

@ -7,6 +7,24 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
## [0.2.56]
- feat: Update llama.cpp to ggerganov/llama.cpp@c2101a2e909ac7c08976d414e64e96c90ee5fa9e
- feat(server): Add endpoints for tokenize, detokenize and count tokens by @felipelo in #1136
- feat: Switch embed to llama_get_embeddings_seq by @iamlemec in #1263
- fix: Fixed json strings grammar by blacklisting character control set by @ExtReMLapin in d02a9cf16ff88ad011e2eb1ce29f4d9400f13cd1
- fix: Check for existence of clip model path by @kejcao in #1264
## [0.2.55]
- feat: Update llama.cpp to ggerganov/9731134296af3a6839cd682e51d9c2109a871de5
- docs: fix small typo in README: 'model know how' -> 'model knows how' by @boegel in #1244
## [0.2.54]
- feat: Update llama.cpp to ggerganov/llama.cpp@cb49e0f8c906e5da49e9f6d64a57742a9a241c6a
- docs: fix typo in README.md embeddings example by @iamlemec in #1232
## [0.2.53]
- feat: Update llama.cpp to ggerganov/llama.cpp@cb49e0f8c906e5da49e9f6d64a57742a9a241c6a

View file

@ -286,7 +286,16 @@ By default [`from_pretrained`](https://llama-cpp-python.readthedocs.io/en/latest
The high-level API also provides a simple interface for chat completion.
Note that `chat_format` option must be set for the particular model you are using.
Chat completion requires that the model knows how to format the messages into a single prompt.
The `Llama` class does this using pre-registered chat formats (ie. `chatml`, `llama-2`, `gemma`, etc) or by providing a custom chat handler object.
The model will will format the messages into a single prompt using the following order of precedence:
- Use the `chat_handler` if provided
- Use the `chat_format` if provided
- Use the `tokenizer.chat_template` from the `gguf` model's metadata (should work for most new models, older models may not have this)
- else, fallback to the `llama-2` chat format
Set `verbose=True` to see the selected chat format.
```python
>>> from llama_cpp import Llama
@ -525,7 +534,7 @@ To generate text embeddings use [`create_embedding`](http://localhost:8000/api-r
```python
import llama_cpp
llm = llama_cpp.Llama(model_path="path/to/model.gguf", embeddings=True)
llm = llama_cpp.Llama(model_path="path/to/model.gguf", embedding=True)
embeddings = llm.create_embedding("Hello, world!")

View file

@ -1,4 +1,4 @@
from .llama_cpp import *
from .llama import *
__version__ = "0.2.53"
__version__ = "0.2.56"

View file

@ -86,7 +86,6 @@ class Llama:
yarn_beta_fast: float = 32.0,
yarn_beta_slow: float = 1.0,
yarn_orig_ctx: int = 0,
mul_mat_q: bool = True,
logits_all: bool = False,
embedding: bool = False,
offload_kqv: bool = True,
@ -291,11 +290,10 @@ class Llama:
yarn_beta_slow if yarn_beta_slow != 0.0 else 0
)
self.context_params.yarn_orig_ctx = yarn_orig_ctx if yarn_orig_ctx != 0 else 0
self.context_params.mul_mat_q = mul_mat_q
self.context_params.logits_all = (
logits_all if draft_model is None else True
) # Must be set to True for speculative decoding
self.context_params.embedding = embedding
self.context_params.embeddings = embedding # TODO: Rename to embeddings
self.context_params.offload_kqv = offload_kqv
# Sampling Params
@ -412,7 +410,7 @@ class Llama:
bos_token = self._model.token_get_text(bos_token_id)
if self.verbose:
print(f"Using chat template: {template}", file=sys.stderr)
print(f"Using gguf chat template: {template}", file=sys.stderr)
print(f"Using chat eos_token: {eos_token}", file=sys.stderr)
print(f"Using chat bos_token: {bos_token}", file=sys.stderr)
@ -422,6 +420,8 @@ class Llama:
if self.chat_format is None and self.chat_handler is None:
self.chat_format = "llama-2"
if self.verbose:
print(f"Using fallback chat format: {chat_format}", file=sys.stderr)
@property
def ctx(self) -> llama_cpp.llama_context_p:
@ -787,7 +787,7 @@ class Llama:
n_embd = self.n_embd()
n_batch = self.n_batch
if self.context_params.embedding == False:
if self.context_params.embeddings == False:
raise RuntimeError(
"Llama model must be created with embedding=True to call this method"
)
@ -814,7 +814,7 @@ class Llama:
# store embeddings
for i in range(n_seq):
embedding: List[float] = llama_cpp.llama_get_embeddings_ith(
embedding: List[float] = llama_cpp.llama_get_embeddings_seq(
self._ctx.ctx, i
)[:n_embd]
if normalize:
@ -1724,9 +1724,8 @@ class Llama:
yarn_beta_fast=self.context_params.yarn_beta_fast,
yarn_beta_slow=self.context_params.yarn_beta_slow,
yarn_orig_ctx=self.context_params.yarn_orig_ctx,
mul_mat_q=self.context_params.mul_mat_q,
logits_all=self.context_params.logits_all,
embedding=self.context_params.embedding,
embedding=self.context_params.embeddings,
# Sampling Params
last_n_tokens_size=self.last_n_tokens_size,
# LoRA Params
@ -1768,7 +1767,6 @@ class Llama:
yarn_beta_fast=state["yarn_beta_fast"],
yarn_beta_slow=state["yarn_beta_slow"],
yarn_orig_ctx=state["yarn_orig_ctx"],
mul_mat_q=state["mul_mat_q"],
logits_all=state["logits_all"],
embedding=state["embedding"],
# Sampling Params

View file

@ -1848,6 +1848,9 @@ class Llava15ChatHandler:
self.verbose = verbose
self._clip_free = self._llava_cpp._libllava.clip_free # type: ignore
if not os.path.exists(clip_model_path):
raise ValueError(f"Clip model path does not exist: {clip_model_path}")
with suppress_stdout_stderr(disable=self.verbose):
self.clip_ctx = self._llava_cpp.clip_model_load(
self.clip_model_path.encode(), 0

View file

@ -148,6 +148,12 @@ ggml_backend_sched_eval_callback = ctypes.CFUNCTYPE(
ctypes.c_bool, ctypes.c_void_p, ctypes.c_bool, ctypes.c_void_p
)
# // Abort callback
# // If not NULL, called before ggml computation
# // If it returns true, the computation is aborted
# typedef bool (*ggml_abort_callback)(void * data);
ggml_abort_callback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_void_p)
# llama.h bindings
_lib.llama_max_devices.argtypes = []
@ -314,10 +320,12 @@ LLAMA_ROPE_SCALING_TYPE_YARN = 2
LLAMA_ROPE_SCALING_TYPE_MAX_VALUE = LLAMA_ROPE_SCALING_TYPE_YARN
# enum llama_pooling_type {
# LLAMA_POOLING_TYPE_UNSPECIFIED = -1,
# LLAMA_POOLING_TYPE_NONE = 0,
# LLAMA_POOLING_TYPE_MEAN = 1,
# LLAMA_POOLING_TYPE_CLS = 2,
# };
LLAMA_POOLING_TYPE_UNSPECIFIED = -1
LLAMA_POOLING_TYPE_NONE = 0
LLAMA_POOLING_TYPE_MEAN = 1
LLAMA_POOLING_TYPE_CLS = 2
@ -391,7 +399,7 @@ llama_progress_callback = ctypes.CFUNCTYPE(
# // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
# // - pos : the positions of the respective token in the sequence
# // - seq_id : the sequence to which the respective token belongs
# // - logits : if zero, the logits for the respective token will not be output
# // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
# //
# typedef struct llama_batch {
# int32_t n_tokens;
@ -401,7 +409,7 @@ llama_progress_callback = ctypes.CFUNCTYPE(
# llama_pos * pos;
# int32_t * n_seq_id;
# llama_seq_id ** seq_id;
# int8_t * logits;
# int8_t * logits; // TODO: rename this to "output"
# // NOTE: helpers for smooth API transition - can be deprecated in the future
@ -421,10 +429,12 @@ class llama_batch(ctypes.Structure):
The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
Attributes:
n_tokens (int): number of tokens
token (ctypes.Array[llama_token]): the token ids of the input (used when embd is NULL)
embd (ctypes.Array[ctypes.ctypes.c_float]): token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
pos (ctypes.Array[ctypes.Array[llama_pos]]): the positions of the respective token in the sequence
seq_id (ctypes.Array[ctypes.Array[llama_seq_id]]): the sequence to which the respective token belongs
logits (ctypes.Array[ctypes.ctypes.c_int8]): if zero, the logits for the respective token will not be output
"""
_fields_ = [
@ -539,9 +549,13 @@ class llama_model_params(ctypes.Structure):
# uint32_t seed; // RNG seed, -1 for random
# uint32_t n_ctx; // text context, 0 = from model
# uint32_t n_batch; // prompt processing maximum batch size
# uint32_t n_parallel; // number of parallel sequences (i.e. distinct states for recurrent models)
# uint32_t n_threads; // number of threads to use for generation
# uint32_t n_threads_batch; // number of threads to use for batch processing
# int32_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
# enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
# enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
# // (ignored if no pooling layer)
# // ref: https://github.com/ggerganov/llama.cpp/pull/2054
# float rope_freq_base; // RoPE base frequency, 0 = from model
@ -559,13 +573,16 @@ class llama_model_params(ctypes.Structure):
# enum ggml_type type_k; // data type for K cache
# enum ggml_type type_v; // data type for V cache
# // Keep the booleans together to avoid misalignment during copy-by-value.
# bool mul_mat_q; // if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
# bool logits_all; // the llama_eval() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
# bool embedding; // embedding mode only
# bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
# bool embeddings; // if true, extract embeddings (together with logits)
# bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
# bool do_pooling; // whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)
# // Abort callback
# // if it returns true, execution of llama_decode() will be aborted
# // currently works only with CPU execution
# ggml_abort_callback abort_callback;
# void * abort_callback_data;
# };
class llama_context_params(ctypes.Structure):
"""Parameters for llama_context
@ -574,9 +591,11 @@ class llama_context_params(ctypes.Structure):
seed (int): RNG seed, -1 for random
n_ctx (int): text context, 0 = from model
n_batch (int): prompt processing maximum batch size
n_parallel (int): number of parallel sequences (i.e. distinct states for recurrent models)
n_threads (int): number of threads to use for generation
n_threads_batch (int): number of threads to use for batch processing
rope_scaling_type (int): RoPE scaling type, from `enum llama_rope_scaling_type`
pooling_type (int): whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)
rope_freq_base (float): RoPE base frequency, 0 = from model
rope_freq_scale (float): RoPE frequency scaling factor, 0 = from model
yarn_ext_factor (float): YaRN extrapolation mix factor, negative = from model
@ -589,20 +608,22 @@ class llama_context_params(ctypes.Structure):
cb_eval_user_data (ctypes.ctypes.c_void_p): user data for cb_eval
type_k (int): data type for K cache
type_v (int): data type for V cache
mul_mat_q (bool): if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
logits_all (bool): the llama_eval() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
embedding (bool): embedding mode only
embeddings (bool): if true, extract embeddings (together with logits)
offload_kqv (bool): whether to offload the KQV ops (including the KV cache) to GPU
do_pooling (bool): whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)
abort_callback (ggml_abort_callback): abort callback if it returns true, execution of llama_decode() will be aborted
abort_callback_data (ctypes.ctypes.c_void_p): data for abort_callback
"""
_fields_ = [
("seed", ctypes.c_uint32),
("n_ctx", ctypes.c_uint32),
("n_batch", ctypes.c_uint32),
("n_parallel", ctypes.c_uint32),
("n_threads", ctypes.c_uint32),
("n_threads_batch", ctypes.c_uint32),
("rope_scaling_type", ctypes.c_int32),
("rope_scaling_type", ctypes.c_int),
("pooling_type", ctypes.c_int),
("rope_freq_base", ctypes.c_float),
("rope_freq_scale", ctypes.c_float),
("yarn_ext_factor", ctypes.c_float),
@ -615,11 +636,11 @@ class llama_context_params(ctypes.Structure):
("cb_eval_user_data", ctypes.c_void_p),
("type_k", ctypes.c_int),
("type_v", ctypes.c_int),
("mul_mat_q", ctypes.c_bool),
("logits_all", ctypes.c_bool),
("embedding", ctypes.c_bool),
("embeddings", ctypes.c_bool),
("offload_kqv", ctypes.c_bool),
("do_pooling", ctypes.c_bool),
("abort_callback", ggml_abort_callback),
("abort_callback_data", ctypes.c_void_p),
]
@ -1306,7 +1327,7 @@ def llama_kv_cache_clear(ctx: llama_context_p, /):
# // seq_id < 0 : match any sequence
# // p0 < 0 : [0, p1]
# // p1 < 0 : [p0, inf)
# LLAMA_API void llama_kv_cache_seq_rm(
# LLAMA_API bool llama_kv_cache_seq_rm(
# struct llama_context * ctx,
# llama_seq_id seq_id,
# llama_pos p0,
@ -1319,7 +1340,7 @@ def llama_kv_cache_clear(ctx: llama_context_p, /):
llama_pos,
llama_pos,
],
None,
ctypes.c_bool,
)
def llama_kv_cache_seq_rm(
ctx: llama_context_p,
@ -1327,7 +1348,7 @@ def llama_kv_cache_seq_rm(
p0: Union[llama_pos, int],
p1: Union[llama_pos, int],
/,
):
) -> bool:
"""Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
seq_id < 0 : match any sequence
p0 < 0 : [0, p1]
@ -1519,11 +1540,11 @@ def llama_copy_state_data(
...
# Set the state reading from the specified address
# Returns the number of bytes read
# // Set the state reading from the specified address
# // Returns the number of bytes read
# LLAMA_API size_t llama_set_state_data(
# struct llama_context * ctx,
# uint8_t * src);
# const uint8_t * src);
@ctypes_function(
"llama_set_state_data",
[llama_context_p_ctypes, ctypes.POINTER(ctypes.c_uint8)],
@ -1707,8 +1728,24 @@ def llama_set_n_threads(
"""
...
# // Set abort callback
# LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);
@ctypes_function(
"llama_set_abort_callback",
[llama_context_p_ctypes, ggml_abort_callback, ctypes.c_void_p],
None,
)
def llama_set_abort_callback(
ctx: llama_context_p,
abort_callback: Callable[[ctypes.c_void_p], None],
abort_callback_data: ctypes.c_void_p,
/,
):
"""Set abort callback"""
...
# // Token logits obtained from the last call to llama_eval()
# // Token logits obtained from the last call to llama_decode()
# // The logits for the last token are stored in the last row
# // Logits for which llama_batch.logits[i] == 0 are undefined
# // Rows: n_tokens provided with llama_batch
@ -1722,7 +1759,10 @@ def llama_get_logits(ctx: llama_context_p, /) -> CtypesArray[ctypes.c_float]:
The logits for the last token are stored in the last row
Logits for which llama_batch.logits[i] == 0 are undefined
Rows: n_tokens provided with llama_batch
Cols: n_vocab"""
Cols: n_vocab
Returns:
Pointer to the logits buffer of shape (n_tokens, n_vocab)"""
...
@ -1742,8 +1782,8 @@ def llama_get_logits_ith(
...
# Get the embeddings for the input
# shape: [n_embd] (1-dimensional)
# // Get all output token embeddings
# // shape: [n_tokens*n_embd] (1-dimensional)
# LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
@ctypes_function(
"llama_get_embeddings", [llama_context_p_ctypes], ctypes.POINTER(ctypes.c_float)
@ -1754,8 +1794,9 @@ def llama_get_embeddings(ctx: llama_context_p, /) -> CtypesArray[ctypes.c_float]
...
# // Get the embeddings for the ith sequence
# // Get the embeddings for the ith token
# // llama_get_embeddings(ctx) + i*n_embd
# // shape: [n_embd] (1-dimensional)
# LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
@ctypes_function(
"llama_get_embeddings_ith",
@ -1770,6 +1811,23 @@ def llama_get_embeddings_ith(
...
# // Get the embeddings for a sequence id
# // Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
# // shape: [n_embd] (1-dimensional)
# LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
@ctypes_function(
"llama_get_embeddings_seq",
[llama_context_p_ctypes, llama_seq_id],
ctypes.POINTER(ctypes.c_float),
)
def llama_get_embeddings_seq(
ctx: llama_context_p, seq_id: Union[llama_seq_id, int], /
) -> CtypesArray[ctypes.c_float]:
"""Get the embeddings for a sequence id
Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
shape: [n_embd] (1-dimensional)"""
...
# //
# // Vocab
# //

View file

@ -1337,7 +1337,7 @@ array ::=
string ::=
"\"" (
[^"\\] |
[^"\\\x7F\x00-\x1F] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
)* "\"" ws
@ -1366,7 +1366,7 @@ array ::=
string ::=
"\"" (
[^"\\] |
[^"\\\x7F\x00-\x1F] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
)* "\"" ws

View file

@ -41,6 +41,11 @@ from llama_cpp.server.types import (
CreateEmbeddingRequest,
CreateChatCompletionRequest,
ModelList,
TokenizeInputRequest,
TokenizeInputResponse,
TokenizeInputCountResponse,
DetokenizeInputRequest,
DetokenizeInputResponse,
)
from llama_cpp.server.errors import RouteErrorHandler
@ -196,6 +201,9 @@ async def authenticate(
)
openai_v1_tag = "OpenAI V1"
@router.post(
"/v1/completions",
summary="Completion",
@ -227,11 +235,13 @@ async def authenticate(
},
}
},
tags=[openai_v1_tag],
)
@router.post(
"/v1/engines/copilot-codex/completions",
include_in_schema=False,
dependencies=[Depends(authenticate)],
tags=[openai_v1_tag],
)
async def create_completion(
request: Request,
@ -297,7 +307,10 @@ async def create_completion(
@router.post(
"/v1/embeddings", summary="Embedding", dependencies=[Depends(authenticate)]
"/v1/embeddings",
summary="Embedding",
dependencies=[Depends(authenticate)],
tags=[openai_v1_tag],
)
async def create_embedding(
request: CreateEmbeddingRequest,
@ -339,6 +352,7 @@ async def create_embedding(
},
}
},
tags=[openai_v1_tag],
)
async def create_chat_completion(
request: Request,
@ -391,7 +405,12 @@ async def create_chat_completion(
return iterator_or_completion
@router.get("/v1/models", summary="Models", dependencies=[Depends(authenticate)])
@router.get(
"/v1/models",
summary="Models",
dependencies=[Depends(authenticate)],
tags=[openai_v1_tag],
)
async def get_models(
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
) -> ModelList:
@ -407,3 +426,51 @@ async def get_models(
for model_alias in llama_proxy
],
}
extras_tag = "Extras"
@router.post(
"/extras/tokenize",
summary="Tokenize",
dependencies=[Depends(authenticate)],
tags=[extras_tag],
)
async def tokenize(
body: TokenizeInputRequest,
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
) -> TokenizeInputResponse:
tokens = llama_proxy(body.model).tokenize(body.input.encode("utf-8"), special=True)
return {"tokens": tokens}
@router.post(
"/extras/tokenize/count",
summary="Tokenize Count",
dependencies=[Depends(authenticate)],
tags=[extras_tag],
)
async def count_query_tokens(
body: TokenizeInputRequest,
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
) -> TokenizeInputCountResponse:
tokens = llama_proxy(body.model).tokenize(body.input.encode("utf-8"), special=True)
return {"count": len(tokens)}
@router.post(
"/extras/detokenize",
summary="Detokenize",
dependencies=[Depends(authenticate)],
tags=[extras_tag],
)
async def detokenize(
body: DetokenizeInputRequest,
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
) -> DetokenizeInputResponse:
text = llama_proxy(body.model).detokenize(body.tokens).decode("utf-8")
return {"text": text}

View file

@ -264,3 +264,39 @@ class ModelData(TypedDict):
class ModelList(TypedDict):
object: Literal["list"]
data: List[ModelData]
class TokenizeInputRequest(BaseModel):
model: Optional[str] = model_field
input: Optional[str] = Field(description="The input to tokenize.")
model_config = {
"json_schema_extra": {"examples": [{"input": "How many tokens in this query?"}]}
}
class TokenizeInputResponse(BaseModel):
tokens: List[int] = Field(description="A list of tokens.")
model_config = {"json_schema_extra": {"example": {"tokens": [123, 321, 222]}}}
class TokenizeInputCountResponse(BaseModel):
count: int = Field(description="The number of tokens in the input.")
model_config = {"json_schema_extra": {"example": {"count": 5}}}
class DetokenizeInputRequest(BaseModel):
model: Optional[str] = model_field
tokens: List[int] = Field(description="A list of toekns to detokenize.")
model_config = {"json_schema_extra": {"example": [{"tokens": [123, 321, 222]}]}}
class DetokenizeInputResponse(BaseModel):
text: str = Field(description="The detokenized text.")
model_config = {
"json_schema_extra": {"example": {"text": "How many tokens in this query?"}}
}

2
vendor/llama.cpp vendored

@ -1 +1 @@
Subproject commit 08c5ee87e4cceb603ecceac90734fcdade57311b
Subproject commit c2101a2e909ac7c08976d414e64e96c90ee5fa9e