Merge branch 'main' into fix-on-m1

This commit is contained in:
Andrei Betlen 2023-08-08 14:38:35 -04:00
commit bf0c603c51
7 changed files with 83 additions and 22 deletions

View file

@ -169,7 +169,7 @@ docker run --rm -it -p 8000:8000 -v /path/to/models:/models -e MODEL=/models/ggm
## Low-level API
The low-level API is a direct [`ctypes`](https://docs.python.org/3/library/ctypes.html) binding to the C API provided by `llama.cpp`.
The entire lowe-level API can be found in [llama_cpp/llama_cpp.py](https://github.com/abetlen/llama-cpp-python/blob/master/llama_cpp/llama_cpp.py) and directly mirrors the C API in [llama.h](https://github.com/ggerganov/llama.cpp/blob/master/llama.h).
The entire low-level API can be found in [llama_cpp/llama_cpp.py](https://github.com/abetlen/llama-cpp-python/blob/master/llama_cpp/llama_cpp.py) and directly mirrors the C API in [llama.h](https://github.com/ggerganov/llama.cpp/blob/master/llama.h).
Below is a short example demonstrating how to use the low-level API to tokenize a prompt:

View file

@ -27,6 +27,8 @@ from .llama_types import *
import numpy as np
import numpy.typing as npt
from .utils import suppress_stdout_stderr
class BaseLlamaCache(ABC):
"""Base cache class for a llama.cpp model."""
@ -224,7 +226,8 @@ class Llama:
rope_freq_base: float = 10000.0,
rope_freq_scale: float = 1.0,
n_gqa: Optional[int] = None, # (TEMPORARY) must be 8 for llama2 70b
rms_norm_eps: Optional[float] = None, # (TEMPORARY)
rms_norm_eps: Optional[float] = None, # (TEMPORARY)
mul_mat_q: Optional(bool) = None, # (TEMPORARY)
verbose: bool = True,
):
"""Load a llama.cpp model from `model_path`.
@ -277,7 +280,9 @@ class Llama:
if self.tensor_split is not None:
FloatArray = (ctypes.c_float * len(self.tensor_split))(*self.tensor_split)
self._p_tensor_split = ctypes.POINTER(ctypes.c_float)(FloatArray) # keep a reference to the array so it is not gc'd
self._p_tensor_split = ctypes.POINTER(ctypes.c_float)(
FloatArray
) # keep a reference to the array so it is not gc'd
self.params.tensor_split = self._p_tensor_split
self.params.rope_freq_base = rope_freq_base
@ -289,6 +294,9 @@ class Llama:
if rms_norm_eps is not None:
self.params.rms_norm_eps = rms_norm_eps
if mul_mat_q is not None:
self.params.mul_mat_q = mul_mat_q
self.last_n_tokens_size = last_n_tokens_size
self.n_batch = min(n_ctx, n_batch)
@ -306,12 +314,25 @@ class Llama:
if not os.path.exists(model_path):
raise ValueError(f"Model path does not exist: {model_path}")
self.model = llama_cpp.llama_load_model_from_file(
self.model_path.encode("utf-8"), self.params
)
if verbose:
self.model = llama_cpp.llama_load_model_from_file(
self.model_path.encode("utf-8"), self.params
)
else:
with suppress_stdout_stderr():
self.model = llama_cpp.llama_load_model_from_file(
self.model_path.encode("utf-8"), self.params
)
assert self.model is not None
self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.params)
if verbose:
self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.params)
else:
with suppress_stdout_stderr():
print("here")
self.ctx = llama_cpp.llama_new_context_with_model(
self.model, self.params
)
assert self.ctx is not None
@ -959,9 +980,7 @@ class Llama:
for token in remaining_tokens:
token_end_position += len(self.detokenize([token]))
# Check if stop sequence is in the token
if token_end_position >= (
remaining_length - first_stop_position
):
if token_end_position >= (remaining_length - first_stop_position):
break
logprobs_or_none: Optional[CompletionLogprobs] = None
if logprobs is not None:
@ -1503,10 +1522,10 @@ class Llama:
return self._convert_text_completion_to_chat(completion)
def __del__(self):
if self.model is not None:
if hasattr(self, "model") and self.model is not None:
llama_cpp.llama_free_model(self.model)
self.model = None
if self.ctx is not None:
if hasattr(self, "ctx") and self.ctx is not None:
llama_cpp.llama_free(self.ctx)
self.ctx = None

View file

@ -103,6 +103,10 @@ class Settings(BaseSettings):
default=None,
description="TEMPORARY",
)
mul_mat_q: Optional[bool] = Field(
default=None,
description="TEMPORARY",
)
class ErrorResponse(TypedDict):

38
llama_cpp/utils.py Normal file
View file

@ -0,0 +1,38 @@
import os
import sys
class suppress_stdout_stderr(object):
# Oddly enough this works better than the contextlib version
def __enter__(self):
self.outnull_file = open(os.devnull, "w")
self.errnull_file = open(os.devnull, "w")
self.old_stdout_fileno_undup = sys.stdout.fileno()
self.old_stderr_fileno_undup = sys.stderr.fileno()
self.old_stdout_fileno = os.dup(sys.stdout.fileno())
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
self.old_stdout = sys.stdout
self.old_stderr = sys.stderr
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
sys.stdout = self.outnull_file
sys.stderr = self.errnull_file
return self
def __exit__(self, *_):
sys.stdout = self.old_stdout
sys.stderr = self.old_stderr
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
os.close(self.old_stdout_fileno)
os.close(self.old_stderr_fileno)
self.outnull_file.close()
self.errnull_file.close()

16
poetry.lock generated
View file

@ -384,17 +384,17 @@ test = ["pytest (>=6)"]
[[package]]
name = "fastapi"
version = "0.100.1"
version = "0.101.0"
description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production"
optional = true
python-versions = ">=3.7"
files = [
{file = "fastapi-0.100.1-py3-none-any.whl", hash = "sha256:ec6dd52bfc4eff3063cfcd0713b43c87640fefb2687bbbe3d8a08d94049cdf32"},
{file = "fastapi-0.100.1.tar.gz", hash = "sha256:522700d7a469e4a973d92321ab93312448fbe20fca9c8da97effc7e7bc56df23"},
{file = "fastapi-0.101.0-py3-none-any.whl", hash = "sha256:494eb3494d89e8079c20859d7ca695f66eaccc40f46fe8c75ab6186d15f05ffd"},
{file = "fastapi-0.101.0.tar.gz", hash = "sha256:ca2ae65fe42f6a34b5cf6c994337149154b1b400c39809d7b2dccdceb5ae77af"},
]
[package.dependencies]
pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<3.0.0"
pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<2.1.0 || >2.1.0,<3.0.0"
starlette = ">=0.27.0,<0.28.0"
typing-extensions = ">=4.5.0"
@ -744,13 +744,13 @@ files = [
[[package]]
name = "mkdocs"
version = "1.5.1"
version = "1.5.2"
description = "Project documentation with Markdown."
optional = false
python-versions = ">=3.7"
files = [
{file = "mkdocs-1.5.1-py3-none-any.whl", hash = "sha256:67e889f8d8ba1fe5decdfc59f5f8f21d6a8925a129339e93dede303bdea03a98"},
{file = "mkdocs-1.5.1.tar.gz", hash = "sha256:f2f323c62fffdf1b71b84849e39aef56d6852b3f0a5571552bca32cefc650209"},
{file = "mkdocs-1.5.2-py3-none-any.whl", hash = "sha256:60a62538519c2e96fe8426654a67ee177350451616118a41596ae7c876bb7eac"},
{file = "mkdocs-1.5.2.tar.gz", hash = "sha256:70d0da09c26cff288852471be03c23f0f521fc15cf16ac89c7a3bfb9ae8d24f9"},
]
[package.dependencies]
@ -1757,4 +1757,4 @@ server = ["fastapi", "pydantic-settings", "sse-starlette", "uvicorn"]
[metadata]
lock-version = "2.0"
python-versions = "^3.8.1"
content-hash = "6718d680fa89f9518a232c1110ba43958d3e21c54c4dbd9129effa4f40a02b81"
content-hash = "4bfb67dfb72b02c845376211f7f958b2ece8c985944fbd03d246c858e846ddf6"

View file

@ -25,7 +25,7 @@ pydantic-settings = { version = ">=2.0.1", optional = true }
[tool.poetry.group.dev.dependencies]
black = "^23.7.0"
twine = "^4.0.2"
mkdocs = "^1.4.3"
mkdocs = "^1.5.2"
mkdocstrings = {extras = ["python"], version = "^0.22.0"}
mkdocs-material = "^9.1.21"
pytest = "^7.4.0"

2
vendor/llama.cpp vendored

@ -1 +1 @@
Subproject commit 8183159cf3def112f6d1fe94815fce70e1bffa12
Subproject commit f5bfea0580e417f99850d5456ca541d871a3e48c