Merge branch main into custom_rope

This commit is contained in:
Andrei Betlen 2023-07-15 15:11:01 -04:00
parent 3f8f276f9f
commit f0797a6054
8 changed files with 211 additions and 68 deletions

View file

@ -7,6 +7,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased] ## [Unreleased]
## [0.1.71]
### Added
- (llama.cpp) Update llama.cpp
### Fixed
- (server) Fix several pydantic v2 migration bugs
## [0.1.70] ## [0.1.70]
### Fixed ### Fixed

View file

@ -135,6 +135,7 @@ A Docker image is available on [GHCR](https://ghcr.io/abetlen/llama-cpp-python).
```bash ```bash
docker run --rm -it -p 8000:8000 -v /path/to/models:/models -e MODEL=/models/ggml-model-name.bin ghcr.io/abetlen/llama-cpp-python:latest docker run --rm -it -p 8000:8000 -v /path/to/models:/models -e MODEL=/models/ggml-model-name.bin ghcr.io/abetlen/llama-cpp-python:latest
``` ```
[Docker on termux (requires root)](https://gist.github.com/FreddieOliveira/efe850df7ff3951cb62d74bd770dce27) is currently the only known way to run this on phones, see [termux support issue](https://github.com/abetlen/llama-cpp-python/issues/389)
## Low-level API ## Low-level API

View file

@ -19,6 +19,7 @@ from typing import (
from collections import deque, OrderedDict from collections import deque, OrderedDict
import diskcache import diskcache
import ctypes
from . import llama_cpp from . import llama_cpp
from .llama_types import * from .llama_types import *
@ -26,7 +27,6 @@ from .llama_types import *
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
class BaseLlamaCache(ABC): class BaseLlamaCache(ABC):
"""Base cache class for a llama.cpp model.""" """Base cache class for a llama.cpp model."""
@ -222,6 +222,7 @@ class Llama:
lora_base: Optional[str] = None, lora_base: Optional[str] = None,
lora_path: Optional[str] = None, lora_path: Optional[str] = None,
low_vram: bool = False, low_vram: bool = False,
tensor_split: Optional[List[float]] = None,
verbose: bool = True, verbose: bool = True,
): ):
"""Load a llama.cpp model from `model_path`. """Load a llama.cpp model from `model_path`.
@ -244,6 +245,7 @@ class Llama:
last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque. last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model. lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
lora_path: Path to a LoRA file to apply to the model. lora_path: Path to a LoRA file to apply to the model.
tensor_split: List of floats to split the model across multiple GPUs. If None, the model is not split.
verbose: Print verbose output to stderr. verbose: Print verbose output to stderr.
Raises: Raises:
@ -252,6 +254,7 @@ class Llama:
Returns: Returns:
A Llama instance. A Llama instance.
""" """
self.verbose = verbose self.verbose = verbose
self.model_path = model_path self.model_path = model_path
@ -269,6 +272,15 @@ class Llama:
self.params.embedding = embedding self.params.embedding = embedding
self.params.low_vram = low_vram self.params.low_vram = low_vram
self.tensor_split = tensor_split
self._c_tensor_split = None
if self.tensor_split is not None:
#Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES.value
self._c_tensor_split = FloatArray(*tensor_split) # keep a reference to the array so it is not gc'd
self.params.tensor_split = self._c_tensor_split
self.last_n_tokens_size = last_n_tokens_size self.last_n_tokens_size = last_n_tokens_size
self.n_batch = min(n_ctx, n_batch) self.n_batch = min(n_ctx, n_batch)
@ -1509,6 +1521,7 @@ class Llama:
n_threads=self.n_threads, n_threads=self.n_threads,
lora_base=self.lora_base, lora_base=self.lora_base,
lora_path=self.lora_path, lora_path=self.lora_path,
tensor_split=self.tensor_split,
### DEPRECATED ### ### DEPRECATED ###
n_parts=self.n_parts, n_parts=self.n_parts,
### DEPRECATED ### ### DEPRECATED ###
@ -1533,6 +1546,7 @@ class Llama:
last_n_tokens_size=state["last_n_tokens_size"], last_n_tokens_size=state["last_n_tokens_size"],
lora_base=state["lora_base"], lora_base=state["lora_base"],
lora_path=state["lora_path"], lora_path=state["lora_path"],
tensor_split=state["tensor_split"],
verbose=state["verbose"], verbose=state["verbose"],
) )

