This commit is contained in:
baalajimaestro 2024-02-27 17:59:39 +05:30
commit f343259cf7
Signed by: baalajimaestro
GPG key ID: F93C394FE9BBAFD5
19 changed files with 460 additions and 266 deletions

View file

@ -7,6 +7,26 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
## [0.2.52]
- feat: Update llama.cpp to ggerganov/llama.cpp@a33e6a0d2a66104ea9a906bdbf8a94d050189d91
- fix: Llava15ChatHandler (this function takes at least 4 arguments) by @abetlen in 8383a9e5620f5df5a88f62da16813eac200dd706
## [0.2.51]
- feat: Update llama.cpp to ggerganov/llama.cpp@c39373398803c669056304090050fe3f44b41bf9
- fix: Restore type hints for low-level api by @abetlen in 19234aa0dbd0c3c87656e65dd2b064665371925b
## [0.2.50]
- docs: Update Functionary OpenAI Server Readme by @jeffrey-fong in #1193
- fix: LlamaHFTokenizer now receives pre_tokens by @abetlen in 47bad30dd716443652275099fa3851811168ff4a
## [0.2.49]
- fix: module 'llama_cpp.llama_cpp' has no attribute 'c_uint8' in Llama.save_state by @abetlen in db776a885cd4c20811f22f8bd1a27ecc71dba927
- feat: Auto detect Mixtral's slightly different format by @lukestanley in #1214
## [0.2.48]
- feat: Update llama.cpp to ggerganov/llama.cpp@15499eb94227401bdc8875da6eb85c15d37068f7
@ -145,7 +165,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- feat: Update llama.cpp to ggerganov/llama.cpp@b3a7c20b5c035250257d2b62851c379b159c899a
- feat: Add `saiga` chat format by @femoiseev in #1050
- feat: Added `chatglm3` chat format by @xaviviro in #1059
- fix: Correct typo in README.md by @qeleb in (#1058)
- fix: Correct typo in README.md by @qeleb in (#1058)
## [0.2.26]
@ -278,7 +298,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [0.2.11]
- Fix bug in `llama_model_params` object has no attribute `logits_all` by @abetlen in d696251fbe40015e8616ea7a7d7ad5257fd1b896
- Fix bug in `llama_model_params` object has no attribute `logits_all` by @abetlen in d696251fbe40015e8616ea7a7d7ad5257fd1b896
## [0.2.10]
@ -466,7 +486,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [0.1.60]
NOTE: This release was deleted due to a bug with the packaging system that caused pip installations to fail.
NOTE: This release was deleted due to a bug with the packaging system that caused pip installations to fail.
- Truncate max_tokens in create_completion so requested tokens doesn't exceed context size.
- Temporarily disable cache for completion requests
@ -490,4 +510,4 @@ NOTE: This release was deleted due to a bug with the packaging system that caus
- (misc) Added first version of the changelog
- (server) Use async routes
- (python-api) Use numpy for internal buffers to reduce memory usage and improve performance.
- (python-api) Performance bug in stop sequence check slowing down streaming.
- (python-api) Performance bug in stop sequence check slowing down streaming.

View file

@ -365,14 +365,10 @@ To constrain the response further to a specific JSON Schema add the schema to th
### Function Calling
The high-level API also provides a simple interface for function calling. This is possible through the `functionary` pre-trained models chat format or through the generic `chatml-function-calling` chat format.
The gguf-converted files for functionary can be found here: [functionary-7b-v1](https://huggingface.co/abetlen/functionary-7b-v1-GGUF)
The high-level API supports OpenAI compatible function and tool calling. This is possible through the `functionary` pre-trained models chat format or through the generic `chatml-function-calling` chat format.
```python
>>> from llama_cpp import Llama
>>> llm = Llama(model_path="path/to/functionary/llama-model.gguf", chat_format="functionary")
>>> # or
>>> llm = Llama(model_path="path/to/chatml/llama-model.gguf", chat_format="chatml-function-calling")
>>> llm.create_chat_completion(
messages = [
@ -416,6 +412,25 @@ The gguf-converted files for functionary can be found here: [functionary-7b-v1](
)
```
<details>
<summary>Functionary v2</summary>
The various gguf-converted files for this set of models can be found [here](https://huggingface.co/meetkai). Functionary is able to intelligently call functions and also analyze any provided function outputs to generate coherent responses. All v2 models of functionary supports **parallel function calling**. You can provide either `functionary-v1` or `functionary-v2` for the `chat_format` when initializing the Llama class.
Due to discrepancies between llama.cpp and HuggingFace's tokenizers, it is required to provide HF Tokenizer for functionary. The `LlamaHFTokenizer` class can be initialized and passed into the Llama class. This will override the default llama.cpp tokenizer used in Llama class. The tokenizer files are already included in the respective HF repositories hosting the gguf files.
```python
>>> from llama_cpp import Llama
>>> from llama_cpp.llama_tokenizer import LlamaHFTokenizer
>>> llm = Llama.from_pretrained(
repo_id="meetkai/functionary-small-v2.2-GGUF",
filename="functionary-small-v2.2.q4_0.gguf",
chat_format="functionary-v2",
tokenizer=LlamaHFTokenizer.from_pretrained("meetkai/functionary-small-v2.2-GGUF")
)
```
</details>
### Multi-modal Models
`llama-cpp-python` supports the llava1.5 family of multi-modal models which allow the language model to
@ -453,6 +468,38 @@ Then you'll need to use a custom chat handler to load the clip model and process
)
```
<details>
<summary>Loading a Local Image</summary>
Images can be passed as base64 encoded data URIs. The following example demonstrates how to do this.
```python
import base64
def image_to_base64_data_uri(file_path):
with open(file_path, "rb") as img_file:
base64_data = base64.b64encode(img_file.read()).decode('utf-8')
return f"data:image/png;base64,{base64_data}"
# Replace 'file_path.png' with the actual path to your PNG file
file_path = 'file_path.png'
data_uri = image_to_base64_data_uri(file_path)
messages = [
{"role": "system", "content": "You are an assistant who perfectly describes images."},
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": data_uri }},
{"type" : "text", "text": "Describe this image in detail please."}
]
}
]
```
</details>
### Speculative Decoding
`llama-cpp-python` supports speculative decoding which allows the model to generate completions based on a draft model.
@ -530,6 +577,12 @@ python3 -m llama_cpp.server --model models/7B/llama-model.gguf --chat_format cha
That will format the prompt according to how model expects it. You can find the prompt format in the model card.
For possible options, see [llama_cpp/llama_chat_format.py](llama_cpp/llama_chat_format.py) and look for lines starting with "@register_chat_format".
If you have `huggingface-hub` installed, you can also use the `--hf_model_repo_id` flag to load a model from the Hugging Face Hub.
```bash
python3 -m llama_cpp.server --hf_model_repo_id Qwen/Qwen1.5-0.5B-Chat-GGUF --model '*q8_0.gguf'
```
### Web Server Features
- [Local Copilot replacement](https://llama-cpp-python.readthedocs.io/en/latest/server/#code-completion)

View file

@ -76,12 +76,14 @@ Function calling is completely compatible with the OpenAI function calling API a
You'll first need to download one of the available function calling models in GGUF format:
- [functionary-7b-v1](https://huggingface.co/abetlen/functionary-7b-v1-GGUF)
- [functionary](https://huggingface.co/meetkai)
Then when you run the server you'll need to also specify the `functionary` chat_format
Then when you run the server you'll need to also specify either `functionary-v1` or `functionary-v2` chat_format.
Note that since functionary requires a HF Tokenizer due to discrepancies between llama.cpp and HuggingFace's tokenizers as mentioned [here](https://github.com/abetlen/llama-cpp-python/blob/main?tab=readme-ov-file#function-calling), you will need to pass in the path to the tokenizer too. The tokenizer files are already included in the respective HF repositories hosting the gguf files.
```bash
python3 -m llama_cpp.server --model <model_path> --chat_format functionary
python3 -m llama_cpp.server --model <model_path_to_functionary_v2_model> --chat_format functionary-v2 --hf_pretrained_model_name_or_path <model_path_to_functionary_v2_tokenizer>
```
Check out this [example notebook](https://github.com/abetlen/llama-cpp-python/blob/main/examples/notebooks/Functions.ipynb) for a walkthrough of some interesting use cases for function calling.

View file

@ -0,0 +1,59 @@
import llama_cpp
import llama_cpp.llama_tokenizer
import gradio as gr
llama = llama_cpp.Llama.from_pretrained(
repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF",
filename="*q8_0.gguf",
tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B"),
verbose=False
)
model = "gpt-3.5-turbo"
def predict(message, history):
messages = []
for user_message, assistant_message in history:
messages.append({"role": "user", "content": user_message})
messages.append({"role": "assistant", "content": assistant_message})
messages.append({"role": "user", "content": message})
response = llama.create_chat_completion_openai_v1(
model=model,
messages=messages,
stream=True
)
text = ""
for chunk in response:
content = chunk.choices[0].delta.content
if content:
text += content
yield text
js = """function () {
gradioURL = window.location.href
if (!gradioURL.endsWith('?__theme=dark')) {
window.location.replace(gradioURL + '?__theme=dark');
}
}"""
css = """
footer {
visibility: hidden;
}
full-height {
height: 100%;
}
"""
with gr.Blocks(theme=gr.themes.Soft(), js=js, css=css, fill_height=True) as demo:
gr.ChatInterface(predict, fill_height=True, examples=["What is the capital of France?", "Who was the first person on the moon?"])
if __name__ == "__main__":
demo.launch()

View file

@ -0,0 +1,56 @@
import gradio as gr
from openai import OpenAI
client = OpenAI(
base_url="http://localhost:8000/v1",
api_key="llama.cpp"
)
model = "gpt-3.5-turbo"
def predict(message, history):
messages = []
for user_message, assistant_message in history:
messages.append({"role": "user", "content": user_message})
messages.append({"role": "assistant", "content": assistant_message})
messages.append({"role": "user", "content": message})
response = client.chat.completions.create(
model=model,
messages=messages,
stream=True
)
text = ""
for chunk in response:
content = chunk.choices[0].delta.content
if content:
text += content
yield text
js = """function () {
gradioURL = window.location.href
if (!gradioURL.endsWith('?__theme=dark')) {
window.location.replace(gradioURL + '?__theme=dark');
}
}"""
css = """
footer {
visibility: hidden;
}
full-height {
height: 100%;
}
"""
with gr.Blocks(theme=gr.themes.Soft(), js=js, css=css, fill_height=True) as demo:
gr.ChatInterface(predict, fill_height=True, examples=["What is the capital of France?", "Who was the first person on the moon?"])
if __name__ == "__main__":
demo.launch()

39
examples/hf_pull/main.py Normal file
View file

@ -0,0 +1,39 @@
import llama_cpp
import llama_cpp.llama_tokenizer
llama = llama_cpp.Llama.from_pretrained(
repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF",
filename="*q8_0.gguf",
tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B"),
verbose=False
)
response = llama.create_chat_completion(
messages=[
{
"role": "user",
"content": "What is the capital of France?"
}
],
response_format={
"type": "json_object",
"schema": {
"type": "object",
"properties": {
"country": {"type": "string"},
"capital": {"type": "string"}
},
"required": ["country", "capital"],
}
},
stream=True
)
for chunk in response:
delta = chunk["choices"][0]["delta"]
if "content" not in delta:
continue
print(delta["content"], end="", flush=True)
print()

View file

@ -9,7 +9,7 @@
"The OpenAI compatbile web server in `llama-cpp-python` supports function calling.\n",
"\n",
"Function calling allows API clients to specify a schema that gives the model a format it should respond in.\n",
"Function calling in `llama-cpp-python` works by combining models pretrained for function calling such as [`functionary`](https://huggingface.co/abetlen/functionary-7b-v1-GGUF) with constrained sampling to produce a response that is compatible with the schema.\n",
"Function calling in `llama-cpp-python` works by combining models pretrained for function calling such as [`functionary`](https://huggingface.co/meetkai) with constrained sampling to produce a response that is compatible with the schema.\n",
"\n",
"Note however that this improves but does not guarantee that the response will be compatible with the schema.\n",
"\n",

View file

@ -11,7 +11,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 13,
"metadata": {},
"outputs": [
{
@ -25,7 +25,7 @@
"source": [
"from openai import OpenAI\n",
"\n",
"client = OpenAI(base_url=\"http://100.64.159.73:8000/v1\", api_key=\"sk-1234\")\n",
"client = OpenAI(base_url=\"http://localhost:8000/v1\", api_key=\"llama.cpp\")\n",
"response = client.chat.completions.create(\n",
" model=\"gpt-4-vision-preview\",\n",
" messages=[\n",
@ -42,7 +42,17 @@
" ],\n",
" }\n",
" ],\n",
" response_format={ \"type\": \"json_object\" }\n",
" response_format={ \n",
" \"type\": \"json_object\",\n",
" \"schema\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"text\": {\n",
" \"type\": \"string\"\n",
" }\n",
" }\n",
" }\n",
" }\n",
")\n",
"import json\n",
"print(json.loads(response.choices[0].message.content))"

View file

@ -1,4 +1,4 @@
from .llama_cpp import *
from .llama import *
__version__ = "0.2.48"
__version__ = "0.2.52"

View file

@ -291,7 +291,7 @@ class _LlamaContext:
def kv_cache_seq_shift(self, seq_id: int, p0: int, p1: int, shift: int):
assert self.ctx is not None
llama_cpp.llama_kv_cache_seq_shift(self.ctx, seq_id, p0, p1, shift)
llama_cpp.llama_kv_cache_seq_add(self.ctx, seq_id, p0, p1, shift)
def get_state_size(self) -> int:
assert self.ctx is not None

View file

@ -5,8 +5,10 @@ import sys
import uuid
import time
import json
import ctypes
import fnmatch
import multiprocessing
from typing import (
List,
Optional,
@ -20,7 +22,6 @@ from typing import (
from collections import deque
from pathlib import Path
import ctypes
from llama_cpp.llama_types import List
@ -64,7 +65,7 @@ class Llama:
*,
# Model Params
n_gpu_layers: int = 0,
split_mode: int = llama_cpp.LLAMA_SPLIT_LAYER,
split_mode: int = llama_cpp.LLAMA_SPLIT_MODE_LAYER,
main_gpu: int = 0,
tensor_split: Optional[List[float]] = None,
vocab_only: bool = False,
@ -77,7 +78,7 @@ class Llama:
n_batch: int = 512,
n_threads: Optional[int] = None,
n_threads_batch: Optional[int] = None,
rope_scaling_type: Optional[int] = llama_cpp.LLAMA_ROPE_SCALING_UNSPECIFIED,
rope_scaling_type: Optional[int] = llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
rope_freq_base: float = 0.0,
rope_freq_scale: float = 0.0,
yarn_ext_factor: float = -1.0,
@ -237,13 +238,13 @@ class Llama:
for i, (k, v) in enumerate(kv_overrides.items()):
self._kv_overrides_array[i].key = k.encode("utf-8")
if isinstance(v, bool):
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_BOOL
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_BOOL
self._kv_overrides_array[i].value.bool_value = v
elif isinstance(v, int):
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_INT
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_INT
self._kv_overrides_array[i].value.int_value = v
elif isinstance(v, float):
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_FLOAT
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_FLOAT
self._kv_overrides_array[i].value.float_value = v
else:
raise ValueError(f"Unknown value type for {k}: {v}")
@ -269,7 +270,7 @@ class Llama:
self.context_params.rope_scaling_type = (
rope_scaling_type
if rope_scaling_type is not None
else llama_cpp.LLAMA_ROPE_SCALING_UNSPECIFIED
else llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
)
self.context_params.rope_freq_base = (
rope_freq_base if rope_freq_base != 0.0 else 0
@ -479,7 +480,7 @@ class Llama:
Returns:
The detokenized string.
"""
return self.tokenizer_.detokenize(tokens, prev_tokens)
return self.tokenizer_.detokenize(tokens, prev_tokens=prev_tokens)
def set_cache(self, cache: Optional[BaseLlamaCache]):
"""Set the cache.
@ -1015,13 +1016,13 @@ class Llama:
grammar=grammar,
):
if token == self._token_eos:
text = self.detokenize(completion_tokens)
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
finish_reason = "stop"
break
completion_tokens.append(token)
all_text = self.detokenize(completion_tokens)
all_text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
# Contains multi-byte UTF8
for k, char in enumerate(all_text[-3:]):
@ -1045,7 +1046,7 @@ class Llama:
if stream:
remaining_tokens = completion_tokens[returned_tokens:]
remaining_text = self.detokenize(remaining_tokens)
remaining_text = self.detokenize(remaining_tokens, prev_tokens=prompt_tokens + completion_tokens[:returned_tokens])
remaining_length = len(remaining_text)
# We want to avoid yielding any characters from
@ -1067,17 +1068,17 @@ class Llama:
for token in remaining_tokens:
if token == self.token_bos():
continue
token_end_position += len(self.detokenize([token]))
token_end_position += len(self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]))
# Check if stop sequence is in the token
if token_end_position > (
remaining_length - first_stop_position
):
break
token_str = self.detokenize([token]).decode(
token_str = self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]).decode(
"utf-8", errors="ignore"
)
text_offset = len(prompt) + len(
self.detokenize(completion_tokens[:returned_tokens]).decode(
self.detokenize(completion_tokens[:returned_tokens], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]).decode(
"utf-8", errors="ignore"
)
)
@ -1099,7 +1100,7 @@ class Llama:
top_logprob.update({token_str: current_logprobs[int(token)]})
logprobs_or_none = {
"tokens": [
self.detokenize([token]).decode(
self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]).decode(
"utf-8", errors="ignore"
)
],
@ -1115,7 +1116,7 @@ class Llama:
"model": model_name,
"choices": [
{
"text": self.detokenize([token]).decode(
"text": self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]).decode(
"utf-8", errors="ignore"
),
"index": 0,
@ -1129,7 +1130,7 @@ class Llama:
decode_success = False
for i in range(1, len(remaining_tokens) + 1):
try:
bs = self.detokenize(remaining_tokens[:i])
bs = self.detokenize(remaining_tokens[:i], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens])
ts = bs.decode("utf-8")
decode_success = True
break
@ -1164,14 +1165,14 @@ class Llama:
}
if len(completion_tokens) >= max_tokens:
text = self.detokenize(completion_tokens)
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
finish_reason = "length"
break
if stopping_criteria is not None and stopping_criteria(
self._input_ids, self._scores[-1, :]
):
text = self.detokenize(completion_tokens)
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
finish_reason = "stop"
if self.verbose:
@ -1179,7 +1180,7 @@ class Llama:
if stream:
remaining_tokens = completion_tokens[returned_tokens:]
all_text = self.detokenize(remaining_tokens)
all_text = self.detokenize(remaining_tokens, prev_tokens=prompt_tokens + completion_tokens[:returned_tokens])
any_stop = [s for s in stop_sequences if s in all_text]
if len(any_stop) > 0:
end = min(all_text.index(stop) for stop in any_stop)
@ -1188,7 +1189,7 @@ class Llama:
token_end_position = 0
for token in remaining_tokens:
token_end_position += len(self.detokenize([token]))
token_end_position += len(self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]))
logprobs_or_none: Optional[CompletionLogprobs] = None
if logprobs is not None:
@ -1198,7 +1199,7 @@ class Llama:
"utf-8", errors="ignore"
)
text_offset = len(prompt) + len(
self.detokenize(completion_tokens[:returned_tokens])
self.detokenize(completion_tokens[:returned_tokens], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens])
)
token_offset = len(prompt_tokens) + returned_tokens - 1
logits = self._scores[token_offset, :]
@ -1312,8 +1313,8 @@ class Llama:
all_tokens = completion_tokens
all_token_strs = [
self.detokenize([token]).decode("utf-8", errors="ignore")
for token in all_tokens
self.detokenize([token], prev_tokens=all_tokens[:i]).decode("utf-8", errors="ignore")
for i, token in enumerate(all_tokens)
]
all_logprobs = Llama.logits_to_logprobs(self._scores)[token_offset:]
# TODO: may be able to change this loop to use np.take_along_dim
@ -1338,7 +1339,7 @@ class Llama:
)
token_logprobs.append(logprobs_token[int(token)])
top_logprob: Optional[Dict[str, float]] = {
self.detokenize([i]).decode("utf-8", errors="ignore"): logprob
self.detokenize([i], prev_tokens=all_tokens[:idx]).decode("utf-8", errors="ignore"): logprob
for logprob, i in sorted_logprobs[:logprobs]
}
top_logprob.update({token_str: logprobs_token[int(token)]})
@ -1593,6 +1594,8 @@ class Llama:
logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
) -> Union[
CreateChatCompletionResponse, Iterator[CreateChatCompletionStreamResponse]
]:
@ -1789,7 +1792,7 @@ class Llama:
state_size = llama_cpp.llama_get_state_size(self._ctx.ctx)
if self.verbose:
print(f"Llama.save_state: got state size: {state_size}", file=sys.stderr)
llama_state = (llama_cpp.c_uint8 * int(state_size))()
llama_state = (ctypes.c_uint8 * int(state_size))()
if self.verbose:
print("Llama.save_state: allocated state", file=sys.stderr)
n_bytes = llama_cpp.llama_copy_state_data(self._ctx.ctx, llama_state)
@ -1797,7 +1800,7 @@ class Llama:
print(f"Llama.save_state: copied llama state: {n_bytes}", file=sys.stderr)
if int(n_bytes) > int(state_size):
raise RuntimeError("Failed to copy llama state data")
llama_state_compact = (llama_cpp.c_uint8 * int(n_bytes))()
llama_state_compact = (ctypes.c_uint8 * int(n_bytes))()
llama_cpp.ctypes.memmove(llama_state_compact, llama_state, int(n_bytes))
if self.verbose:
print(

View file

@ -29,6 +29,8 @@ MISTRAL_INSTRUCT_CHAT_TEMPLATE = "{{ bos_token }}{% for message in messages %}{%
MISTRAL_INSTRUCT_BOS_TOKEN = "<s>"
MISTRAL_INSTRUCT_EOS_TOKEN = "</s>"
# Source: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/tokenizer_config.json
MIXTRAL_INSTRUCT_CHAT_TEMPLATE = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
### Chat Completion Handler ###
@ -470,7 +472,8 @@ def guess_chat_format_from_gguf_metadata(metadata: Dict[str, str]) -> Optional[s
if metadata["tokenizer.chat_template"] == CHATML_CHAT_TEMPLATE:
return "chatml"
if metadata["tokenizer.chat_template"] == MISTRAL_INSTRUCT_CHAT_TEMPLATE:
if (metadata["tokenizer.chat_template"] == MISTRAL_INSTRUCT_CHAT_TEMPLATE or
metadata["tokenizer.chat_template"] == MIXTRAL_INSTRUCT_CHAT_TEMPLATE):
return "mistral-instruct"
return None
@ -1954,10 +1957,10 @@ class Llava15ChatHandler:
with suppress_stdout_stderr(disable=self.verbose):
embed = (
self._llava_cpp.llava_image_embed_make_with_bytes(
ctx_clip=self.clip_ctx,
n_threads=llama.context_params.n_threads,
image_bytes=c_ubyte_ptr,
image_bytes_length=len(image_bytes),
self.clip_ctx,
llama.context_params.n_threads,
c_ubyte_ptr,
len(image_bytes),
)
)
try:
@ -1965,10 +1968,10 @@ class Llava15ChatHandler:
n_past_p = ctypes.pointer(n_past)
with suppress_stdout_stderr(disable=self.verbose):
self._llava_cpp.llava_eval_image_embed(
ctx_llama=llama.ctx,
embed=embed,
n_batch=llama.n_batch,
n_past=n_past_p,
llama.ctx,
embed,
llama.n_batch,
n_past_p,
)
assert llama.n_ctx() >= n_past.value
llama.n_tokens = n_past.value

View file

@ -109,12 +109,13 @@ if TYPE_CHECKING:
CtypesFuncPointer: TypeAlias = ctypes._FuncPointer # type: ignore
F = TypeVar("F", bound=Callable[..., Any])
def ctypes_function_for_shared_library(lib: ctypes.CDLL):
def ctypes_function(
name: str, argtypes: List[Any], restype: Any, enabled: bool = True
):
def decorator(f: Callable[..., Any]):
def decorator(f: F) -> F:
if enabled:
func = getattr(lib, name)
func.argtypes = argtypes
@ -199,6 +200,20 @@ LLAMA_VOCAB_TYPE_BPE = 1
LLAMA_VOCAB_TYPE_WPM = 2
# // note: these values should be synchronized with ggml_rope
# // TODO: maybe move this enum to ggml.h (ggml_rope_type)
# enum llama_rope_type {
# LLAMA_ROPE_TYPE_NONE = -1,
# LLAMA_ROPE_TYPE_NORM = 0,
# LLAMA_ROPE_TYPE_NEOX = 2,
# LLAMA_ROPE_TYPE_GLM = 4,
# };
LLAMA_ROPE_TYPE_NONE = -1
LLAMA_ROPE_TYPE_NORM = 0
LLAMA_ROPE_TYPE_NEOX = 2
LLAMA_ROPE_TYPE_GLM = 4
# enum llama_token_type {
# LLAMA_TOKEN_TYPE_UNDEFINED = 0,
# LLAMA_TOKEN_TYPE_NORMAL = 1,
@ -241,10 +256,14 @@ LLAMA_TOKEN_TYPE_BYTE = 6
# LLAMA_FTYPE_MOSTLY_IQ2_XXS = 19, // except 1d tensors
# LLAMA_FTYPE_MOSTLY_IQ2_XS = 20, // except 1d tensors
# LLAMA_FTYPE_MOSTLY_Q2_K_S = 21, // except 1d tensors
# LLAMA_FTYPE_MOSTLY_Q3_K_XS = 22, // except 1d tensors
# LLAMA_FTYPE_MOSTLY_IQ3_XS = 22, // except 1d tensors
# LLAMA_FTYPE_MOSTLY_IQ3_XXS = 23, // except 1d tensors
# LLAMA_FTYPE_MOSTLY_IQ1_S = 24, // except 1d tensors
# LLAMA_FTYPE_MOSTLY_IQ4_NL = 25, // except 1d tensors
# LLAMA_FTYPE_MOSTLY_IQ3_S = 26, // except 1d tensors
# LLAMA_FTYPE_MOSTLY_IQ3_M = 27, // except 1d tensors
# LLAMA_FTYPE_MOSTLY_IQ2_S = 28, // except 1d tensors
# LLAMA_FTYPE_MOSTLY_IQ2_M = 29, // except 1d tensors
# LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
# };
@ -268,42 +287,46 @@ LLAMA_FTYPE_MOSTLY_Q6_K = 18
LLAMA_FTYPE_MOSTLY_IQ2_XXS = 19
LLAMA_FTYPE_MOSTLY_IQ2_XS = 20
LLAMA_FTYPE_MOSTLY_Q2_K_S = 21
LLAMA_FTYPE_MOSTLY_Q3_K_XS = 22
LLAMA_FTYPE_MOSTLY_IQ3_XS = 22
LLAMA_FTYPE_MOSTLY_IQ3_XXS = 23
LLAMA_FTYPE_MOSTLY_IQ1_S = 24
LLAMA_FTYPE_MOSTLY_IQ4_NL = 25
LLAMA_FTYPE_MOSTLY_IQ3_S = 26
LLAMA_FTYPE_MOSTLY_IQ3_M = 27
LLAMA_FTYPE_MOSTLY_IQ2_S = 28
LLAMA_FTYPE_MOSTLY_IQ2_M = 29
LLAMA_FTYPE_GUESSED = 1024
# enum llama_rope_scaling_type {
# LLAMA_ROPE_SCALING_UNSPECIFIED = -1,
# LLAMA_ROPE_SCALING_NONE = 0,
# LLAMA_ROPE_SCALING_LINEAR = 1,
# LLAMA_ROPE_SCALING_YARN = 2,
# LLAMA_ROPE_SCALING_MAX_VALUE = LLAMA_ROPE_SCALING_YARN,
# LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED = -1,
# LLAMA_ROPE_SCALING_TYPE_NONE = 0,
# LLAMA_ROPE_SCALING_TYPE_LINEAR = 1,
# LLAMA_ROPE_SCALING_TYPE_YARN = 2,
# LLAMA_ROPE_SCALING_TYPE_MAX_VALUE = LLAMA_ROPE_SCALING_TYPE_YARN,
# };
LLAMA_ROPE_SCALING_UNSPECIFIED = -1
LLAMA_ROPE_SCALING_NONE = 0
LLAMA_ROPE_SCALING_LINEAR = 1
LLAMA_ROPE_SCALING_YARN = 2
LLAMA_ROPE_SCALING_MAX_VALUE = LLAMA_ROPE_SCALING_YARN
LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED = -1
LLAMA_ROPE_SCALING_TYPE_NONE = 0
LLAMA_ROPE_SCALING_TYPE_LINEAR = 1
LLAMA_ROPE_SCALING_TYPE_YARN = 2
LLAMA_ROPE_SCALING_TYPE_MAX_VALUE = LLAMA_ROPE_SCALING_TYPE_YARN
# enum llama_pooling_type {
# LLAMA_POOLING_NONE = 0,
# LLAMA_POOLING_MEAN = 1,
# LLAMA_POOLING_CLS = 2,
# LLAMA_POOLING_TYPE_NONE = 0,
# LLAMA_POOLING_TYPE_MEAN = 1,
# LLAMA_POOLING_TYPE_CLS = 2,
# };
LLAMA_POOLING_NONE = 0
LLAMA_POOLING_MEAN = 1
LLAMA_POOLING_CLS = 2
LLAMA_POOLING_TYPE_NONE = 0
LLAMA_POOLING_TYPE_MEAN = 1
LLAMA_POOLING_TYPE_CLS = 2
# enum llama_split_mode {
# LLAMA_SPLIT_NONE = 0, // single GPU
# LLAMA_SPLIT_LAYER = 1, // split layers and KV across GPUs
# LLAMA_SPLIT_ROW = 2, // split rows across GPUs
# LLAMA_SPLIT_MODE_NONE = 0, // single GPU
# LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs
# LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs
# };
LLAMA_SPLIT_NONE = 0
LLAMA_SPLIT_LAYER = 1
LLAMA_SPLIT_ROW = 2
LLAMA_SPLIT_MODE_NONE = 0
LLAMA_SPLIT_MODE_LAYER = 1
LLAMA_SPLIT_MODE_ROW = 2
# typedef struct llama_token_data {
@ -416,13 +439,13 @@ class llama_batch(ctypes.Structure):
# enum llama_model_kv_override_type {
# LLAMA_KV_OVERRIDE_INT,
# LLAMA_KV_OVERRIDE_FLOAT,
# LLAMA_KV_OVERRIDE_BOOL,
# LLAMA_KV_OVERRIDE_TYPE_INT,
# LLAMA_KV_OVERRIDE_TYPE_FLOAT,
# LLAMA_KV_OVERRIDE_TYPE_BOOL,
# };
LLAMA_KV_OVERRIDE_INT = 0
LLAMA_KV_OVERRIDE_FLOAT = 1
LLAMA_KV_OVERRIDE_BOOL = 2
LLAMA_KV_OVERRIDE_TYPE_INT = 0
LLAMA_KV_OVERRIDE_TYPE_FLOAT = 1
LLAMA_KV_OVERRIDE_TYPE_BOOL = 2
# struct llama_model_kv_override {
@ -887,104 +910,84 @@ def llama_time_us() -> int:
# LLAMA_API size_t llama_max_devices(void);
@ctypes_function("llama_max_devices", [], ctypes.c_size_t)
def llama_max_devices() -> int:
...
# LLAMA_API bool llama_supports_mmap (void);
@ctypes_function("llama_supports_mmap", [], ctypes.c_bool)
def llama_supports_mmap() -> bool:
...
# LLAMA_API bool llama_supports_mlock (void);
@ctypes_function("llama_supports_mlock", [], ctypes.c_bool)
def llama_supports_mlock() -> bool:
...
# LLAMA_API bool llama_supports_gpu_offload(void);
@ctypes_function("llama_supports_gpu_offload", [], ctypes.c_bool)
def llama_supports_gpu_offload() -> bool:
...
# LLAMA_API DEPRECATED(bool llama_mmap_supported (void), "use llama_supports_mmap() instead");
@ctypes_function("llama_mmap_supported", [], ctypes.c_bool)
def llama_mmap_supported() -> bool:
...
# LLAMA_API DEPRECATED(bool llama_mlock_supported(void), "use llama_supports_mlock() instead");
@ctypes_function("llama_mlock_supported", [], ctypes.c_bool)
def llama_mlock_supported() -> bool:
...
# LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
@ctypes_function("llama_get_model", [llama_context_p_ctypes], llama_model_p_ctypes)
def llama_get_model(ctx: llama_context_p, /) -> Optional[llama_model_p]:
...
# LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
@ctypes_function("llama_n_ctx", [llama_context_p_ctypes], ctypes.c_uint32)
def llama_n_ctx(ctx: llama_context_p, /) -> int:
...
# LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
@ctypes_function("llama_n_batch", [llama_context_p_ctypes], ctypes.c_uint32)
def llama_n_batch(ctx: llama_context_p, /) -> int:
...
# LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
@ctypes_function("llama_vocab_type", [llama_model_p_ctypes], ctypes.c_int)
def llama_vocab_type(model: llama_model_p, /) -> int:
...
# LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
@ctypes_function("llama_rope_type", [llama_model_p_ctypes], ctypes.c_int)
def llama_rope_type(model: llama_model_p, /) -> int:
...
# LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
@ctypes_function("llama_n_vocab", [llama_model_p_ctypes], ctypes.c_int32)
def llama_n_vocab(model: llama_model_p, /) -> int:
...
# LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
@ctypes_function("llama_n_ctx_train", [llama_model_p_ctypes], ctypes.c_int32)
def llama_n_ctx_train(model: llama_model_p, /) -> int:
...
# LLAMA_API int32_t llama_n_embd (const struct llama_model * model);
@ctypes_function("llama_n_embd", [llama_model_p_ctypes], ctypes.c_int32)
def llama_n_embd(model: llama_model_p, /) -> int:
...
@ -992,8 +995,6 @@ def llama_n_embd(model: llama_model_p, /) -> int:
# // Get the model's RoPE frequency scaling factor
# LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model);
@ctypes_function("llama_rope_freq_scale_train", [llama_model_p_ctypes], ctypes.c_float)
def llama_rope_freq_scale_train(model: llama_model_p, /) -> float:
"""Get the model's RoPE frequency scaling factor"""
@ -1008,8 +1009,6 @@ def llama_rope_freq_scale_train(model: llama_model_p, /) -> float:
# // Get metadata value as a string by key name
# LLAMA_API int32_t llama_model_meta_val_str(const struct llama_model * model, const char * key, char * buf, size_t buf_size);
@ctypes_function(
"llama_model_meta_val_str",
[
@ -1033,8 +1032,6 @@ def llama_model_meta_val_str(
# // Get the number of metadata key/value pairs
# LLAMA_API int32_t llama_model_meta_count(const struct llama_model * model);
@ctypes_function("llama_model_meta_count", [llama_model_p_ctypes], ctypes.c_int32)
def llama_model_meta_count(model: llama_model_p, /) -> int:
"""Get the number of metadata key/value pairs"""
@ -1043,8 +1040,6 @@ def llama_model_meta_count(model: llama_model_p, /) -> int:
# // Get metadata key name by index
# LLAMA_API int32_t llama_model_meta_key_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size);
@ctypes_function(
"llama_model_meta_key_by_index",
[
@ -1068,8 +1063,6 @@ def llama_model_meta_key_by_index(
# // Get metadata value as a string by index
# LLAMA_API int32_t llama_model_meta_val_str_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size);
@ctypes_function(
"llama_model_meta_val_str_by_index",
[
@ -1093,8 +1086,6 @@ def llama_model_meta_val_str_by_index(
# // Get a string describing the model type
# LLAMA_API int32_t llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size);
@ctypes_function(
"llama_model_desc",
[llama_model_p_ctypes, ctypes.c_char_p, ctypes.c_size_t],
@ -1112,8 +1103,6 @@ def llama_model_desc(
# // Returns the total size of all the tensors in the model in bytes
# LLAMA_API uint64_t llama_model_size(const struct llama_model * model);
@ctypes_function("llama_model_size", [llama_model_p_ctypes], ctypes.c_uint64)
def llama_model_size(model: llama_model_p, /) -> int:
"""Returns the total size of all the tensors in the model in bytes"""
@ -1122,8 +1111,6 @@ def llama_model_size(model: llama_model_p, /) -> int:
# // Returns the total number of parameters in the model
# LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model);
@ctypes_function("llama_model_n_params", [llama_model_p_ctypes], ctypes.c_uint64)
def llama_model_n_params(model: llama_model_p, /) -> int:
"""Returns the total number of parameters in the model"""
@ -1132,8 +1119,6 @@ def llama_model_n_params(model: llama_model_p, /) -> int:
# // Get a llama model tensor
# LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name);
@ctypes_function(
"llama_get_model_tensor", [llama_model_p_ctypes, ctypes.c_char_p], ctypes.c_void_p
)
@ -1149,8 +1134,6 @@ def llama_get_model_tensor(
# const char * fname_inp,
# const char * fname_out,
# const llama_model_quantize_params * params);
@ctypes_function(
"llama_model_quantize",
[
@ -1183,8 +1166,6 @@ def llama_model_quantize(
# const char * path_base_model,
# int32_t n_threads),
# "use llama_model_apply_lora_from_file instead");
@ctypes_function(
"llama_apply_lora_from_file",
[
@ -1219,8 +1200,6 @@ def llama_apply_lora_from_file(
# float scale,
# const char * path_base_model,
# int32_t n_threads);
@ctypes_function(
"llama_model_apply_lora_from_file",
[
@ -1308,8 +1287,6 @@ llama_kv_cache_view_p = ctypes.POINTER(llama_kv_cache_view)
# // Create an empty KV cache view. (use only for debugging purposes)
# LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq);
@ctypes_function(
"llama_kv_cache_view_init",
[llama_context_p_ctypes, ctypes.c_int32],
@ -1324,8 +1301,6 @@ def llama_kv_cache_view_init(
# // Free a KV cache view. (use only for debugging purposes)
# LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view);
@ctypes_function("llama_kv_cache_view_free", [llama_kv_cache_view_p], None)
def llama_kv_cache_view_free(view: "ctypes.pointer[llama_kv_cache_view]", /): # type: ignore
"""Free a KV cache view. (use only for debugging purposes)"""
@ -1334,8 +1309,6 @@ def llama_kv_cache_view_free(view: "ctypes.pointer[llama_kv_cache_view]", /): #
# // Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
# LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);
@ctypes_function(
"llama_kv_cache_view_update", [llama_context_p_ctypes, llama_kv_cache_view_p], None
)
@ -1347,8 +1320,6 @@ def llama_kv_cache_view_update(ctx: llama_context_p, view: CtypesPointerOrRef[ll
# // Returns the number of tokens in the KV cache (slow, use only for debug)
# // If a KV cell has multiple sequences assigned to it, it will be counted multiple times
# LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx);
@ctypes_function(
"llama_get_kv_cache_token_count", [llama_context_p_ctypes], ctypes.c_int32
)
@ -1361,8 +1332,6 @@ def llama_get_kv_cache_token_count(ctx: llama_context_p, /) -> int:
# // Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
# LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx);
@ctypes_function(
"llama_get_kv_cache_used_cells", [llama_context_p_ctypes], ctypes.c_int32
)
@ -1374,8 +1343,6 @@ def llama_get_kv_cache_used_cells(ctx: llama_context_p, /) -> int:
# // Clear the KV cache
# LLAMA_API void llama_kv_cache_clear(
# struct llama_context * ctx);
@ctypes_function("llama_kv_cache_clear", [llama_context_p_ctypes], None)
def llama_kv_cache_clear(ctx: llama_context_p, /):
"""Clear the KV cache"""
@ -1391,8 +1358,6 @@ def llama_kv_cache_clear(ctx: llama_context_p, /):
# llama_seq_id seq_id,
# llama_pos p0,
# llama_pos p1);
@ctypes_function(
"llama_kv_cache_seq_rm",
[
@ -1427,8 +1392,6 @@ def llama_kv_cache_seq_rm(
# llama_seq_id seq_id_dst,
# llama_pos p0,
# llama_pos p1);
@ctypes_function(
"llama_kv_cache_seq_cp",
[
@ -1459,8 +1422,6 @@ def llama_kv_cache_seq_cp(
# LLAMA_API void llama_kv_cache_seq_keep(
# struct llama_context * ctx,
# llama_seq_id seq_id);
@ctypes_function(
"llama_kv_cache_seq_keep", [llama_context_p_ctypes, llama_seq_id], None
)
@ -1470,19 +1431,19 @@ def llama_kv_cache_seq_keep(ctx: llama_context_p, seq_id: Union[llama_seq_id, in
# // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
# // If the KV cache is RoPEd, the KV data is updated accordingly
# // If the KV cache is RoPEd, the KV data is updated accordingly:
# // - lazily on next llama_decode()
# // - explicitly with llama_kv_cache_update()
# // p0 < 0 : [0, p1]
# // p1 < 0 : [p0, inf)
# LLAMA_API void llama_kv_cache_seq_shift(
# LLAMA_API void llama_kv_cache_seq_add(
# struct llama_context * ctx,
# llama_seq_id seq_id,
# llama_pos p0,
# llama_pos p1,
# llama_pos delta);
@ctypes_function(
"llama_kv_cache_seq_shift",
"llama_kv_cache_seq_add",
[
llama_context_p_ctypes,
llama_seq_id,
@ -1492,7 +1453,7 @@ def llama_kv_cache_seq_keep(ctx: llama_context_p, seq_id: Union[llama_seq_id, in
],
None,
)
def llama_kv_cache_seq_shift(
def llama_kv_cache_seq_add(
ctx: llama_context_p,
seq_id: Union[llama_seq_id, int],
p0: Union[llama_pos, int],
@ -1501,7 +1462,9 @@ def llama_kv_cache_seq_shift(
/,
):
"""Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
If the KV cache is RoPEd, the KV data is updated accordingly
If the KV cache is RoPEd, the KV data is updated accordingly:
- lazily on next llama_decode()
- explicitly with llama_kv_cache_update()
p0 < 0 : [0, p1]
p1 < 0 : [p0, inf)"""
...
@ -1517,8 +1480,6 @@ def llama_kv_cache_seq_shift(
# llama_pos p0,
# llama_pos p1,
# int d);
@ctypes_function(
"llama_kv_cache_seq_div",
[
@ -1545,6 +1506,28 @@ def llama_kv_cache_seq_div(
...
# // Defragment the KV cache
# // This will be applied:
# // - lazily on next llama_decode()
# // - explicitly with llama_kv_cache_update()
# LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx);
@ctypes_function("llama_kv_cache_defrag", [llama_context_p_ctypes], None)
def llama_kv_cache_defrag(ctx: llama_context_p, /):
"""Defragment the KV cache
This will be applied:
- lazily on next llama_decode()
- explicitly with llama_kv_cache_update()"""
...
# // Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
# LLAMA_API void llama_kv_cache_update(struct llama_context * ctx);
@ctypes_function("llama_kv_cache_update", [llama_context_p_ctypes], None)
def llama_kv_cache_update(ctx: llama_context_p, /):
"""Apply the KV cache updates (such as K-shifts, defragmentation, etc.)"""
...
# //
# // State / sessions
# //
@ -1553,8 +1536,6 @@ def llama_kv_cache_seq_div(
# Returns the maximum size in bytes of the state (rng, logits, embedding
# and kv_cache) - will often be smaller after compacting tokens
# LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx);
@ctypes_function("llama_get_state_size", [llama_context_p_ctypes], ctypes.c_size_t)
def llama_get_state_size(ctx: llama_context_p, /) -> int:
"""Returns the maximum size in bytes of the state (rng, logits, embedding
@ -1568,8 +1549,6 @@ def llama_get_state_size(ctx: llama_context_p, /) -> int:
# LLAMA_API size_t llama_copy_state_data(
# struct llama_context * ctx,
# uint8_t * dst);
@ctypes_function(
"llama_copy_state_data",
[
@ -1592,8 +1571,6 @@ def llama_copy_state_data(
# LLAMA_API size_t llama_set_state_data(
# struct llama_context * ctx,
# uint8_t * src);
@ctypes_function(
"llama_set_state_data",
[llama_context_p_ctypes, ctypes.POINTER(ctypes.c_uint8)],
@ -1613,8 +1590,6 @@ def llama_set_state_data(
# llama_token * tokens_out,
# size_t n_token_capacity,
# size_t * n_token_count_out);
@ctypes_function(
"llama_load_session_file",
[
@ -1642,8 +1617,6 @@ def llama_load_session_file(
# const char * path_session,
# const llama_token * tokens,
# size_t n_token_count);
@ctypes_function(
"llama_save_session_file",
[
@ -1680,8 +1653,6 @@ def llama_save_session_file(
# int32_t n_tokens,
# int32_t n_past),
# "use llama_decode() instead");
@ctypes_function(
"llama_eval",
[
@ -1715,8 +1686,6 @@ def llama_eval(
# int32_t n_tokens,
# int32_t n_past),
# "use llama_decode() instead");
@ctypes_function(
"llama_eval_embd",
[
@ -1748,8 +1717,6 @@ def llama_eval_embd(
# int32_t n_tokens,
# llama_pos pos_0,
# llama_seq_id seq_id);
@ctypes_function(
"llama_batch_get_one",
[
@ -1785,8 +1752,6 @@ def llama_batch_get_one(
# int32_t n_tokens,
# int32_t embd,
# int32_t n_seq_max);
@ctypes_function(
"llama_batch_init", [ctypes.c_int32, ctypes.c_int32, ctypes.c_int32], llama_batch
)
@ -1808,8 +1773,6 @@ def llama_batch_init(
# // Frees a batch of tokens allocated with llama_batch_init()
# LLAMA_API void llama_batch_free(struct llama_batch batch);
@ctypes_function("llama_batch_free", [llama_batch], None)
def llama_batch_free(batch: llama_batch, /):
"""Frees a batch of tokens allocated with llama_batch_init()"""
@ -1823,8 +1786,6 @@ def llama_batch_free(batch: llama_batch, /):
# LLAMA_API int32_t llama_decode(
# struct llama_context * ctx,
# struct llama_batch batch);
@ctypes_function("llama_decode", [llama_context_p_ctypes, llama_batch], ctypes.c_int32)
def llama_decode(ctx: llama_context_p, batch: llama_batch, /) -> int:
"""Positive return values does not mean a fatal error, but rather a warning.
@ -1838,8 +1799,6 @@ def llama_decode(ctx: llama_context_p, batch: llama_batch, /) -> int:
# // n_threads is the number of threads used for generation (single token)
# // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
# LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch);
@ctypes_function(
"llama_set_n_threads",
[
@ -1868,8 +1827,6 @@ def llama_set_n_threads(
# // Rows: n_tokens provided with llama_batch
# // Cols: n_vocab
# LLAMA_API float * llama_get_logits(struct llama_context * ctx);
@ctypes_function(
"llama_get_logits", [llama_context_p_ctypes], ctypes.POINTER(ctypes.c_float)
)
@ -1885,8 +1842,6 @@ def llama_get_logits(ctx: llama_context_p, /) -> CtypesArray[ctypes.c_float]:
# // Logits for the ith token. Equivalent to:
# // llama_get_logits(ctx) + i*n_vocab
# LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
@ctypes_function(
"llama_get_logits_ith",
[llama_context_p_ctypes, ctypes.c_int32],
@ -1903,8 +1858,6 @@ def llama_get_logits_ith(
# Get the embeddings for the input
# shape: [n_embd] (1-dimensional)
# LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
@ctypes_function(
"llama_get_embeddings", [llama_context_p_ctypes], ctypes.POINTER(ctypes.c_float)
)
@ -1917,8 +1870,6 @@ def llama_get_embeddings(ctx: llama_context_p, /) -> CtypesArray[ctypes.c_float]
# // Get the embeddings for the ith sequence
# // llama_get_embeddings(ctx) + i*n_embd
# LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
@ctypes_function(
"llama_get_embeddings_ith",
[llama_context_p_ctypes, ctypes.c_int32],
@ -1938,8 +1889,6 @@ def llama_get_embeddings_ith(
# LLAMA_API const char * llama_token_get_text(const struct llama_model * model, llama_token token);
@ctypes_function(
"llama_token_get_text", [llama_model_p_ctypes, llama_token], ctypes.c_char_p
)
@ -1950,8 +1899,6 @@ def llama_token_get_text(
# LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token);
@ctypes_function(
"llama_token_get_score", [llama_model_p_ctypes, llama_token], ctypes.c_float
)
@ -1962,8 +1909,6 @@ def llama_token_get_score(
# LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token);
@ctypes_function(
"llama_token_get_type", [llama_model_p_ctypes, llama_token], ctypes.c_int
)
@ -1977,8 +1922,6 @@ def llama_token_get_type(
# LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence
@ctypes_function("llama_token_bos", [llama_model_p_ctypes], llama_token)
def llama_token_bos(model: llama_model_p, /) -> int:
"""beginning-of-sentence"""
@ -1986,8 +1929,6 @@ def llama_token_bos(model: llama_model_p, /) -> int:
# LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
@ctypes_function("llama_token_eos", [llama_model_p_ctypes], llama_token)
def llama_token_eos(model: llama_model_p, /) -> int:
"""end-of-sentence"""
@ -1995,8 +1936,6 @@ def llama_token_eos(model: llama_model_p, /) -> int:
# LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
@ctypes_function("llama_token_nl", [llama_model_p_ctypes], llama_token)
def llama_token_nl(model: llama_model_p, /) -> int:
"""next-line"""
@ -2005,8 +1944,6 @@ def llama_token_nl(model: llama_model_p, /) -> int:
# // Returns -1 if unknown, 1 for true or 0 for false.
# LLAMA_API int32_t llama_add_bos_token(const struct llama_model * model);
@ctypes_function("llama_add_bos_token", [llama_model_p_ctypes], ctypes.c_int32)
def llama_add_bos_token(model: llama_model_p, /) -> int:
"""Returns -1 if unknown, 1 for true or 0 for false."""
@ -2015,8 +1952,6 @@ def llama_add_bos_token(model: llama_model_p, /) -> int:
# // Returns -1 if unknown, 1 for true or 0 for false.
# LLAMA_API int32_t llama_add_eos_token(const struct llama_model * model);
@ctypes_function("llama_add_eos_token", [llama_model_p_ctypes], ctypes.c_int32)
def llama_add_eos_token(model: llama_model_p, /) -> int:
"""Returns -1 if unknown, 1 for true or 0 for false."""
@ -2025,8 +1960,6 @@ def llama_add_eos_token(model: llama_model_p, /) -> int:
# // codellama infill tokens
# LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix
@ctypes_function("llama_token_prefix", [llama_model_p_ctypes], llama_token)
def llama_token_prefix(model: llama_model_p) -> int:
"""codellama infill tokens"""
@ -2034,24 +1967,18 @@ def llama_token_prefix(model: llama_model_p) -> int:
# LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle
@ctypes_function("llama_token_middle", [llama_model_p_ctypes], llama_token)
def llama_token_middle(model: llama_model_p, /) -> int:
...
# LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix
@ctypes_function("llama_token_suffix", [llama_model_p_ctypes], llama_token)
def llama_token_suffix(model: llama_model_p, /) -> int:
...
# LLAMA_API llama_token llama_token_eot (const struct llama_model * model); // End of infill middle
@ctypes_function("llama_token_eot", [llama_model_p_ctypes], llama_token)
def llama_token_eot(model: llama_model_p, /) -> int:
...
@ -2076,8 +2003,6 @@ def llama_token_eot(model: llama_model_p, /) -> int:
# int32_t n_max_tokens,
# bool add_bos,
# bool special);
@ctypes_function(
"llama_tokenize",
[
@ -2114,8 +2039,6 @@ def llama_tokenize(
# llama_token token,
# char * buf,
# int32_t length);
@ctypes_function(
"llama_token_to_piece",
[
@ -2159,8 +2082,6 @@ def llama_token_to_piece(
# bool add_ass,
# char * buf,
# int32_t length);
@ctypes_function(
"llama_chat_apply_template",
[
@ -2190,8 +2111,6 @@ def llama_chat_apply_template(
# const llama_grammar_element ** rules,
# size_t n_rules,
# size_t start_rule_index);
@ctypes_function(
"llama_grammar_init",
[

View file

@ -16,12 +16,23 @@ class BaseLlamaTokenizer(abc.ABC):
def tokenize(
self, text: bytes, add_bos: bool = True, special: bool = True
) -> List[int]:
"""Tokenize the text into tokens.
Args:
text: The text to tokenize.
add_bos: Whether to add a beginning of sequence token.
special: Whether to tokenize text literally or as special tokens."""
raise NotImplementedError
@abc.abstractmethod
def detokenize(
self, tokens: List[int], prev_tokens: Optional[List[int]] = None
) -> bytes:
"""Detokenize the tokens into text.
Args:
tokens: The tokens to detokenize.
prev_tokens: If tokens is a continuation of a previous sequence, the previous tokens."""
raise NotImplementedError
@ -37,10 +48,7 @@ class LlamaTokenizer(BaseLlamaTokenizer):
def detokenize(
self, tokens: List[int], prev_tokens: Optional[List[int]] = None
) -> bytes:
if prev_tokens is not None:
return self._model.detokenize(tokens[len(prev_tokens) :])
else:
return self._model.detokenize(tokens)
return self._model.detokenize(tokens)
def encode(
self, text: str, add_bos: bool = True, special: bool = True
@ -72,7 +80,7 @@ class LlamaHFTokenizer(BaseLlamaTokenizer):
self, tokens: List[int], prev_tokens: Optional[List[int]] = None
) -> bytes:
if prev_tokens is not None:
text = self.hf_tokenizer.decode(tokens).encode("utf-8", errors="ignore")
text = self.hf_tokenizer.decode(prev_tokens + tokens).encode("utf-8", errors="ignore")
prev_text = self.hf_tokenizer.decode(prev_tokens).encode(
"utf-8", errors="ignore"
)

View file

@ -1,6 +1,7 @@
import sys
import os
import ctypes
import functools
from ctypes import (
c_bool,
c_char_p,
@ -13,7 +14,7 @@ from ctypes import (
Structure,
)
import pathlib
from typing import List, Union, NewType, Optional
from typing import List, Union, NewType, Optional, TypeVar, Callable, Any
import llama_cpp.llama_cpp as llama_cpp
@ -76,6 +77,31 @@ _libllava_base_name = "llava"
# Load the library
_libllava = _load_shared_library(_libllava_base_name)
# ctypes helper
F = TypeVar("F", bound=Callable[..., Any])
def ctypes_function_for_shared_library(lib: ctypes.CDLL):
def ctypes_function(
name: str, argtypes: List[Any], restype: Any, enabled: bool = True
):
def decorator(f: F) -> F:
if enabled:
func = getattr(lib, name)
func.argtypes = argtypes
func.restype = restype
functools.wraps(f)(func)
return func
else:
return f
return decorator
return ctypes_function
ctypes_function = ctypes_function_for_shared_library(_libllava)
################################################
# llava.h
@ -97,49 +123,35 @@ class llava_image_embed(Structure):
# /** sanity check for clip <-> llava embed size match */
# LLAVA_API bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip);
@ctypes_function("llava_validate_embed_size", [llama_cpp.llama_context_p_ctypes, clip_ctx_p_ctypes], c_bool)
def llava_validate_embed_size(ctx_llama: llama_cpp.llama_context_p, ctx_clip: clip_ctx_p, /) -> bool:
...
llava_validate_embed_size = _libllava.llava_validate_embed_size
llava_validate_embed_size.argtypes = [llama_cpp.llama_context_p_ctypes, clip_ctx_p_ctypes]
llava_validate_embed_size.restype = c_bool
# /** build an image embed from image file bytes */
# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length);
@ctypes_function("llava_image_embed_make_with_bytes", [clip_ctx_p_ctypes, c_int, POINTER(c_uint8), c_int], POINTER(llava_image_embed))
def llava_image_embed_make_with_bytes(ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_bytes: bytes, image_bytes_length: Union[c_int, int], /) -> "_Pointer[llava_image_embed]":
...
llava_image_embed_make_with_bytes = _libllava.llava_image_embed_make_with_bytes
llava_image_embed_make_with_bytes.argtypes = [clip_ctx_p_ctypes, c_int, POINTER(c_uint8), c_int]
llava_image_embed_make_with_bytes.restype = POINTER(llava_image_embed)
# /** build an image embed from a path to an image filename */
# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path);
@ctypes_function("llava_image_embed_make_with_filename", [clip_ctx_p_ctypes, c_int, c_char_p], POINTER(llava_image_embed))
def llava_image_embed_make_with_filename(ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_path: bytes, /) -> "_Pointer[llava_image_embed]":
...
llava_image_embed_make_with_filename = _libllava.llava_image_embed_make_with_filename
llava_image_embed_make_with_filename.argtypes = [clip_ctx_p_ctypes, c_int, c_char_p]
llava_image_embed_make_with_filename.restype = POINTER(llava_image_embed)
# LLAVA_API void llava_image_embed_free(struct llava_image_embed * embed);
# /** free an embedding made with llava_image_embed_make_* */
@ctypes_function("llava_image_embed_free", [POINTER(llava_image_embed)], None)
def llava_image_embed_free(embed: "_Pointer[llava_image_embed]", /):
...
llava_image_embed_free = _libllava.llava_image_embed_free
llava_image_embed_free.argtypes = [POINTER(llava_image_embed)]
llava_image_embed_free.restype = None
# /** write the image represented by embed into the llama context with batch size n_batch, starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */
# LLAVA_API bool llava_eval_image_embed(struct llama_context * ctx_llama, const struct llava_image_embed * embed, int n_batch, int * n_past);
@ctypes_function("llava_eval_image_embed", [llama_cpp.llama_context_p_ctypes, POINTER(llava_image_embed), c_int, POINTER(c_int)], c_bool)
def llava_eval_image_embed(ctx_llama: llama_cpp.llama_context_p, embed: "_Pointer[llava_image_embed]", n_batch: Union[c_int, int], n_past: "_Pointer[c_int]", /) -> bool:
...
llava_eval_image_embed = _libllava.llava_eval_image_embed
llava_eval_image_embed.argtypes = [llama_cpp.llama_context_p_ctypes, POINTER(llava_image_embed), c_int, POINTER(c_int)]
llava_eval_image_embed.restype = c_bool
################################################
# clip.h
@ -148,18 +160,12 @@ llava_eval_image_embed.restype = c_bool
# /** load mmproj model */
# CLIP_API struct clip_ctx * clip_model_load (const char * fname, int verbosity);
@ctypes_function("clip_model_load", [c_char_p, c_int], clip_ctx_p_ctypes)
def clip_model_load(fname: bytes, verbosity: Union[c_int, int], /) -> Optional[clip_ctx_p]:
...
clip_model_load = _libllava.clip_model_load
clip_model_load.argtypes = [c_char_p, c_int]
clip_model_load.restype = clip_ctx_p_ctypes
# /** free mmproj model */
# CLIP_API void clip_free(struct clip_ctx * ctx);
@ctypes_function("clip_free", [clip_ctx_p_ctypes], None)
def clip_free(ctx: clip_ctx_p, /):
...
clip_free = _libllava.clip_free
clip_free.argtypes = [clip_ctx_p_ctypes]
clip_free.restype = None

View file

@ -120,9 +120,20 @@ class LlamaProxy:
kv_overrides[key] = float(value)
else:
raise ValueError(f"Unknown value type {value_type}")
import functools
_model = llama_cpp.Llama(
model_path=settings.model,
kwargs = {}
if settings.hf_model_repo_id is not None:
create_fn = functools.partial(llama_cpp.Llama.from_pretrained, repo_id=settings.hf_model_repo_id, filename=settings.model)
else:
create_fn = llama_cpp.Llama
kwargs["model_path"] = settings.model
_model = create_fn(
**kwargs,
# Model Params
n_gpu_layers=settings.n_gpu_layers,
main_gpu=settings.main_gpu,

View file

@ -29,7 +29,7 @@ class ModelSettings(BaseSettings):
description="The number of layers to put on the GPU. The rest will be on the CPU. Set -1 to move all to GPU.",
)
split_mode: int = Field(
default=llama_cpp.LLAMA_SPLIT_LAYER,
default=llama_cpp.LLAMA_SPLIT_MODE_LAYER,
description="The split mode to use.",
)
main_gpu: int = Field(
@ -74,7 +74,7 @@ class ModelSettings(BaseSettings):
ge=0,
description="The number of threads to use when batch processing.",
)
rope_scaling_type: int = Field(default=llama_cpp.LLAMA_ROPE_SCALING_UNSPECIFIED)
rope_scaling_type: int = Field(default=llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED)
rope_freq_base: float = Field(default=0.0, description="RoPE base frequency")
rope_freq_scale: float = Field(
default=0.0, description="RoPE frequency scaling factor"
@ -143,6 +143,11 @@ class ModelSettings(BaseSettings):
default=None,
description="The model name or path to a pretrained HuggingFace tokenizer model. Same as you would pass to AutoTokenizer.from_pretrained().",
)
# Loading from HuggingFace Model Hub
hf_model_repo_id: Optional[str] = Field(
default=None,
description="The model repo id to use for the HuggingFace tokenizer model.",
)
# Speculative Decoding
draft_model: Optional[str] = Field(
default=None,

View file

@ -132,7 +132,7 @@ def mock_llama(monkeypatch):
assert ctx == llama._ctx.ctx, "context does not match mock_llama"
return
def mock_kv_cache_seq_shift(
def mock_kv_cache_seq_add(
ctx: llama_cpp.llama_context_p,
seq_id: llama_cpp.llama_seq_id,
pos0: llama_cpp.llama_pos,
@ -146,7 +146,7 @@ def mock_llama(monkeypatch):
monkeypatch.setattr("llama_cpp.llama_cpp.llama_kv_cache_seq_rm", mock_kv_cache_seq_rm)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_kv_cache_seq_cp", mock_kv_cache_seq_cp)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_kv_cache_seq_keep", mock_kv_cache_seq_keep)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_kv_cache_seq_shift", mock_kv_cache_seq_shift)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_kv_cache_seq_add", mock_kv_cache_seq_add)
return setup_mock

2
vendor/llama.cpp vendored

@ -1 +1 @@
Subproject commit 15499eb94227401bdc8875da6eb85c15d37068f7
Subproject commit a33e6a0d2a66104ea9a906bdbf8a94d050189d91