This commit is contained in:
Mug 2023-04-05 14:18:27 +02:00
commit e4c6f34d95
19 changed files with 6212 additions and 123 deletions

30
.github/workflows/test.yaml vendored Normal file
View file

@ -0,0 +1,30 @@
name: Tests
on:
push:
branches:
- main
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
steps:
- uses: actions/checkout@v3
with:
submodules: "true"
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip pytest cmake scikit-build
python3 setup.py develop
- name: Test with pytest
run: |
pytest

View file

@ -1,6 +1,7 @@
# 🦙 Python Bindings for `llama.cpp`
[![Documentation](https://img.shields.io/badge/docs-passing-green.svg)](https://abetlen.github.io/llama-cpp-python)
[![Tests](https://github.com/abetlen/llama-cpp-python/actions/workflows/test.yaml/badge.svg?branch=main)](https://github.com/abetlen/llama-cpp-python/actions/workflows/test.yaml)
[![PyPI](https://img.shields.io/pypi/v/llama-cpp-python)](https://pypi.org/project/llama-cpp-python/)
[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/llama-cpp-python)](https://pypi.org/project/llama-cpp-python/)
[![PyPI - License](https://img.shields.io/pypi/l/llama-cpp-python)](https://pypi.org/project/llama-cpp-python/)
@ -70,7 +71,7 @@ python3 setup.py develop
# How does this compare to other Python bindings of `llama.cpp`?
I wrote this package for my own use, I had two goals in mind:
I originally wrote this package for my own use with two goals in mind:
- Provide a simple process to install `llama.cpp` and access the full C API in `llama.h` from Python
- Provide a high-level Python API that can be used as a drop-in replacement for the OpenAI API so existing apps can be easily ported to use `llama.cpp`

View file

@ -71,8 +71,10 @@ python3 setup.py develop
- sample
- generate
- create_embedding
- embed
- create_completion
- __call__
- create_chat_completion
- token_bos
- token_eos
show_root_heading: true

View file

@ -1,97 +0,0 @@
"""Example FastAPI server for llama.cpp.
"""
import json
from typing import List, Optional, Iterator
import llama_cpp
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict
from sse_starlette.sse import EventSourceResponse
class Settings(BaseSettings):
model: str
app = FastAPI(
title="🦙 llama.cpp Python API",
version="0.0.1",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
settings = Settings()
llama = llama_cpp.Llama(
settings.model,
f16_kv=True,
use_mlock=True,
embedding=True,
n_threads=6,
n_batch=2048,
)
class CreateCompletionRequest(BaseModel):
prompt: str
suffix: Optional[str] = Field(None)
max_tokens: int = 16
temperature: float = 0.8
top_p: float = 0.95
logprobs: Optional[int] = Field(None)
echo: bool = False
stop: List[str] = []
repeat_penalty: float = 1.1
top_k: int = 40
stream: bool = False
class Config:
schema_extra = {
"example": {
"prompt": "\n\n### Instructions:\nWhat is the capital of France?\n\n### Response:\n",
"stop": ["\n", "###"],
}
}
CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
@app.post(
"/v1/completions",
response_model=CreateCompletionResponse,
)
def create_completion(request: CreateCompletionRequest):
if request.stream:
chunks: Iterator[llama_cpp.CompletionChunk] = llama(**request.dict()) # type: ignore
return EventSourceResponse(dict(data=json.dumps(chunk)) for chunk in chunks)
return llama(**request.dict())
class CreateEmbeddingRequest(BaseModel):
model: Optional[str]
input: str
user: Optional[str]
class Config:
schema_extra = {
"example": {
"input": "The food was delicious and the waiter...",
}
}
CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
@app.post(
"/v1/embeddings",
response_model=CreateEmbeddingResponse,
)
def create_embedding(request: CreateEmbeddingRequest):
return llama.create_embedding(request.input)

View file

@ -0,0 +1,181 @@
"""Example FastAPI server for llama.cpp.
To run this example:
```bash
pip install fastapi uvicorn sse-starlette
export MODEL=../models/7B/...
uvicorn fastapi_server_chat:app --reload
```
Then visit http://localhost:8000/docs to see the interactive API docs.
"""
import os
import json
from typing import List, Optional, Literal, Union, Iterator
import llama_cpp
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict
from sse_starlette.sse import EventSourceResponse
class Settings(BaseSettings):
model: str
n_ctx: int = 2048
n_batch: int = 2048
n_threads: int = os.cpu_count() or 1
f16_kv: bool = True
use_mlock: bool = True
embedding: bool = True
last_n_tokens_size: int = 64
app = FastAPI(
title="🦙 llama.cpp Python API",
version="0.0.1",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
settings = Settings()
llama = llama_cpp.Llama(
settings.model,
f16_kv=settings.f16_kv,
use_mlock=settings.use_mlock,
embedding=settings.embedding,
n_threads=settings.n_threads,
n_batch=settings.n_batch,
n_ctx=settings.n_ctx,
last_n_tokens_size=settings.last_n_tokens_size,
)
class CreateCompletionRequest(BaseModel):
prompt: str
suffix: Optional[str] = Field(None)
max_tokens: int = 16
temperature: float = 0.8
top_p: float = 0.95
logprobs: Optional[int] = Field(None)
echo: bool = False
stop: List[str] = []
repeat_penalty: float = 1.1
top_k: int = 40
stream: bool = False
class Config:
schema_extra = {
"example": {
"prompt": "\n\n### Instructions:\nWhat is the capital of France?\n\n### Response:\n",
"stop": ["\n", "###"],
}
}
CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
@app.post(
"/v1/completions",
response_model=CreateCompletionResponse,
)
def create_completion(request: CreateCompletionRequest):
if request.stream:
chunks: Iterator[llama_cpp.CompletionChunk] = llama(**request.dict()) # type: ignore
return EventSourceResponse(dict(data=json.dumps(chunk)) for chunk in chunks)
return llama(**request.dict())
class CreateEmbeddingRequest(BaseModel):
model: Optional[str]
input: str
user: Optional[str]
class Config:
schema_extra = {
"example": {
"input": "The food was delicious and the waiter...",
}
}
CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
@app.post(
"/v1/embeddings",
response_model=CreateEmbeddingResponse,
)
def create_embedding(request: CreateEmbeddingRequest):
return llama.create_embedding(**request.dict(exclude={"model", "user"}))
class ChatCompletionRequestMessage(BaseModel):
role: Union[Literal["system"], Literal["user"], Literal["assistant"]]
content: str
user: Optional[str] = None
class CreateChatCompletionRequest(BaseModel):
model: Optional[str]
messages: List[ChatCompletionRequestMessage]
temperature: float = 0.8
top_p: float = 0.95
stream: bool = False
stop: List[str] = []
max_tokens: int = 128
repeat_penalty: float = 1.1
class Config:
schema_extra = {
"example": {
"messages": [
ChatCompletionRequestMessage(
role="system", content="You are a helpful assistant."
),
ChatCompletionRequestMessage(
role="user", content="What is the capital of France?"
),
]
}
}
CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatCompletion)
@app.post(
"/v1/chat/completions",
response_model=CreateChatCompletionResponse,
)
async def create_chat_completion(
request: CreateChatCompletionRequest,
) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]:
completion_or_chunks = llama.create_chat_completion(
**request.dict(exclude={"model"}),
)
if request.stream:
async def server_sent_events(
chat_chunks: Iterator[llama_cpp.ChatCompletionChunk],
):
for chat_chunk in chat_chunks:
yield dict(data=json.dumps(chat_chunk))
yield dict(data="[DONE]")
chunks: Iterator[llama_cpp.ChatCompletionChunk] = completion_or_chunks # type: ignore
return EventSourceResponse(
server_sent_events(chunks),
)
completion: llama_cpp.ChatCompletion = completion_or_chunks # type: ignore
return completion

View file

@ -11,7 +11,7 @@ llm = Llama(model_path=args.model)
output = llm(
"Question: What are the names of the planets in the solar system? Answer: ",
max_tokens=1,
max_tokens=48,
stop=["Q:", "\n"],
echo=True,
)

View file

@ -4,7 +4,7 @@ import argparse
from llama_cpp import Llama
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default=".//models/...")
parser.add_argument("-m", "--model", type=str, default="./models/...")
args = parser.parse_args()
llm = Llama(model_path=args.model)

View file

@ -0,0 +1,25 @@
import os
import argparse
import llama_cpp
def main(args):
if not os.path.exists(fname_inp):
raise RuntimeError(f"Input file does not exist ({fname_inp})")
if os.path.exists(fname_out):
raise RuntimeError(f"Output file already exists ({fname_out})")
fname_inp = args.fname_inp.encode("utf-8")
fname_out = args.fname_out.encode("utf-8")
itype = args.itype
return_code = llama_cpp.llama_model_quantize(fname_inp, fname_out, itype)
if return_code != 0:
raise RuntimeError("Failed to quantize model")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("fname_inp", type=str, help="Path to input model")
parser.add_argument("fname_out", type=str, help="Path to output model")
parser.add_argument("type", type=int, help="Type of quantization (2: q4_0, 3: q4_1)")
args = parser.parse_args()
main(args)

File diff suppressed because one or more lines are too long

View file

@ -1,8 +1,9 @@
import os
import sys
import uuid
import time
import multiprocessing
from typing import List, Optional, Union, Generator, Sequence
from typing import List, Optional, Union, Generator, Sequence, Iterator
from collections import deque
from . import llama_cpp
@ -27,6 +28,7 @@ class Llama:
n_threads: Optional[int] = None,
n_batch: int = 8,
last_n_tokens_size: int = 64,
verbose: bool = True,
):
"""Load a llama.cpp model from `model_path`.
@ -43,6 +45,7 @@ class Llama:
n_threads: Number of threads to use. If None, the number of threads is automatically determined.
n_batch: Maximum number of prompt tokens to batch together when calling llama_eval.
last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
verbose: Print verbose output to stderr.
Raises:
ValueError: If the model path does not exist.
@ -50,6 +53,7 @@ class Llama:
Returns:
A Llama instance.
"""
self.verbose = verbose
self.model_path = model_path
self.params = llama_cpp.llama_context_default_params()
@ -68,7 +72,7 @@ class Llama:
maxlen=self.last_n_tokens_size,
)
self.tokens_consumed = 0
self.n_batch = n_batch
self.n_batch = min(n_ctx, n_batch)
self.n_threads = n_threads or multiprocessing.cpu_count()
@ -79,6 +83,9 @@ class Llama:
self.model_path.encode("utf-8"), self.params
)
if self.verbose:
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
def tokenize(self, text: bytes) -> List[llama_cpp.llama_token]:
"""Tokenize a string.
@ -169,11 +176,6 @@ class Llama:
The sampled token.
"""
assert self.ctx is not None
# Temporary workaround for https://github.com/ggerganov/llama.cpp/issues/684
if temp == 0.0:
temp = 1.0
top_p = 0.0
top_k = 1
return llama_cpp.llama_sample_top_p_top_k(
ctx=self.ctx,
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
@ -239,6 +241,15 @@ class Llama:
An embedding object.
"""
assert self.ctx is not None
if self.params.embedding == False:
raise RuntimeError(
"Llama model must be created with embedding=True to call this method"
)
if self.verbose:
llama_cpp.llama_reset_timings(self.ctx)
tokens = self.tokenize(input.encode("utf-8"))
self.reset()
self.eval(tokens)
@ -246,6 +257,10 @@ class Llama:
embedding = llama_cpp.llama_get_embeddings(self.ctx)[
: llama_cpp.llama_n_embd(self.ctx)
]
if self.verbose:
llama_cpp.llama_print_timings(self.ctx)
return {
"object": "list",
"data": [
@ -262,6 +277,17 @@ class Llama:
},
}
def embed(self, input: str) -> List[float]:
"""Embed a string.
Args:
input: The utf-8 encoded string to embed.
Returns:
A list of embeddings
"""
return list(map(float, self.create_embedding(input)["data"][0]["embedding"]))
def _create_completion(
self,
prompt: str,
@ -275,10 +301,7 @@ class Llama:
repeat_penalty: float = 1.1,
top_k: int = 40,
stream: bool = False,
) -> Union[
Generator[Completion, None, None],
Generator[CompletionChunk, None, None],
]:
) -> Union[Iterator[Completion], Iterator[CompletionChunk],]:
assert self.ctx is not None
completion_id = f"cmpl-{str(uuid.uuid4())}"
created = int(time.time())
@ -288,6 +311,9 @@ class Llama:
text = b""
returned_characters = 0
if self.verbose:
llama_cpp.llama_reset_timings(self.ctx)
if len(prompt_tokens) + max_tokens > int(llama_cpp.llama_n_ctx(self.ctx)):
raise ValueError(
f"Requested tokens exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}"
@ -341,7 +367,7 @@ class Llama:
"model": self.model_path,
"choices": [
{
"text": text[start :].decode("utf-8"),
"text": text[start:].decode("utf-8"),
"index": 0,
"logprobs": None,
"finish_reason": None,
@ -384,6 +410,9 @@ class Llama:
if logprobs is not None:
raise NotImplementedError("logprobs not implemented")
if self.verbose:
llama_cpp.llama_print_timings(self.ctx)
yield {
"id": completion_id,
"object": "text_completion",
@ -417,7 +446,7 @@ class Llama:
repeat_penalty: float = 1.1,
top_k: int = 40,
stream: bool = False,
) -> Union[Completion, Generator[CompletionChunk, None, None]]:
) -> Union[Completion, Iterator[CompletionChunk]]:
"""Generate text from a prompt.
Args:
@ -454,7 +483,7 @@ class Llama:
stream=stream,
)
if stream:
chunks: Generator[CompletionChunk, None, None] = completion_or_chunks
chunks: Iterator[CompletionChunk] = completion_or_chunks
return chunks
completion: Completion = next(completion_or_chunks) # type: ignore
return completion
@ -472,7 +501,7 @@ class Llama:
repeat_penalty: float = 1.1,
top_k: int = 40,
stream: bool = False,
):
) -> Union[Completion, Iterator[CompletionChunk]]:
"""Generate text from a prompt.
Args:
@ -509,11 +538,158 @@ class Llama:
stream=stream,
)
def _convert_text_completion_to_chat(
self, completion: Completion
) -> ChatCompletion:
return {
"id": "chat" + completion["id"],
"object": "chat.completion",
"created": completion["created"],
"model": completion["model"],
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": completion["choices"][0]["text"],
},
"finish_reason": completion["choices"][0]["finish_reason"],
}
],
"usage": completion["usage"],
}
def _convert_text_completion_chunks_to_chat(
self,
chunks: Iterator[CompletionChunk],
) -> Iterator[ChatCompletionChunk]:
for i, chunk in enumerate(chunks):
if i == 0:
yield {
"id": "chat" + chunk["id"],
"model": chunk["model"],
"created": chunk["created"],
"object": "chat.completion.chunk",
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
},
"finish_reason": None,
}
],
}
yield {
"id": "chat" + chunk["id"],
"model": chunk["model"],
"created": chunk["created"],
"object": "chat.completion.chunk",
"choices": [
{
"index": 0,
"delta": {
"content": chunk["choices"][0]["text"],
},
"finish_reason": chunk["choices"][0]["finish_reason"],
}
],
}
def create_chat_completion(
self,
messages: List[ChatCompletionMessage],
temperature: float = 0.8,
top_p: float = 0.95,
top_k: int = 40,
stream: bool = False,
stop: List[str] = [],
max_tokens: int = 128,
repeat_penalty: float = 1.1,
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
"""Generate a chat completion from a list of messages.
Args:
messages: A list of messages to generate a response for.
temperature: The temperature to use for sampling.
top_p: The top-p value to use for sampling.
top_k: The top-k value to use for sampling.
stream: Whether to stream the results.
stop: A list of strings to stop generation when encountered.
max_tokens: The maximum number of tokens to generate.
repeat_penalty: The penalty to apply to repeated tokens.
Returns:
Generated chat completion or a stream of chat completion chunks.
"""
instructions = """Complete the following chat conversation between the user and the assistant. System messages should be strictly followed as additional instructions."""
chat_history = "\n".join(
f'{message["role"]} {message.get("user", "")}: {message["content"]}'
for message in messages
)
PROMPT = f" \n\n### Instructions:{instructions}\n\n### Inputs:{chat_history}\n\n### Response:\nassistant: "
PROMPT_STOP = ["###", "\nuser: ", "\nassistant: ", "\nsystem: "]
completion_or_chunks = self(
prompt=PROMPT,
stop=PROMPT_STOP + stop,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream=stream,
max_tokens=max_tokens,
repeat_penalty=repeat_penalty,
)
if stream:
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore
return self._convert_text_completion_chunks_to_chat(chunks)
else:
completion: Completion = completion_or_chunks # type: ignore
return self._convert_text_completion_to_chat(completion)
def __del__(self):
if self.ctx is not None:
llama_cpp.llama_free(self.ctx)
self.ctx = None
def __getstate__(self):
return dict(
verbose=self.verbose,
model_path=self.model_path,
n_ctx=self.params.n_ctx,
n_parts=self.params.n_parts,
seed=self.params.seed,
f16_kv=self.params.f16_kv,
logits_all=self.params.logits_all,
vocab_only=self.params.vocab_only,
use_mlock=self.params.use_mlock,
embedding=self.params.embedding,
last_n_tokens_size=self.last_n_tokens_size,
last_n_tokens_data=self.last_n_tokens_data,
tokens_consumed=self.tokens_consumed,
n_batch=self.n_batch,
n_threads=self.n_threads,
)
def __setstate__(self, state):
self.__init__(
model_path=state["model_path"],
n_ctx=state["n_ctx"],
n_parts=state["n_parts"],
seed=state["seed"],
f16_kv=state["f16_kv"],
logits_all=state["logits_all"],
vocab_only=state["vocab_only"],
use_mlock=state["use_mlock"],
embedding=state["embedding"],
n_threads=state["n_threads"],
n_batch=state["n_batch"],
last_n_tokens_size=state["last_n_tokens_size"],
verbose=state["verbose"],
)
self.last_n_tokens_data=state["last_n_tokens_data"]
self.tokens_consumed=state["tokens_consumed"]
@staticmethod
def token_eos() -> llama_cpp.llama_token:
"""Return the end-of-sequence token."""

View file

@ -125,12 +125,12 @@ _lib.llama_free.restype = None
# TODO: not great API - very likely to change
# Returns 0 on success
def llama_model_quantize(
fname_inp: bytes, fname_out: bytes, itype: c_int, qk: c_int
fname_inp: bytes, fname_out: bytes, itype: c_int
) -> c_int:
return _lib.llama_model_quantize(fname_inp, fname_out, itype, qk)
return _lib.llama_model_quantize(fname_inp, fname_out, itype)
_lib.llama_model_quantize.argtypes = [c_char_p, c_char_p, c_int, c_int]
_lib.llama_model_quantize.argtypes = [c_char_p, c_char_p, c_int]
_lib.llama_model_quantize.restype = c_int
# Returns the KV cache that will contain the context for the

View file

@ -1,5 +1,5 @@
from typing import List, Optional, Dict, Literal
from typing_extensions import TypedDict
from typing import List, Optional, Dict, Union
from typing_extensions import TypedDict, NotRequired, Literal
class EmbeddingUsage(TypedDict):
@ -55,3 +55,43 @@ class Completion(TypedDict):
model: str
choices: List[CompletionChoice]
usage: CompletionUsage
class ChatCompletionMessage(TypedDict):
role: Union[Literal["assistant"], Literal["user"], Literal["system"]]
content: str
user: NotRequired[str]
class ChatCompletionChoice(TypedDict):
index: int
message: ChatCompletionMessage
finish_reason: Optional[str]
class ChatCompletion(TypedDict):
id: str
object: Literal["chat.completion"]
created: int
model: str
choices: List[ChatCompletionChoice]
usage: CompletionUsage
class ChatCompletionChunkDelta(TypedDict):
role: NotRequired[Literal["assistant"]]
content: NotRequired[str]
class ChatCompletionChunkChoice(TypedDict):
index: int
delta: ChatCompletionChunkDelta
finish_reason: Optional[str]
class ChatCompletionChunk(TypedDict):
id: str
model: str
object: Literal["chat.completion.chunk"]
created: int
choices: List[ChatCompletionChunkChoice]

88
poetry.lock generated
View file

@ -1,5 +1,24 @@
# This file is automatically @generated by Poetry 1.4.1 and should not be changed by hand.
[[package]]
name = "attrs"
version = "22.2.0"
description = "Classes Without Boilerplate"
category = "dev"
optional = false
python-versions = ">=3.6"
files = [
{file = "attrs-22.2.0-py3-none-any.whl", hash = "sha256:29e95c7f6778868dbd49170f98f8818f78f3dc5e0e37c0b1f474e3561b240836"},
{file = "attrs-22.2.0.tar.gz", hash = "sha256:c9227bfc2f01993c03f68db37d1d15c9690188323c067c641f1a35ca58185f99"},
]
[package.extras]
cov = ["attrs[tests]", "coverage-enable-subprocess", "coverage[toml] (>=5.3)"]
dev = ["attrs[docs,tests]"]
docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope.interface"]
tests = ["attrs[tests-no-zope]", "zope.interface"]
tests-no-zope = ["cloudpickle", "cloudpickle", "hypothesis", "hypothesis", "mypy (>=0.971,<0.990)", "mypy (>=0.971,<0.990)", "pympler", "pympler", "pytest (>=4.3.0)", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-mypy-plugins", "pytest-xdist[psutil]", "pytest-xdist[psutil]"]
[[package]]
name = "black"
version = "23.1.0"
@ -328,6 +347,21 @@ files = [
{file = "docutils-0.19.tar.gz", hash = "sha256:33995a6753c30b7f577febfc2c50411fec6aac7f7ffeb7c4cfe5991072dcf9e6"},
]
[[package]]
name = "exceptiongroup"
version = "1.1.1"
description = "Backport of PEP 654 (exception groups)"
category = "dev"
optional = false
python-versions = ">=3.7"
files = [
{file = "exceptiongroup-1.1.1-py3-none-any.whl", hash = "sha256:232c37c63e4f682982c8b6459f33a8981039e5fb8756b2074364e5055c498c9e"},
{file = "exceptiongroup-1.1.1.tar.gz", hash = "sha256:d484c3090ba2889ae2928419117447a14daf3c1231d5e30d0aae34f354f01785"},
]
[package.extras]
test = ["pytest (>=6)"]
[[package]]
name = "ghp-import"
version = "2.1.0"
@ -415,6 +449,18 @@ zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""}
docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
testing = ["flake8 (<5)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"]
[[package]]
name = "iniconfig"
version = "2.0.0"
description = "brain-dead simple config-ini parsing"
category = "dev"
optional = false
python-versions = ">=3.7"
files = [
{file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"},
{file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
]
[[package]]
name = "jaraco-classes"
version = "3.2.3"
@ -821,6 +867,22 @@ files = [
docs = ["furo (>=2022.12.7)", "proselint (>=0.13)", "sphinx (>=6.1.3)", "sphinx-autodoc-typehints (>=1.22,!=1.23.4)"]
test = ["appdirs (==1.4.4)", "covdefaults (>=2.2.2)", "pytest (>=7.2.1)", "pytest-cov (>=4)", "pytest-mock (>=3.10)"]
[[package]]
name = "pluggy"
version = "1.0.0"
description = "plugin and hook calling mechanisms for python"
category = "dev"
optional = false
python-versions = ">=3.6"
files = [
{file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"},
{file = "pluggy-1.0.0.tar.gz", hash = "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159"},
]
[package.extras]
dev = ["pre-commit", "tox"]
testing = ["pytest", "pytest-benchmark"]
[[package]]
name = "pycparser"
version = "2.21"
@ -864,6 +926,30 @@ files = [
markdown = ">=3.2"
pyyaml = "*"
[[package]]
name = "pytest"
version = "7.2.2"
description = "pytest: simple powerful testing with Python"
category = "dev"
optional = false
python-versions = ">=3.7"
files = [
{file = "pytest-7.2.2-py3-none-any.whl", hash = "sha256:130328f552dcfac0b1cec75c12e3f005619dc5f874f0a06e8ff7263f0ee6225e"},
{file = "pytest-7.2.2.tar.gz", hash = "sha256:c99ab0c73aceb050f68929bc93af19ab6db0558791c6a0715723abe9d0ade9d4"},
]
[package.dependencies]
attrs = ">=19.2.0"
colorama = {version = "*", markers = "sys_platform == \"win32\""}
exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""}
iniconfig = "*"
packaging = "*"
pluggy = ">=0.12,<2.0"
tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""}
[package.extras]
testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "xmlschema"]
[[package]]
name = "python-dateutil"
version = "2.8.2"
@ -1281,4 +1367,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more
[metadata]
lock-version = "2.0"
python-versions = "^3.8.1"
content-hash = "cffaf5e2e66ade4f429d0e938277d4fa2c4878ca7338c3c4f91721a7d3aff91b"
content-hash = "cc9babcdfdc3679a4d84f68912408a005619a576947b059146ed1b428850ece9"

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "llama_cpp"
version = "0.1.17"
version = "0.1.22"
description = "Python bindings for the llama.cpp library"
authors = ["Andrei Betlen <abetlen@gmail.com>"]
license = "MIT"
@ -23,6 +23,7 @@ twine = "^4.0.2"
mkdocs = "^1.4.2"
mkdocstrings = {extras = ["python"], version = "^0.20.0"}
mkdocs-material = "^9.1.4"
pytest = "^7.2.2"
[build-system]
requires = [

View file

@ -10,7 +10,7 @@ setup(
description="A Python wrapper for llama.cpp",
long_description=long_description,
long_description_content_type="text/markdown",
version="0.1.17",
version="0.1.22",
author="Andrei Betlen",
author_email="abetlen@gmail.com",
license="MIT",
@ -19,4 +19,12 @@ setup(
"typing-extensions>=4.5.0",
],
python_requires=">=3.7",
classifiers=[
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
],
)

96
tests/test_llama.py Normal file
View file

@ -0,0 +1,96 @@
import llama_cpp
MODEL = "./vendor/llama.cpp/models/ggml-vocab.bin"
def test_llama():
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
assert llama
assert llama.ctx is not None
text = b"Hello World"
assert llama.detokenize(llama.tokenize(text)) == text
def test_llama_patch(monkeypatch):
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
## Set up mock function
def mock_eval(*args, **kwargs):
return 0
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
output_text = " jumps over the lazy dog."
output_tokens = llama.tokenize(output_text.encode("utf-8"))
token_eos = llama.token_eos()
n = 0
def mock_sample(*args, **kwargs):
nonlocal n
if n < len(output_tokens):
n += 1
return output_tokens[n - 1]
else:
return token_eos
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_top_p_top_k", mock_sample)
text = "The quick brown fox"
## Test basic completion until eos
n = 0 # reset
completion = llama.create_completion(text, max_tokens=20)
assert completion["choices"][0]["text"] == output_text
assert completion["choices"][0]["finish_reason"] == "stop"
## Test streaming completion until eos
n = 0 # reset
chunks = llama.create_completion(text, max_tokens=20, stream=True)
assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == output_text
assert completion["choices"][0]["finish_reason"] == "stop"
## Test basic completion until stop sequence
n = 0 # reset
completion = llama.create_completion(text, max_tokens=20, stop=["lazy"])
assert completion["choices"][0]["text"] == " jumps over the "
assert completion["choices"][0]["finish_reason"] == "stop"
## Test streaming completion until stop sequence
n = 0 # reset
chunks = llama.create_completion(text, max_tokens=20, stream=True, stop=["lazy"])
assert (
"".join(chunk["choices"][0]["text"] for chunk in chunks) == " jumps over the "
)
assert completion["choices"][0]["finish_reason"] == "stop"
## Test basic completion until length
n = 0 # reset
completion = llama.create_completion(text, max_tokens=2)
assert completion["choices"][0]["text"] == " j"
assert completion["choices"][0]["finish_reason"] == "length"
## Test streaming completion until length
n = 0 # reset
chunks = llama.create_completion(text, max_tokens=2, stream=True)
assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == " j"
assert completion["choices"][0]["finish_reason"] == "length"
def test_llama_pickle():
import pickle
import tempfile
fp = tempfile.TemporaryFile()
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
pickle.dump(llama, fp)
fp.seek(0)
llama = pickle.load(fp)
assert llama
assert llama.ctx is not None
text = b"Hello World"
assert llama.detokenize(llama.tokenize(text)) == text