Compare commits

...

7 commits

Author SHA1 Message Date
baalajimaestro eebae1a368
Merge https://github.com/abetlen/llama-cpp-python 2024-02-29 21:11:35 +05:30
Andrei Betlen 8c71725d53 fix: Remove deprecated cfg sampling functions 2024-02-28 14:37:07 -05:00
Andrei Betlen 727d60c28a misc: Format 2024-02-28 14:27:40 -05:00
Andrei Betlen 0d37ce52b1 feat: Update llama.cpp 2024-02-28 14:27:16 -05:00
Andrei Betlen ffcd4b2636 chore: Bump version 2024-02-28 01:38:32 -05:00
Sigbjørn Skjæret c36ab15e68
fix: eos/bos_token set correctly for Jinja2ChatFormatter and automatic chat formatter (#1230)
The token strings were not correctly retrieved (empty).
2024-02-28 01:30:31 -05:00
Andrei Betlen fea33c9b94 feat: Update llama.cpp 2024-02-27 12:22:17 -05:00
11 changed files with 63 additions and 237 deletions

View file

@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
## [0.2.53]
- feat: Update llama.cpp to ggerganov/llama.cpp@cb49e0f8c906e5da49e9f6d64a57742a9a241c6a
- fix: eos/bos_token set correctly for Jinja2ChatFormatter and automatic chat formatter by @CISC in #1230
## [0.2.52]
- feat: Update llama.cpp to ggerganov/llama.cpp@a33e6a0d2a66104ea9a906bdbf8a94d050189d91

View file

@ -1,4 +1,4 @@
from .llama_cpp import *
from .llama import *
__version__ = "0.2.52"
__version__ = "0.2.53"

View file

@ -357,21 +357,6 @@ class _LlamaContext:
penalty_present,
)
def sample_classifier_free_guidance(
self,
candidates: "_LlamaTokenDataArray",
guidance_ctx: "_LlamaContext",
scale: float,
):
assert self.ctx is not None
assert guidance_ctx.ctx is not None
llama_cpp.llama_sample_classifier_free_guidance(
self.ctx,
llama_cpp.byref(candidates.candidates),
guidance_ctx.ctx,
scale,
)
def sample_softmax(self, candidates: "_LlamaTokenDataArray"):
assert self.ctx is not None
llama_cpp.llama_sample_softmax(
@ -720,7 +705,7 @@ class _LlamaSamplingContext:
return ctx_main.model.detokenize(self.prev[-n:]).decode("utf-8")
def sample(
self, ctx_main: _LlamaContext, ctx_cfg: Optional[_LlamaContext] = None, idx: int = 0, logits_array: Optional[npt.NDArray[np.single]] = None
self, ctx_main: _LlamaContext, idx: int = 0, logits_array: Optional[npt.NDArray[np.single]] = None
):
n_vocab = ctx_main.model.n_vocab()
id: int = 0
@ -741,11 +726,6 @@ class _LlamaSamplingContext:
) # TODO: Only create this once
token_data_array.copy_logits(logits_array)
if ctx_cfg is not None:
ctx_main.sample_classifier_free_guidance(
token_data_array, ctx_cfg, self.params.cfg_scale
)
# apply penalties
if len(self.prev) > 0:
nl_token = ctx_main.model.token_nl()

View file

@ -408,8 +408,8 @@ class Llama:
except:
bos_token_id = self.token_bos()
eos_token = self.detokenize([eos_token_id]).decode("utf-8")
bos_token = self.detokenize([bos_token_id]).decode("utf-8")
eos_token = self._model.token_get_text(eos_token_id)
bos_token = self._model.token_get_text(bos_token_id)
if self.verbose:
print(f"Using chat template: {template}", file=sys.stderr)

View file

