Add basic tests. Closes #24

This commit is contained in:
Andrei Betlen 2023-04-05 03:23:15 -04:00
parent 51dbcf2693
commit c3972b61ae
3 changed files with 167 additions and 1 deletions

88
poetry.lock generated
View file

@ -1,5 +1,24 @@
# This file is automatically @generated by Poetry 1.4.1 and should not be changed by hand.
[[package]]
name = "attrs"
version = "22.2.0"
description = "Classes Without Boilerplate"
category = "dev"
optional = false
python-versions = ">=3.6"
files = [
{file = "attrs-22.2.0-py3-none-any.whl", hash = "sha256:29e95c7f6778868dbd49170f98f8818f78f3dc5e0e37c0b1f474e3561b240836"},
{file = "attrs-22.2.0.tar.gz", hash = "sha256:c9227bfc2f01993c03f68db37d1d15c9690188323c067c641f1a35ca58185f99"},
]
[package.extras]
cov = ["attrs[tests]", "coverage-enable-subprocess", "coverage[toml] (>=5.3)"]
dev = ["attrs[docs,tests]"]
docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope.interface"]
tests = ["attrs[tests-no-zope]", "zope.interface"]
tests-no-zope = ["cloudpickle", "cloudpickle", "hypothesis", "hypothesis", "mypy (>=0.971,<0.990)", "mypy (>=0.971,<0.990)", "pympler", "pympler", "pytest (>=4.3.0)", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-mypy-plugins", "pytest-xdist[psutil]", "pytest-xdist[psutil]"]
[[package]]
name = "black"
version = "23.1.0"
@ -328,6 +347,21 @@ files = [
{file = "docutils-0.19.tar.gz", hash = "sha256:33995a6753c30b7f577febfc2c50411fec6aac7f7ffeb7c4cfe5991072dcf9e6"},
]
[[package]]
name = "exceptiongroup"
version = "1.1.1"
description = "Backport of PEP 654 (exception groups)"
category = "dev"
optional = false
python-versions = ">=3.7"
files = [
{file = "exceptiongroup-1.1.1-py3-none-any.whl", hash = "sha256:232c37c63e4f682982c8b6459f33a8981039e5fb8756b2074364e5055c498c9e"},
{file = "exceptiongroup-1.1.1.tar.gz", hash = "sha256:d484c3090ba2889ae2928419117447a14daf3c1231d5e30d0aae34f354f01785"},
]
[package.extras]
test = ["pytest (>=6)"]
[[package]]
name = "ghp-import"
version = "2.1.0"
@ -415,6 +449,18 @@ zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""}
docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
testing = ["flake8 (<5)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"]
[[package]]
name = "iniconfig"
version = "2.0.0"
description = "brain-dead simple config-ini parsing"
category = "dev"
optional = false
python-versions = ">=3.7"
files = [
{file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"},
{file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
]
[[package]]
name = "jaraco-classes"
version = "3.2.3"
@ -821,6 +867,22 @@ files = [
docs = ["furo (>=2022.12.7)", "proselint (>=0.13)", "sphinx (>=6.1.3)", "sphinx-autodoc-typehints (>=1.22,!=1.23.4)"]
test = ["appdirs (==1.4.4)", "covdefaults (>=2.2.2)", "pytest (>=7.2.1)", "pytest-cov (>=4)", "pytest-mock (>=3.10)"]
[[package]]
name = "pluggy"
version = "1.0.0"
description = "plugin and hook calling mechanisms for python"
category = "dev"
optional = false
python-versions = ">=3.6"
files = [
{file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"},
{file = "pluggy-1.0.0.tar.gz", hash = "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159"},
]
[package.extras]
dev = ["pre-commit", "tox"]
testing = ["pytest", "pytest-benchmark"]
[[package]]
name = "pycparser"
version = "2.21"
@ -864,6 +926,30 @@ files = [
markdown = ">=3.2"
pyyaml = "*"
[[package]]
name = "pytest"
version = "7.2.2"
description = "pytest: simple powerful testing with Python"
category = "dev"
optional = false
python-versions = ">=3.7"
files = [
{file = "pytest-7.2.2-py3-none-any.whl", hash = "sha256:130328f552dcfac0b1cec75c12e3f005619dc5f874f0a06e8ff7263f0ee6225e"},
{file = "pytest-7.2.2.tar.gz", hash = "sha256:c99ab0c73aceb050f68929bc93af19ab6db0558791c6a0715723abe9d0ade9d4"},
]
[package.dependencies]
attrs = ">=19.2.0"
colorama = {version = "*", markers = "sys_platform == \"win32\""}
exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""}
iniconfig = "*"
packaging = "*"
pluggy = ">=0.12,<2.0"
tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""}
[package.extras]
testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "xmlschema"]
[[package]]
name = "python-dateutil"
version = "2.8.2"
@ -1281,4 +1367,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more
[metadata]
lock-version = "2.0"
python-versions = "^3.8.1"
content-hash = "cffaf5e2e66ade4f429d0e938277d4fa2c4878ca7338c3c4f91721a7d3aff91b"
content-hash = "cc9babcdfdc3679a4d84f68912408a005619a576947b059146ed1b428850ece9"

View file

@ -23,6 +23,7 @@ twine = "^4.0.2"
mkdocs = "^1.4.2"
mkdocstrings = {extras = ["python"], version = "^0.20.0"}
mkdocs-material = "^9.1.4"
pytest = "^7.2.2"
[build-system]
requires = [

79
tests/test_llama.py Normal file
View file

@ -0,0 +1,79 @@
import llama_cpp
MODEL = "./vendor/llama.cpp/models/ggml-vocab.bin"
def test_llama():
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
assert llama
assert llama.ctx is not None
text = b"Hello World"
assert llama.detokenize(llama.tokenize(text)) == text
def test_llama_patch(monkeypatch):
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
## Set up mock function
def mock_eval(*args, **kwargs):
return 0
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
output_text = " jumps over the lazy dog."
output_tokens = llama.tokenize(output_text.encode("utf-8"))
token_eos = llama.token_eos()
n = 0
def mock_sample(*args, **kwargs):
nonlocal n
if n < len(output_tokens):
n += 1
return output_tokens[n - 1]
else:
return token_eos
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_top_p_top_k", mock_sample)
text = "The quick brown fox"
## Test basic completion until eos
n = 0 # reset
completion = llama.create_completion(text, max_tokens=20)
assert completion["choices"][0]["text"] == output_text
assert completion["choices"][0]["finish_reason"] == "stop"
## Test streaming completion until eos
n = 0 # reset
chunks = llama.create_completion(text, max_tokens=20, stream=True)
assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == output_text
assert completion["choices"][0]["finish_reason"] == "stop"
## Test basic completion until stop sequence
n = 0 # reset
completion = llama.create_completion(text, max_tokens=20, stop=["lazy"])
assert completion["choices"][0]["text"] == " jumps over the "
assert completion["choices"][0]["finish_reason"] == "stop"
## Test streaming completion until stop sequence
n = 0 # reset
chunks = llama.create_completion(text, max_tokens=20, stream=True, stop=["lazy"])
assert (
"".join(chunk["choices"][0]["text"] for chunk in chunks) == " jumps over the "
)
assert completion["choices"][0]["finish_reason"] == "stop"
## Test basic completion until length
n = 0 # reset
completion = llama.create_completion(text, max_tokens=2)
assert completion["choices"][0]["text"] == " j"
assert completion["choices"][0]["finish_reason"] == "length"
## Test streaming completion until length
n = 0 # reset
chunks = llama.create_completion(text, max_tokens=2, stream=True)
assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == " j"
assert completion["choices"][0]["finish_reason"] == "length"