Merge branch 'main' of github.com:abetlen/llama_cpp_python into main

This commit is contained in:
Andrei Betlen 2023-09-12 16:09:15 -04:00
commit db6ceb67ac
9 changed files with 91 additions and 60 deletions

View file

@ -187,7 +187,8 @@ Below is a short example demonstrating how to use the low-level API to tokenize
>>> import ctypes
>>> params = llama_cpp.llama_context_default_params()
# use bytes for char * params
>>> ctx = llama_cpp.llama_init_from_file(b"./models/7b/ggml-model.bin", params)
>>> model = llama_cpp.llama_load_model_from_file(b"./models/7b/ggml-model.bin", params)
>>> ctx = llama_cpp.llama_new_context_with_model(model, params)
>>> max_tokens = params.n_ctx
# use ctypes arrays for array params
>>> tokens = (llama_cpp.llama_token * int(max_tokens))()

View file

@ -24,6 +24,10 @@ class LLaMAInteract:
def __init__(self, params: GptParams) -> None:
# input args
self.params = params
if self.params.path_session is None:
self.params.path_session = ""
if self.params.antiprompt is None:
self.params.antiprompt = ""
if (self.params.perplexity):
raise NotImplementedError("""************
@ -66,7 +70,9 @@ specified) expect poor results""", file=sys.stderr)
self.lparams.use_mlock = self.params.use_mlock
self.lparams.use_mmap = self.params.use_mmap
self.ctx = llama_cpp.llama_init_from_file(self.params.model.encode("utf8"), self.lparams)
self.model = llama_cpp.llama_load_model_from_file(
self.params.model.encode("utf8"), self.lparams)
self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.lparams)
if (not self.ctx):
raise RuntimeError(f"error: failed to load model '{self.params.model}'")
@ -181,12 +187,12 @@ prompt: '{self.params.prompt}'
number of tokens in prompt = {len(self.embd_inp)}""", file=sys.stderr)
for i in range(len(self.embd_inp)):
print(f"{self.embd_inp[i]} -> '{llama_cpp.llama_token_to_str(self.ctx, self.embd_inp[i])}'", file=sys.stderr)
print(f"{self.embd_inp[i]} -> '{self.token_to_str(self.embd_inp[i])}'", file=sys.stderr)
if (self.params.n_keep > 0):
print("static prompt based on n_keep: '")
for i in range(self.params.n_keep):
print(llama_cpp.llama_token_to_str(self.ctx, self.embd_inp[i]), file=sys.stderr)
print(self.token_to_str(self.embd_inp[i]), file=sys.stderr)
print("'", file=sys.stderr)
print(file=sys.stderr)
@ -339,7 +345,7 @@ n_keep = {self.params.n_keep}
candidates_p = llama_cpp.ctypes.pointer(llama_cpp.llama_token_data_array(_arr, len(_arr), False))
# Apply penalties
nl_logit = logits[llama_cpp.llama_token_nl()]
nl_logit = logits[llama_cpp.llama_token_nl(self.ctx)]
last_n_repeat = min(len(self.last_n_tokens), repeat_last_n, self.n_ctx)
_arr = (llama_cpp.llama_token * last_n_repeat)(*self.last_n_tokens[len(self.last_n_tokens) - last_n_repeat:])
@ -380,7 +386,7 @@ n_keep = {self.params.n_keep}
self.last_n_tokens.append(id)
# replace end of text token with newline token when in interactive mode
if (id == llama_cpp.llama_token_eos() and self.params.interactive and not self.params.instruct):
if (id == llama_cpp.llama_token_eos(self.ctx) and self.params.interactive and not self.params.instruct):
id = self.llama_token_newline[0]
self.embd.append(id)
if (self.use_antiprompt()):
@ -437,7 +443,7 @@ n_keep = {self.params.n_keep}
break
# end of text token
if len(self.embd) > 0 and self.embd[-1] == llama_cpp.llama_token_eos():
if len(self.embd) > 0 and self.embd[-1] == llama_cpp.llama_token_eos(self.ctx):
if (not self.params.instruct):
for i in self.llama_token_eot:
yield i
@ -464,10 +470,18 @@ n_keep = {self.params.n_keep}
llama_cpp.llama_free(self.ctx)
self.set_color(util.CONSOLE_COLOR_DEFAULT)
def token_to_str(self, token_id: int) -> bytes:
size = 32
buffer = (ctypes.c_char * size)()
n = llama_cpp.llama_token_to_piece_with_model(
self.model, llama_cpp.llama_token(token_id), buffer, size)
assert n <= size
return bytes(buffer[:n])
# return past text
def past(self):
for id in self.last_n_tokens[-self.n_past:]:
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf8", errors="ignore")
yield self.token_to_str(id).decode("utf8", errors="ignore")
# write input
def input(self, prompt: str):
@ -481,7 +495,7 @@ n_keep = {self.params.n_keep}
def output(self):
self.remaining_tokens = self.params.n_predict
for id in self.generate():
cur_char = llama_cpp.llama_token_to_str(self.ctx, id)
cur_char = self.token_to_str(id)
# Add remainder of missing bytes
if None in self.multibyte_fix:

View file

@ -1,15 +1,17 @@
import llama_cpp
import ctypes
import os
import multiprocessing
import llama_cpp
N_THREADS = multiprocessing.cpu_count()
MODEL_PATH = os.environ.get('MODEL', "../models/7B/ggml-model.bin")
prompt = b"\n\n### Instruction:\nWhat is the capital of France?\n\n### Response:\n"
lparams = llama_cpp.llama_context_default_params()
ctx = llama_cpp.llama_init_from_file(b"../models/7B/ggml-model.bin", lparams)
model = llama_cpp.llama_load_model_from_file(MODEL_PATH.encode('utf-8'), lparams)
ctx = llama_cpp.llama_new_context_with_model(model, lparams)
# determine the required inference memory per token:
tmp = [0, 1, 2, 3]
@ -58,7 +60,8 @@ while remaining_tokens > 0:
llama_cpp.llama_token_data(token_id, logits[token_id], 0.0)
for token_id in range(n_vocab)
])
candidates_p = llama_cpp.ctypes.pointer(llama_cpp.llama_token_data_array(_arr, len(_arr), False))
candidates_p = llama_cpp.ctypes.pointer(
llama_cpp.llama_token_data_array(_arr, len(_arr), False))
_arr = (llama_cpp.c_int * len(last_n_tokens_data))(*last_n_tokens_data)
llama_cpp.llama_sample_repetition_penalty(ctx, candidates_p,
@ -68,9 +71,9 @@ while remaining_tokens > 0:
_arr,
last_n_repeat, frequency_penalty, presence_penalty)
llama_cpp.llama_sample_top_k(ctx, candidates_p, 40)
llama_cpp.llama_sample_top_p(ctx, candidates_p, 0.8)
llama_cpp.llama_sample_temperature(ctx, candidates_p, 0.2)
llama_cpp.llama_sample_top_k(ctx, candidates_p, k=40, min_keep=1)
llama_cpp.llama_sample_top_p(ctx, candidates_p, p=0.8, min_keep=1)
llama_cpp.llama_sample_temperature(ctx, candidates_p, temp=0.2)
id = llama_cpp.llama_sample_token(ctx, candidates_p)
last_n_tokens_data = last_n_tokens_data[1:] + [id]
@ -86,13 +89,18 @@ while remaining_tokens > 0:
break
if not input_noecho:
for id in embd:
size = 32
buffer = (ctypes.c_char * size)()
n = llama_cpp.llama_token_to_piece_with_model(
model, llama_cpp.llama_token(id), buffer, size)
assert n <= size
print(
llama_cpp.llama_token_to_str(ctx, id).decode("utf-8", errors="ignore"),
buffer[:n].decode('utf-8'),
end="",
flush=True,
)
if len(embd) > 0 and embd[-1] == llama_cpp.llama_token_eos():
if len(embd) > 0 and embd[-1] == llama_cpp.llama_token_eos(ctx):
break
print()

View file

@ -1,2 +1,4 @@
from .llama_cpp import *
from .llama import *
from .version import __version__

1
llama_cpp/version.py Normal file
View file

@ -0,0 +1 @@
__version__ = "0.1.84"

76
poetry.lock generated
View file

@ -51,33 +51,33 @@ pytz = {version = ">=2015.7", markers = "python_version < \"3.9\""}
[[package]]
name = "black"
version = "23.7.0"
version = "23.9.1"
description = "The uncompromising code formatter."
optional = false
python-versions = ">=3.8"
files = [
{file = "black-23.7.0-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:5c4bc552ab52f6c1c506ccae05681fab58c3f72d59ae6e6639e8885e94fe2587"},
{file = "black-23.7.0-cp310-cp310-macosx_10_16_universal2.whl", hash = "sha256:552513d5cd5694590d7ef6f46e1767a4df9af168d449ff767b13b084c020e63f"},
{file = "black-23.7.0-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:86cee259349b4448adb4ef9b204bb4467aae74a386bce85d56ba4f5dc0da27be"},
{file = "black-23.7.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:501387a9edcb75d7ae8a4412bb8749900386eaef258f1aefab18adddea1936bc"},
{file = "black-23.7.0-cp310-cp310-win_amd64.whl", hash = "sha256:fb074d8b213749fa1d077d630db0d5f8cc3b2ae63587ad4116e8a436e9bbe995"},
{file = "black-23.7.0-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:b5b0ee6d96b345a8b420100b7d71ebfdd19fab5e8301aff48ec270042cd40ac2"},
{file = "black-23.7.0-cp311-cp311-macosx_10_16_universal2.whl", hash = "sha256:893695a76b140881531062d48476ebe4a48f5d1e9388177e175d76234ca247cd"},
{file = "black-23.7.0-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:c333286dc3ddca6fdff74670b911cccedacb4ef0a60b34e491b8a67c833b343a"},
{file = "black-23.7.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:831d8f54c3a8c8cf55f64d0422ee875eecac26f5f649fb6c1df65316b67c8926"},
{file = "black-23.7.0-cp311-cp311-win_amd64.whl", hash = "sha256:7f3bf2dec7d541b4619b8ce526bda74a6b0bffc480a163fed32eb8b3c9aed8ad"},
{file = "black-23.7.0-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:f9062af71c59c004cd519e2fb8f5d25d39e46d3af011b41ab43b9c74e27e236f"},
{file = "black-23.7.0-cp38-cp38-macosx_10_16_universal2.whl", hash = "sha256:01ede61aac8c154b55f35301fac3e730baf0c9cf8120f65a9cd61a81cfb4a0c3"},
{file = "black-23.7.0-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:327a8c2550ddc573b51e2c352adb88143464bb9d92c10416feb86b0f5aee5ff6"},
{file = "black-23.7.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d1c6022b86f83b632d06f2b02774134def5d4d4f1dac8bef16d90cda18ba28a"},
{file = "black-23.7.0-cp38-cp38-win_amd64.whl", hash = "sha256:27eb7a0c71604d5de083757fbdb245b1a4fae60e9596514c6ec497eb63f95320"},
{file = "black-23.7.0-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:8417dbd2f57b5701492cd46edcecc4f9208dc75529bcf76c514864e48da867d9"},
{file = "black-23.7.0-cp39-cp39-macosx_10_16_universal2.whl", hash = "sha256:47e56d83aad53ca140da0af87678fb38e44fd6bc0af71eebab2d1f59b1acf1d3"},
{file = "black-23.7.0-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:25cc308838fe71f7065df53aedd20327969d05671bac95b38fdf37ebe70ac087"},
{file = "black-23.7.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:642496b675095d423f9b8448243336f8ec71c9d4d57ec17bf795b67f08132a91"},
{file = "black-23.7.0-cp39-cp39-win_amd64.whl", hash = "sha256:ad0014efc7acf0bd745792bd0d8857413652979200ab924fbf239062adc12491"},
{file = "black-23.7.0-py3-none-any.whl", hash = "sha256:9fd59d418c60c0348505f2ddf9609c1e1de8e7493eab96198fc89d9f865e7a96"},
{file = "black-23.7.0.tar.gz", hash = "sha256:022a582720b0d9480ed82576c920a8c1dde97cc38ff11d8d8859b3bd6ca9eedb"},
{file = "black-23.9.1-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:d6bc09188020c9ac2555a498949401ab35bb6bf76d4e0f8ee251694664df6301"},
{file = "black-23.9.1-cp310-cp310-macosx_10_16_universal2.whl", hash = "sha256:13ef033794029b85dfea8032c9d3b92b42b526f1ff4bf13b2182ce4e917f5100"},
{file = "black-23.9.1-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:75a2dc41b183d4872d3a500d2b9c9016e67ed95738a3624f4751a0cb4818fe71"},
{file = "black-23.9.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:13a2e4a93bb8ca74a749b6974925c27219bb3df4d42fc45e948a5d9feb5122b7"},
{file = "black-23.9.1-cp310-cp310-win_amd64.whl", hash = "sha256:adc3e4442eef57f99b5590b245a328aad19c99552e0bdc7f0b04db6656debd80"},
{file = "black-23.9.1-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:8431445bf62d2a914b541da7ab3e2b4f3bc052d2ccbf157ebad18ea126efb91f"},
{file = "black-23.9.1-cp311-cp311-macosx_10_16_universal2.whl", hash = "sha256:8fc1ddcf83f996247505db6b715294eba56ea9372e107fd54963c7553f2b6dfe"},
{file = "black-23.9.1-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:7d30ec46de88091e4316b17ae58bbbfc12b2de05e069030f6b747dfc649ad186"},
{file = "black-23.9.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:031e8c69f3d3b09e1aa471a926a1eeb0b9071f80b17689a655f7885ac9325a6f"},
{file = "black-23.9.1-cp311-cp311-win_amd64.whl", hash = "sha256:538efb451cd50f43aba394e9ec7ad55a37598faae3348d723b59ea8e91616300"},
{file = "black-23.9.1-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:638619a559280de0c2aa4d76f504891c9860bb8fa214267358f0a20f27c12948"},
{file = "black-23.9.1-cp38-cp38-macosx_10_16_universal2.whl", hash = "sha256:a732b82747235e0542c03bf352c126052c0fbc458d8a239a94701175b17d4855"},
{file = "black-23.9.1-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:cf3a4d00e4cdb6734b64bf23cd4341421e8953615cba6b3670453737a72ec204"},
{file = "black-23.9.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cf99f3de8b3273a8317681d8194ea222f10e0133a24a7548c73ce44ea1679377"},
{file = "black-23.9.1-cp38-cp38-win_amd64.whl", hash = "sha256:14f04c990259576acd093871e7e9b14918eb28f1866f91968ff5524293f9c573"},
{file = "black-23.9.1-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:c619f063c2d68f19b2d7270f4cf3192cb81c9ec5bc5ba02df91471d0b88c4c5c"},
{file = "black-23.9.1-cp39-cp39-macosx_10_16_universal2.whl", hash = "sha256:6a3b50e4b93f43b34a9d3ef00d9b6728b4a722c997c99ab09102fd5efdb88325"},
{file = "black-23.9.1-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:c46767e8df1b7beefb0899c4a95fb43058fa8500b6db144f4ff3ca38eb2f6393"},
{file = "black-23.9.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:50254ebfa56aa46a9fdd5d651f9637485068a1adf42270148cd101cdf56e0ad9"},
{file = "black-23.9.1-cp39-cp39-win_amd64.whl", hash = "sha256:403397c033adbc45c2bd41747da1f7fc7eaa44efbee256b53842470d4ac5a70f"},
{file = "black-23.9.1-py3-none-any.whl", hash = "sha256:6ccd59584cc834b6d127628713e4b6b968e5f79572da66284532525a042549f9"},
{file = "black-23.9.1.tar.gz", hash = "sha256:24b6b3ff5c6d9ea08a8888f6977eae858e1f340d7260cf56d70a49823236b62d"},
]
[package.dependencies]
@ -87,7 +87,7 @@ packaging = ">=22.0"
pathspec = ">=0.9.0"
platformdirs = ">=2"
tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
typing-extensions = {version = ">=3.10.0.0", markers = "python_version < \"3.10\""}
typing-extensions = {version = ">=4.0.1", markers = "python_version < \"3.11\""}
[package.extras]
colorama = ["colorama (>=0.4.3)"]
@ -461,13 +461,13 @@ files = [
[[package]]
name = "httpcore"
version = "0.17.0"
version = "0.18.0"
description = "A minimal low-level HTTP client."
optional = false
python-versions = ">=3.7"
python-versions = ">=3.8"
files = [
{file = "httpcore-0.17.0-py3-none-any.whl", hash = "sha256:0fdfea45e94f0c9fd96eab9286077f9ff788dd186635ae61b312693e4d943599"},
{file = "httpcore-0.17.0.tar.gz", hash = "sha256:cc045a3241afbf60ce056202301b4d8b6af08845e3294055eb26b09913ef903c"},
{file = "httpcore-0.18.0-py3-none-any.whl", hash = "sha256:adc5398ee0a476567bf87467063ee63584a8bce86078bf748e48754f60202ced"},
{file = "httpcore-0.18.0.tar.gz", hash = "sha256:13b5e5cd1dca1a6636a6aaea212b19f4f85cd88c366a2b82304181b769aab3c9"},
]
[package.dependencies]
@ -482,18 +482,18 @@ socks = ["socksio (==1.*)"]
[[package]]
name = "httpx"
version = "0.24.1"
version = "0.25.0"
description = "The next generation HTTP client."
optional = false
python-versions = ">=3.7"
python-versions = ">=3.8"
files = [
{file = "httpx-0.24.1-py3-none-any.whl", hash = "sha256:06781eb9ac53cde990577af654bd990a4949de37a28bdb4a230d434f3a30b9bd"},
{file = "httpx-0.24.1.tar.gz", hash = "sha256:5853a43053df830c20f8110c5e69fe44d035d850b2dfe795e196f00fdb774bdd"},
{file = "httpx-0.25.0-py3-none-any.whl", hash = "sha256:181ea7f8ba3a82578be86ef4171554dd45fec26a02556a744db029a0a27b7100"},
{file = "httpx-0.25.0.tar.gz", hash = "sha256:47ecda285389cb32bb2691cc6e069e3ab0205956f681c5b2ad2325719751d875"},
]
[package.dependencies]
certifi = "*"
httpcore = ">=0.15.0,<0.18.0"
httpcore = ">=0.18.0,<0.19.0"
idna = "*"
sniffio = "*"
@ -807,13 +807,13 @@ mkdocs = ">=1.1"
[[package]]
name = "mkdocs-material"
version = "9.2.8"
version = "9.3.1"
description = "Documentation that simply works"
optional = false
python-versions = ">=3.7"
python-versions = ">=3.8"
files = [
{file = "mkdocs_material-9.2.8-py3-none-any.whl", hash = "sha256:6bc8524f8047a4f060d6ab0925b9d7cb61b3b5e6d5ca8a8e8085f8bfdeca1b71"},
{file = "mkdocs_material-9.2.8.tar.gz", hash = "sha256:ec839dc5eaf42d8525acd1d6420fd0a0583671a4f98a9b3ff7897ae8628dbc2d"},
{file = "mkdocs_material-9.3.1-py3-none-any.whl", hash = "sha256:614cdd1d695465375e0f50bfe9881db1eb68d1f17164b8edfedcda1457e61894"},
{file = "mkdocs_material-9.3.1.tar.gz", hash = "sha256:793c4ec0978582380491a68db8ac4f7e0d5467a736c9884c05baf95a143f32f7"},
]
[package.dependencies]
@ -1800,4 +1800,4 @@ server = ["fastapi", "pydantic-settings", "sse-starlette", "uvicorn"]
[metadata]
lock-version = "2.0"
python-versions = "^3.8.1"
content-hash = "46f61f5544850ef5226db1475d3fa8627a14e23209b1465181dda0e6704409e9"
content-hash = "af67f2afd23599bc5231a3c0ae735bfcae5596bb6d149b379c614e721f6917b9"

View file

@ -23,13 +23,13 @@ sse-starlette = { version = ">=1.6.1", optional = true }
pydantic-settings = { version = ">=2.0.1", optional = true }
[tool.poetry.group.dev.dependencies]
black = "^23.7.0"
black = "^23.9.1"
twine = "^4.0.2"
mkdocs = "^1.5.2"
mkdocstrings = {extras = ["python"], version = "^0.23.0"}
mkdocs-material = "^9.2.8"
mkdocs-material = "^9.3.1"
pytest = "^7.4.2"
httpx = "^0.24.1"
httpx = "^0.25.0"
scikit-build = "0.17.6"
[tool.poetry.extras]

View file

@ -5,12 +5,14 @@ from pathlib import Path
this_directory = Path(__file__).parent
long_description = (this_directory / "README.md").read_text(encoding="utf-8")
exec(open('llama_cpp/version.py').read())
setup(
name="llama_cpp_python",
description="A Python wrapper for llama.cpp",
long_description=long_description,
long_description_content_type="text/markdown",
version="0.1.84",
version=__version__,
author="Andrei Betlen",
author_email="abetlen@gmail.com",
license="MIT",

View file

@ -181,3 +181,6 @@ def test_llama_server():
}
],
}
def test_llama_cpp_version():
assert llama_cpp.__version__