@ -111,6 +111,7 @@ if TYPE_CHECKING:
F = TypeVar("F", bound=Callable[..., Any])
def ctypes_function_for_shared_library(lib: ctypes.CDLL):
def ctypes_function(
name: str, argtypes: List[Any], restype: Any, enabled: bool = True
@ -264,6 +265,7 @@ LLAMA_TOKEN_TYPE_BYTE = 6
# LLAMA_FTYPE_MOSTLY_IQ3_M = 27, // except 1d tensors
# LLAMA_FTYPE_MOSTLY_IQ2_S = 28, // except 1d tensors
# LLAMA_FTYPE_MOSTLY_IQ2_M = 29, // except 1d tensors
# LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors
# LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
# };
@ -295,6 +297,7 @@ LLAMA_FTYPE_MOSTLY_IQ3_S = 26
LLAMA_FTYPE_MOSTLY_IQ3_M = 27
LLAMA_FTYPE_MOSTLY_IQ2_S = 28
LLAMA_FTYPE_MOSTLY_IQ2_M = 29
LLAMA_FTYPE_MOSTLY_IQ4_XS = 30
LLAMA_FTYPE_GUESSED = 1024
# enum llama_rope_scaling_type {
@ -548,6 +551,7 @@ class llama_model_params(ctypes.Structure):
# float yarn_beta_fast; // YaRN low correction dim
# float yarn_beta_slow; // YaRN high correction dim
# uint32_t yarn_orig_ctx; // YaRN original context size
# float defrag_thold; // defragment the KV cache if holes/size > thold, < 0 disabled (default)
# ggml_backend_sched_eval_callback cb_eval;
# void * cb_eval_user_data;
@ -580,6 +584,7 @@ class llama_context_params(ctypes.Structure):
yarn_beta_fast (float): YaRN low correction dim
yarn_beta_slow (float): YaRN high correction dim
yarn_orig_ctx (int): YaRN original context size
defrag_thold (float): defragment the KV cache if holes/size > thold, < 0 disabled (default)
cb_eval (ggml_backend_sched_eval_callback): callback for scheduling eval
cb_eval_user_data (ctypes.ctypes.c_void_p): user data for cb_eval
type_k (int): data type for K cache
@ -605,6 +610,7 @@ class llama_context_params(ctypes.Structure):
("yarn_beta_fast", ctypes.c_float),
("yarn_beta_slow", ctypes.c_float),
("yarn_orig_ctx", ctypes.c_uint32),
("defrag_thold", ctypes.c_float),
("cb_eval", ggml_backend_sched_eval_callback),
("cb_eval_user_data", ctypes.c_void_p),
("type_k", ctypes.c_int),
@ -933,18 +939,6 @@ def llama_supports_gpu_offload() -> bool:
...
# LLAMA_API DEPRECATED(bool llama_mmap_supported (void), "use llama_supports_mmap() instead");
@ctypes_function("llama_mmap_supported", [], ctypes.c_bool)
def llama_mmap_supported() -> bool:
...
# LLAMA_API DEPRECATED(bool llama_mlock_supported(void), "use llama_supports_mlock() instead");
@ctypes_function("llama_mlock_supported", [], ctypes.c_bool)
def llama_mlock_supported() -> bool:
...
# LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
@ctypes_function("llama_get_model", [llama_context_p_ctypes], llama_model_p_ctypes)
def llama_get_model(ctx: llama_context_p, /) -> Optional[llama_model_p]:
@ -1153,47 +1147,6 @@ def llama_model_quantize(
...
# // Apply a LoRA adapter to a loaded model
# // path_base_model is the path to a higher quality model to use as a base for
# // the layers modified by the adapter. Can be NULL to use the current loaded model.
# // The model needs to be reloaded before applying a new adapter, otherwise the adapter
# // will be applied on top of the previous one
# // Returns 0 on success
# LLAMA_API DEPRECATED(int32_t llama_apply_lora_from_file(
# struct llama_context * ctx,
# const char * path_lora,
# float scale,
# const char * path_base_model,
# int32_t n_threads),
# "use llama_model_apply_lora_from_file instead");
@ctypes_function(
"llama_apply_lora_from_file",
[
llama_context_p_ctypes,
ctypes.c_char_p,
ctypes.c_float,
ctypes.c_char_p,
ctypes.c_int32,
],
ctypes.c_int32,
)
def llama_apply_lora_from_file(
ctx: llama_context_p,
path_lora: Union[ctypes.c_char_p, bytes],
scale: Union[ctypes.c_float, float],
path_base_model: Union[ctypes.c_char_p, bytes],
n_threads: Union[ctypes.c_int32, int],
/,
) -> int:
"""Apply a LoRA adapter to a loaded model
path_base_model is the path to a higher quality model to use as a base for
the layers modified by the adapter. Can be NULL to use the current loaded model.
The model needs to be reloaded before applying a new adapter, otherwise the adapter
will be applied on top of the previous one
Returns 0 on success"""
...
# LLAMA_API int32_t llama_model_apply_lora_from_file(
# const struct llama_model * model,
# const char * path_lora,
@ -1215,7 +1168,7 @@ def llama_model_apply_lora_from_file(
model: llama_model_p,
path_lora: Union[ctypes.c_char_p, bytes],
scale: Union[ctypes.c_float, float],
path_base_model: Union[ctypes.c_char_p, bytes],
path_base_model: Union[ctypes.c_char_p, bytes, None],
n_threads: Union[ctypes.c_int32, int],
/,
) -> int:
@ -1642,72 +1595,6 @@ def llama_save_session_file(
# //
# // Run the llama inference to obtain the logits and probabilities for the next token(s).
# // tokens + n_tokens is the provided batch of new tokens to process
# // n_past is the number of tokens to use from previous eval calls
# // Returns 0 on success
# // DEPRECATED: use llama_decode() instead
# LLAMA_API DEPRECATED(int llama_eval(
# struct llama_context * ctx,
# llama_token * tokens,
# int32_t n_tokens,
# int32_t n_past),
# "use llama_decode() instead");
@ctypes_function(
"llama_eval",
[
llama_context_p_ctypes,
llama_token_p,
ctypes.c_int32,
ctypes.c_int32,
],
ctypes.c_int,
)
def llama_eval(
ctx: llama_context_p,
tokens: CtypesArray[llama_token],
n_tokens: Union[ctypes.c_int, int],
n_past: Union[ctypes.c_int, int],
/,
) -> int:
"""Run the llama inference to obtain the logits and probabilities for the next token(s).
tokens + n_tokens is the provided batch of new tokens to process
n_past is the number of tokens to use from previous eval calls
Returns 0 on success
DEPRECATED: use llama_decode() instead"""
...
# // Same as llama_eval, but use float matrix input directly.
# // DEPRECATED: use llama_decode() instead
# LLAMA_API DEPRECATED(int llama_eval_embd(
# struct llama_context * ctx,
# float * embd,
# int32_t n_tokens,
# int32_t n_past),
# "use llama_decode() instead");
@ctypes_function(
"llama_eval_embd",
[
llama_context_p_ctypes,
ctypes.POINTER(ctypes.c_float),
ctypes.c_int32,
ctypes.c_int32,
],
ctypes.c_int,
)
def llama_eval_embd(
ctx: llama_context_p,
embd: CtypesArray[ctypes.c_float],
n_tokens: Union[ctypes.c_int, int],
n_past: Union[ctypes.c_int, int],
/,
) -> int:
"""Same as llama_eval, but use float matrix input directly.
DEPRECATED: use llama_decode() instead"""
...
# // Return batch for single sequence of tokens starting at pos_0
# //
# // NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it
@ -2242,35 +2129,6 @@ def llama_sample_apply_guidance(
...
# LLAMA_API DEPRECATED(void llama_sample_classifier_free_guidance(
# struct llama_context * ctx,
# llama_token_data_array * candidates,
# struct llama_context * guidance_ctx,
# float scale),
# "use llama_sample_apply_guidance() instead");
@ctypes_function(
"llama_sample_classifier_free_guidance",
[
llama_context_p_ctypes,
llama_token_data_array_p,
llama_context_p_ctypes,
ctypes.c_float,
],
None,
)
def llama_sample_classifier_free_guidance(
ctx: llama_context_p,
candidates: Union[
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
],
guidance_ctx: llama_context_p,
scale: Union[ctypes.c_float, float],
/,
):
"""Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806"""
...
# /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
# LLAMA_API void llama_sample_softmax(
# struct llama_context * ctx,
@ -2469,28 +2327,6 @@ def llama_sample_temp(
...
# LLAMA_API DEPRECATED(void llama_sample_temperature(
# struct llama_context * ctx,
# llama_token_data_array * candidates,
# float temp),
# "use llama_sample_temp instead");
@ctypes_function(
"llama_sample_temperature",
[llama_context_p_ctypes, llama_token_data_array_p, ctypes.c_float],
None,
)
def llama_sample_temperature(
ctx: llama_context_p,
candidates: Union[
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
],
temp: Union[ctypes.c_float, float],
/,
):
"""use llama_sample_temp instead"""
...
# /// @details Apply constraints from grammar
# LLAMA_API void llama_sample_grammar(
# struct llama_context * ctx,

View file

@ -199,8 +199,8 @@ async def authenticate(
@router.post(
"/v1/completions",
summary="Completion",
dependencies=[Depends(authenticate)],
response_model= Union[
dependencies=[Depends(authenticate)],
response_model=Union[
llama_cpp.CreateCompletionResponse,
str,
],
@ -211,19 +211,19 @@ async def authenticate(
"application/json": {
"schema": {
"anyOf": [
{"$ref": "#/components/schemas/CreateCompletionResponse"}
{"$ref": "#/components/schemas/CreateCompletionResponse"}
],
"title": "Completion response, when stream=False",
}
},
"text/event-stream":{
"schema": {
"type": "string",
"title": "Server Side Streaming response, when stream=True. " +
"See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501
"example": """data: {... see CreateCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]"""
"text/event-stream": {
"schema": {
"type": "string",
"title": "Server Side Streaming response, when stream=True. "
+ "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501
"example": """data: {... see CreateCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]""",
}
}
},
},
}
},
@ -290,7 +290,7 @@ async def create_completion(
inner_send_chan=send_chan,
iterator=iterator(),
),
sep='\n',
sep="\n",
)
else:
return iterator_or_completion
@ -310,10 +310,10 @@ async def create_embedding(
@router.post(
"/v1/chat/completions", summary="Chat", dependencies=[Depends(authenticate)],
response_model= Union[
llama_cpp.ChatCompletion, str
],
"/v1/chat/completions",
summary="Chat",
dependencies=[Depends(authenticate)],
response_model=Union[llama_cpp.ChatCompletion, str],
responses={
"200": {
"description": "Successful Response",
@ -321,19 +321,21 @@ async def create_embedding(
"application/json": {
"schema": {
"anyOf": [
{"$ref": "#/components/schemas/CreateChatCompletionResponse"}
{
"$ref": "#/components/schemas/CreateChatCompletionResponse"
}
],
"title": "Completion response, when stream=False",
}
},
"text/event-stream":{
"schema": {
"type": "string",
"title": "Server Side Streaming response, when stream=True" +
"See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501
"example": """data: {... see CreateChatCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]"""
"text/event-stream": {
"schema": {
"type": "string",
"title": "Server Side Streaming response, when stream=True"
+ "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501
"example": """data: {... see CreateChatCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]""",
}
}
},
},
}
},
@ -383,7 +385,7 @@ async def create_chat_completion(
inner_send_chan=send_chan,
iterator=iterator(),
),
sep='\n',
sep="\n",
)
else:
return iterator_or_completion

View file

@ -22,6 +22,7 @@ from llama_cpp.server.types import (
CreateChatCompletionRequest,
)
class ErrorResponse(TypedDict):
"""OpenAI style error response"""
@ -75,7 +76,7 @@ class ErrorResponseFormatters:
(completion_tokens or 0) + prompt_tokens,
prompt_tokens,
completion_tokens,
), # type: ignore
), # type: ignore
type="invalid_request_error",
param="messages",
code="context_length_exceeded",
@ -207,4 +208,3 @@ class RouteErrorHandler(APIRoute):
)
return custom_route_handler

View file

@ -88,15 +88,15 @@ class LlamaProxy:
assert (
settings.hf_tokenizer_config_path is not None
), "hf_tokenizer_config_path must be set for hf-tokenizer-config"
chat_handler = (
llama_cpp.llama_chat_format.hf_tokenizer_config_to_chat_completion_handler(
json.load(open(settings.hf_tokenizer_config_path))
)
chat_handler = llama_cpp.llama_chat_format.hf_tokenizer_config_to_chat_completion_handler(
json.load(open(settings.hf_tokenizer_config_path))
)
tokenizer: Optional[llama_cpp.BaseLlamaTokenizer] = None
if settings.hf_pretrained_model_name_or_path is not None:
tokenizer = llama_tokenizer.LlamaHFTokenizer.from_pretrained(settings.hf_pretrained_model_name_or_path)
tokenizer = llama_tokenizer.LlamaHFTokenizer.from_pretrained(
settings.hf_pretrained_model_name_or_path
)
draft_model = None
if settings.draft_model is not None:
@ -120,17 +120,20 @@ class LlamaProxy:
kv_overrides[key] = float(value)
else:
raise ValueError(f"Unknown value type {value_type}")
import functools
kwargs = {}
if settings.hf_model_repo_id is not None:
create_fn = functools.partial(llama_cpp.Llama.from_pretrained, repo_id=settings.hf_model_repo_id, filename=settings.model)
create_fn = functools.partial(
llama_cpp.Llama.from_pretrained,
repo_id=settings.hf_model_repo_id,
filename=settings.model,
)
else:
create_fn = llama_cpp.Llama
kwargs["model_path"] = settings.model
_model = create_fn(
**kwargs,

View file

@ -45,11 +45,11 @@ class ModelSettings(BaseSettings):
default=False, description="Whether to only return the vocabulary."
)
use_mmap: bool = Field(
default=llama_cpp.llama_mmap_supported(),
default=llama_cpp.llama_supports_mmap(),
description="Use mmap.",
)
use_mlock: bool = Field(
default=llama_cpp.llama_mlock_supported(),
default=llama_cpp.llama_supports_mlock(),
description="Use mlock.",
)
kv_overrides: Optional[List[str]] = Field(
@ -74,7 +74,9 @@ class ModelSettings(BaseSettings):
ge=0,
description="The number of threads to use when batch processing.",
)
rope_scaling_type: int = Field(default=llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED)
rope_scaling_type: int = Field(
default=llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
)
rope_freq_base: float = Field(default=0.0, description="RoPE base frequency")
rope_freq_scale: float = Field(
default=0.0, description="RoPE frequency scaling factor"
@ -193,6 +195,4 @@ class Settings(ServerSettings, ModelSettings):
class ConfigFileSettings(ServerSettings):
"""Configuration file format settings."""
models: List[ModelSettings] = Field(
default=[], description="Model configs"
)
models: List[ModelSettings] = Field(default=[], description="Model configs")

View file

@ -110,7 +110,7 @@ class CreateCompletionRequest(BaseModel):
default=None,
description="A suffix to append to the generated text. If None, no suffix is appended. Useful for chatbots.",
)
max_tokens: Optional[int] = Field(
max_tokens: Optional[int] = Field(
default=16, ge=0, description="The maximum number of tokens to generate."
)
temperature: float = temperature_field

2
vendor/llama.cpp vendored

@ -1 +1 @@
Subproject commit a33e6a0d2a66104ea9a906bdbf8a94d050189d91
Subproject commit 08c5ee87e4cceb603ecceac90734fcdade57311b