View file

@ -165,12 +165,16 @@ llama_progress_callback = ctypes.CFUNCTYPE(None, c_float, c_void_p)
# int32_t n_gpu_layers; // number of layers to store in VRAM # int32_t n_gpu_layers; // number of layers to store in VRAM
# int32_t main_gpu; // the GPU that is used for scratch and small tensors # int32_t main_gpu; // the GPU that is used for scratch and small tensors
# float tensor_split[LLAMA_MAX_DEVICES]; // how to split layers across multiple GPUs # float tensor_split[LLAMA_MAX_DEVICES]; // how to split layers across multiple GPUs
# // ref: https://github.com/ggerganov/llama.cpp/pull/2054
# float rope_freq_base; // RoPE base frequency
# float rope_freq_scale; // RoPE frequency scaling factor
# // called with a progress value between 0 and 1, pass NULL to disable # // called with a progress value between 0 and 1, pass NULL to disable
# llama_progress_callback progress_callback; # llama_progress_callback progress_callback;
# // context pointer passed to the progress callback # // context pointer passed to the progress callback
# void * progress_callback_user_data; # void * progress_callback_user_data;
# // Keep the booleans together to avoid misalignment during copy-by-value. # // Keep the booleans together to avoid misalignment during copy-by-value.
# bool low_vram; // if true, reduce VRAM usage at the cost of performance # bool low_vram; // if true, reduce VRAM usage at the cost of performance
# bool f16_kv; // use fp16 for KV cache # bool f16_kv; // use fp16 for KV cache
@ -190,6 +194,8 @@ class llama_context_params(Structure):
("n_gpu_layers", c_int32), ("n_gpu_layers", c_int32),
("main_gpu", c_int32), ("main_gpu", c_int32),
("tensor_split", c_float * LLAMA_MAX_DEVICES.value), ("tensor_split", c_float * LLAMA_MAX_DEVICES.value),
("rope_freq_base", c_float),
("rope_freq_scale", c_float),
("progress_callback", llama_progress_callback), ("progress_callback", llama_progress_callback),
("progress_callback_user_data", c_void_p), ("progress_callback_user_data", c_void_p),
("low_vram", c_bool), ("low_vram", c_bool),
@ -328,13 +334,23 @@ _lib.llama_mlock_supported.restype = c_bool
# // Initialize the llama + ggml backend # // Initialize the llama + ggml backend
# // If numa is true, use NUMA optimizations # // If numa is true, use NUMA optimizations
# // Call once at the start of the program # // Call once at the start of the program
# LLAMA_API void llama_init_backend(bool numa); # LLAMA_API void llama_backend_init(bool numa);
def llama_init_backend(numa: c_bool): def llama_backend_init(numa: c_bool):
return _lib.llama_init_backend(numa) return _lib.llama_backend_init(numa)
_lib.llama_init_backend.argtypes = [c_bool] _lib.llama_backend_init.argtypes = [c_bool]
_lib.llama_init_backend.restype = None _lib.llama_backend_init.restype = None
# // Call once at the end of the program - currently only used for MPI
# LLAMA_API void llama_backend_free();
def llama_backend_free():
return _lib.llama_backend_free()
_lib.llama_backend_free.argtypes = []
_lib.llama_backend_free.restype = None
# LLAMA_API struct llama_model * llama_load_model_from_file( # LLAMA_API struct llama_model * llama_load_model_from_file(
@ -648,6 +664,22 @@ _lib.llama_tokenize.argtypes = [llama_context_p, c_char_p, llama_token_p, c_int,
_lib.llama_tokenize.restype = c_int _lib.llama_tokenize.restype = c_int
# LLAMA_API int llama_tokenize_with_model(
# const struct llama_model * model,
# const char * text,
# llama_token * tokens,
# int n_max_tokens,
# bool add_bos);
def llama_tokenize_with_model(
model: llama_model_p,
text: bytes,
tokens, # type: Array[llama_token]
n_max_tokens: c_int,
add_bos: c_bool,
) -> int:
return _lib.llama_tokenize_with_model(model, text, tokens, n_max_tokens, add_bos)
# LLAMA_API int llama_n_vocab(const struct llama_context * ctx); # LLAMA_API int llama_n_vocab(const struct llama_context * ctx);
def llama_n_vocab(ctx: llama_context_p) -> int: def llama_n_vocab(ctx: llama_context_p) -> int:
return _lib.llama_n_vocab(ctx) return _lib.llama_n_vocab(ctx)
@ -675,6 +707,33 @@ _lib.llama_n_embd.argtypes = [llama_context_p]
_lib.llama_n_embd.restype = c_int _lib.llama_n_embd.restype = c_int
# LLAMA_API int llama_n_vocab_from_model(const struct llama_model * model);
def llama_n_vocab_from_model(model: llama_model_p) -> int:
return _lib.llama_n_vocab_from_model(model)
_lib.llama_n_vocab_from_model.argtypes = [llama_model_p]
_lib.llama_n_vocab_from_model.restype = c_int
# LLAMA_API int llama_n_ctx_from_model (const struct llama_model * model);
def llama_n_ctx_from_model(model: llama_model_p) -> int:
return _lib.llama_n_ctx_from_model(model)
_lib.llama_n_ctx_from_model.argtypes = [llama_model_p]
_lib.llama_n_ctx_from_model.restype = c_int
# LLAMA_API int llama_n_embd_from_model (const struct llama_model * model);
def llama_n_embd_from_model(model: llama_model_p) -> int:
return _lib.llama_n_embd_from_model(model)
_lib.llama_n_embd_from_model.argtypes = [llama_model_p]
_lib.llama_n_embd_from_model.restype = c_int
# // Get the vocabulary as output parameters. # // Get the vocabulary as output parameters.
# // Returns number of results. # // Returns number of results.
# LLAMA_API int llama_get_vocab( # LLAMA_API int llama_get_vocab(
@ -695,6 +754,20 @@ _lib.llama_get_vocab.argtypes = [llama_context_p, c_char_p, c_float, c_int]
_lib.llama_get_vocab.restype = c_int _lib.llama_get_vocab.restype = c_int
# LLAMA_API int llama_get_vocab_from_model(
# const struct llama_model * model,
# const char * * strings,
# float * scores,
# int capacity);
def llama_get_vocab_from_model(
model: llama_model_p,
strings, # type: Array[c_char_p] # type: ignore
scores, # type: Array[c_float] # type: ignore
capacity: c_int,
) -> int:
return _lib.llama_get_vocab_from_model(model, strings, scores, capacity)
# Token logits obtained from the last call to llama_eval() # Token logits obtained from the last call to llama_eval()
# The logits for the last token are stored in the last row # The logits for the last token are stored in the last row
# Can be mutated in order to change the probabilities of the next token # Can be mutated in order to change the probabilities of the next token
@ -724,8 +797,10 @@ _lib.llama_get_embeddings.argtypes = [llama_context_p]
_lib.llama_get_embeddings.restype = c_float_p _lib.llama_get_embeddings.restype = c_float_p
# Token Id -> String. Uses the vocabulary in the provided context # // Token Id -> String. Uses the vocabulary in the provided context
# LLAMA_API const char * llama_token_to_str(const struct llama_context * ctx, llama_token token); # LLAMA_API const char * llama_token_to_str(
# const struct llama_context * ctx,
# llama_token token);
def llama_token_to_str(ctx: llama_context_p, token: llama_token) -> bytes: def llama_token_to_str(ctx: llama_context_p, token: llama_token) -> bytes:
return _lib.llama_token_to_str(ctx, token) return _lib.llama_token_to_str(ctx, token)
@ -733,6 +808,17 @@ def llama_token_to_str(ctx: llama_context_p, token: llama_token) -> bytes:
_lib.llama_token_to_str.argtypes = [llama_context_p, llama_token] _lib.llama_token_to_str.argtypes = [llama_context_p, llama_token]
_lib.llama_token_to_str.restype = c_char_p _lib.llama_token_to_str.restype = c_char_p
# LLAMA_API const char * llama_token_to_str_with_model(
# const struct llama_model * model,
# llama_token token);
def llama_token_to_str_with_model(model: llama_model_p, token: llama_token) -> bytes:
return _lib.llama_token_to_str_with_model(model, token)
_lib.llama_token_to_str_with_model.argtypes = [llama_model_p, llama_token]
_lib.llama_token_to_str_with_model.restype = c_char_p
# Special tokens # Special tokens
@ -821,6 +907,39 @@ _lib.llama_sample_frequency_and_presence_penalties.argtypes = [
_lib.llama_sample_frequency_and_presence_penalties.restype = None _lib.llama_sample_frequency_and_presence_penalties.restype = None
# /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
# /// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted.
# /// @params guidance_ctx A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
# /// @params scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
# /// @params smooth_factor Smooth factor between guidance logits and original logits. 1.0f means only use guidance logits. 0.0f means only original logits.
# LLAMA_API void llama_sample_classifier_free_guidance(
# struct llama_context * ctx,
# llama_token_data_array * candidates,
# struct llama_context * guidance_ctx,
# float scale,
# float smooth_factor);
def llama_sample_classifier_free_guidance(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
guidance_ctx: llama_context_p,
scale: c_float,
smooth_factor: c_float,
):
return _lib.llama_sample_classifier_free_guidance(
ctx, candidates, guidance_ctx, scale, smooth_factor
)
_lib.llama_sample_classifier_free_guidance.argtypes = [
llama_context_p,
llama_token_data_array_p,
llama_context_p,
c_float,
c_float,
]
_lib.llama_sample_classifier_free_guidance.restype = None
# @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. # @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
# LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates); # LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates);
def llama_sample_softmax( def llama_sample_softmax(
@ -1065,5 +1184,5 @@ _lib.llama_print_system_info.restype = c_char_p
_llama_initialized = False _llama_initialized = False
if not _llama_initialized: if not _llama_initialized:
llama_init_backend(c_bool(False)) llama_backend_init(c_bool(False))
_llama_initialized = True _llama_initialized = True

View file

@ -31,6 +31,10 @@ class Settings(BaseSettings):
ge=0, ge=0,
description="The number of layers to put on the GPU. The rest will be on the CPU.", description="The number of layers to put on the GPU. The rest will be on the CPU.",
) )
tensor_split: Optional[List[float]] = Field(
default=None,
description="Split layers across multiple GPUs in proportion.",
)
seed: int = Field( seed: int = Field(
default=1337, description="Random seed. -1 for random." default=1337, description="Random seed. -1 for random."
) )
@ -80,12 +84,8 @@ class Settings(BaseSettings):
verbose: bool = Field( verbose: bool = Field(
default=True, description="Whether to print debug information." default=True, description="Whether to print debug information."
) )
host: str = Field( host: str = Field(default="localhost", description="Listen address")
default="localhost", description="Listen address" port: int = Field(default=8000, description="Listen port")
)
port: int = Field(
default=8000, description="Listen port"
)
interrupt_requests: bool = Field( interrupt_requests: bool = Field(
default=True, default=True,
description="Whether to interrupt requests when a new request is received.", description="Whether to interrupt requests when a new request is received.",
@ -117,6 +117,7 @@ def create_app(settings: Optional[Settings] = None):
llama = llama_cpp.Llama( llama = llama_cpp.Llama(
model_path=settings.model, model_path=settings.model,
n_gpu_layers=settings.n_gpu_layers, n_gpu_layers=settings.n_gpu_layers,
tensor_split=settings.tensor_split,
seed=settings.seed, seed=settings.seed,
f16_kv=settings.f16_kv, f16_kv=settings.f16_kv,
use_mlock=settings.use_mlock, use_mlock=settings.use_mlock,
@ -178,7 +179,7 @@ def get_settings():
yield settings yield settings
model_field = Field(description="The model to use for generating completions.") model_field = Field(description="The model to use for generating completions.", default=None)
max_tokens_field = Field( max_tokens_field = Field(
default=16, ge=1, le=2048, description="The maximum number of tokens to generate." default=16, ge=1, le=2048, description="The maximum number of tokens to generate."
@ -242,21 +243,18 @@ mirostat_mode_field = Field(
default=0, default=0,
ge=0, ge=0,
le=2, le=2,
description="Enable Mirostat constant-perplexity algorithm of the specified version (1 or 2; 0 = disabled)" description="Enable Mirostat constant-perplexity algorithm of the specified version (1 or 2; 0 = disabled)",
) )
mirostat_tau_field = Field( mirostat_tau_field = Field(
default=5.0, default=5.0,
ge=0.0, ge=0.0,
le=10.0, le=10.0,
description="Mirostat target entropy, i.e. the target perplexity - lower values produce focused and coherent text, larger values produce more diverse and less coherent text" description="Mirostat target entropy, i.e. the target perplexity - lower values produce focused and coherent text, larger values produce more diverse and less coherent text",
) )
mirostat_eta_field = Field( mirostat_eta_field = Field(
default=0.1, default=0.1, ge=0.001, le=1.0, description="Mirostat learning rate"
ge=0.001,
le=1.0,
description="Mirostat learning rate"
) )
@ -294,22 +292,23 @@ class CreateCompletionRequest(BaseModel):
model: Optional[str] = model_field model: Optional[str] = model_field
n: Optional[int] = 1 n: Optional[int] = 1
best_of: Optional[int] = 1 best_of: Optional[int] = 1
user: Optional[str] = Field(None) user: Optional[str] = Field(default=None)
# llama.cpp specific parameters # llama.cpp specific parameters
top_k: int = top_k_field top_k: int = top_k_field
repeat_penalty: float = repeat_penalty_field repeat_penalty: float = repeat_penalty_field
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None) logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
class Config: model_config = {
schema_extra = { "json_schema_extra": {
"example": { "examples": [
"prompt": "\n\n### Instructions:\nWhat is the capital of France?\n\n### Response:\n", {
"stop": ["\n", "###"], "prompt": "\n\n### Instructions:\nWhat is the capital of France?\n\n### Response:\n",
} "stop": ["\n", "###"],
}
]
} }
}
def make_logit_bias_processor( def make_logit_bias_processor(
@ -328,7 +327,7 @@ def make_logit_bias_processor(
elif logit_bias_type == "tokens": elif logit_bias_type == "tokens":
for token, score in logit_bias.items(): for token, score in logit_bias.items():
token = token.encode('utf-8') token = token.encode("utf-8")
for input_id in llama.tokenize(token, add_bos=False): for input_id in llama.tokenize(token, add_bos=False):
to_bias[input_id] = score to_bias[input_id] = score
@ -352,7 +351,7 @@ async def create_completion(
request: Request, request: Request,
body: CreateCompletionRequest, body: CreateCompletionRequest,
llama: llama_cpp.Llama = Depends(get_llama), llama: llama_cpp.Llama = Depends(get_llama),
): ) -> llama_cpp.Completion:
if isinstance(body.prompt, list): if isinstance(body.prompt, list):
assert len(body.prompt) <= 1 assert len(body.prompt) <= 1
body.prompt = body.prompt[0] if len(body.prompt) > 0 else "" body.prompt = body.prompt[0] if len(body.prompt) > 0 else ""
@ -364,7 +363,7 @@ async def create_completion(
"logit_bias_type", "logit_bias_type",
"user", "user",
} }
kwargs = body.dict(exclude=exclude) kwargs = body.model_dump(exclude=exclude)
if body.logit_bias is not None: if body.logit_bias is not None:
kwargs['logits_processor'] = llama_cpp.LogitsProcessorList([ kwargs['logits_processor'] = llama_cpp.LogitsProcessorList([
@ -396,7 +395,7 @@ async def create_completion(
return EventSourceResponse( return EventSourceResponse(
recv_chan, data_sender_callable=partial(event_publisher, send_chan) recv_chan, data_sender_callable=partial(event_publisher, send_chan)
) ) # type: ignore
else: else:
completion: llama_cpp.Completion = await run_in_threadpool(llama, **kwargs) # type: ignore completion: llama_cpp.Completion = await run_in_threadpool(llama, **kwargs) # type: ignore
return completion return completion
@ -405,16 +404,17 @@ async def create_completion(
class CreateEmbeddingRequest(BaseModel): class CreateEmbeddingRequest(BaseModel):
model: Optional[str] = model_field model: Optional[str] = model_field
input: Union[str, List[str]] = Field(description="The input to embed.") input: Union[str, List[str]] = Field(description="The input to embed.")
user: Optional[str] user: Optional[str] = Field(default=None)
class Config: model_config = {
schema_extra = { "json_schema_extra": {
"example": { "examples": [
"input": "The food was delicious and the waiter...", {
} "input": "The food was delicious and the waiter...",
}
]
} }
}
@router.post( @router.post(
@ -424,7 +424,7 @@ async def create_embedding(
request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama) request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama)
): ):
return await run_in_threadpool( return await run_in_threadpool(
llama.create_embedding, **request.dict(exclude={"user"}) llama.create_embedding, **request.model_dump(exclude={"user"})
) )
@ -461,21 +461,22 @@ class CreateChatCompletionRequest(BaseModel):
repeat_penalty: float = repeat_penalty_field repeat_penalty: float = repeat_penalty_field
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None) logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
class Config: model_config = {
schema_extra = { "json_schema_extra": {
"example": { "examples": [
"messages": [ {
ChatCompletionRequestMessage( "messages": [
role="system", content="You are a helpful assistant." ChatCompletionRequestMessage(
), role="system", content="You are a helpful assistant."
ChatCompletionRequestMessage( ).model_dump(),
role="user", content="What is the capital of France?" ChatCompletionRequestMessage(
), role="user", content="What is the capital of France?"
] ).model_dump(),
} ]
}
]
} }
}
@router.post( @router.post(
@ -486,14 +487,14 @@ async def create_chat_completion(
body: CreateChatCompletionRequest, body: CreateChatCompletionRequest,
llama: llama_cpp.Llama = Depends(get_llama), llama: llama_cpp.Llama = Depends(get_llama),
settings: Settings = Depends(get_settings), settings: Settings = Depends(get_settings),
) -> Union[llama_cpp.ChatCompletion]: # type: ignore ) -> llama_cpp.ChatCompletion:
exclude = { exclude = {
"n", "n",
"logit_bias", "logit_bias",
"logit_bias_type", "logit_bias_type",
"user", "user",
} }
kwargs = body.dict(exclude=exclude) kwargs = body.model_dump(exclude=exclude)
if body.logit_bias is not None: if body.logit_bias is not None:
kwargs['logits_processor'] = llama_cpp.LogitsProcessorList([ kwargs['logits_processor'] = llama_cpp.LogitsProcessorList([
@ -526,7 +527,7 @@ async def create_chat_completion(
return EventSourceResponse( return EventSourceResponse(
recv_chan, recv_chan,
data_sender_callable=partial(event_publisher, send_chan), data_sender_callable=partial(event_publisher, send_chan),
) ) # type: ignore
else: else:
completion: llama_cpp.ChatCompletion = await run_in_threadpool( completion: llama_cpp.ChatCompletion = await run_in_threadpool(
llama.create_chat_completion, **kwargs # type: ignore llama.create_chat_completion, **kwargs # type: ignore
@ -546,8 +547,6 @@ class ModelList(TypedDict):
data: List[ModelData] data: List[ModelData]
@router.get("/v1/models") @router.get("/v1/models")
async def get_models( async def get_models(
settings: Settings = Depends(get_settings), settings: Settings = Depends(get_settings),

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "llama_cpp_python" name = "llama_cpp_python"
version = "0.1.70" version = "0.1.71"
description = "Python bindings for the llama.cpp library" description = "Python bindings for the llama.cpp library"
authors = ["Andrei Betlen <abetlen@gmail.com>"] authors = ["Andrei Betlen <abetlen@gmail.com>"]
license = "MIT" license = "MIT"

View file

@ -10,7 +10,7 @@ setup(
description="A Python wrapper for llama.cpp", description="A Python wrapper for llama.cpp",
long_description=long_description, long_description=long_description,
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
version="0.1.70", version="0.1.71",
author="Andrei Betlen", author="Andrei Betlen",
author_email="abetlen@gmail.com", author_email="abetlen@gmail.com",
license="MIT", license="MIT",
@ -18,7 +18,7 @@ setup(
packages=["llama_cpp", "llama_cpp.server"], packages=["llama_cpp", "llama_cpp.server"],
install_requires=["typing-extensions>=4.5.0", "numpy>=1.20.0", "diskcache>=5.6.1"], install_requires=["typing-extensions>=4.5.0", "numpy>=1.20.0", "diskcache>=5.6.1"],
extras_require={ extras_require={
"server": ["uvicorn>=0.22.1", "fastapi>=0.100.0", "pydantic-settings>=2.0.1", "sse-starlette>=1.6.1"], "server": ["uvicorn>=0.22.0", "fastapi>=0.100.0", "pydantic-settings>=2.0.1", "sse-starlette>=1.6.1"],
}, },
python_requires=">=3.7", python_requires=">=3.7",
classifiers=[ classifiers=[

2
vendor/llama.cpp vendored

@ -1 +1 @@
Subproject commit a3b4d932859f4e51ed716bfa1f07e2d2eede2c23 Subproject commit 6e7cca404748dd4b1a3affd0d1296e37f4ac0a6f