diff --git a/llama_cpp/server/__main__.py b/llama_cpp/server/__main__.py index f57d68c..4fbee37 100644 --- a/llama_cpp/server/__main__.py +++ b/llama_cpp/server/__main__.py @@ -24,10 +24,10 @@ Then visit http://localhost:8000/docs to see the interactive API docs. import os import uvicorn -from llama_cpp.server.app import app, init_llama +from llama_cpp.server.app import create_app if __name__ == "__main__": - init_llama() + app = create_app() uvicorn.run( app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000)) diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index 640dd3f..8e86088 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -2,18 +2,18 @@ import os import json from threading import Lock from typing import List, Optional, Union, Iterator, Dict -from typing_extensions import TypedDict, Literal +from typing_extensions import TypedDict, Literal, Annotated import llama_cpp -from fastapi import Depends, FastAPI +from fastapi import Depends, FastAPI, APIRouter from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict from sse_starlette.sse import EventSourceResponse class Settings(BaseSettings): - model: str = os.environ.get("MODEL", "null") + model: str n_ctx: int = 2048 n_batch: int = 512 n_threads: int = max((os.cpu_count() or 2) // 2, 1) @@ -27,25 +27,29 @@ class Settings(BaseSettings): vocab_only: bool = False -app = FastAPI( - title="🦙 llama.cpp Python API", - version="0.0.1", -) -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) +router = APIRouter() -llama: llama_cpp.Llama = None -def init_llama(settings: Settings = None): +llama: Optional[llama_cpp.Llama] = None + + +def create_app(settings: Optional[Settings] = None): if settings is None: settings = Settings() + app = FastAPI( + title="🦙 llama.cpp Python API", + version="0.0.1", + ) + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + app.include_router(router) global llama llama = llama_cpp.Llama( - settings.model, + model_path=settings.model, f16_kv=settings.f16_kv, use_mlock=settings.use_mlock, use_mmap=settings.use_mmap, @@ -60,12 +64,17 @@ def init_llama(settings: Settings = None): if settings.cache: cache = llama_cpp.LlamaCache() llama.set_cache(cache) + return app + llama_lock = Lock() + + def get_llama(): with llama_lock: yield llama + class CreateCompletionRequest(BaseModel): prompt: Union[str, List[str]] suffix: Optional[str] = Field(None) @@ -102,7 +111,7 @@ class CreateCompletionRequest(BaseModel): CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion) -@app.post( +@router.post( "/v1/completions", response_model=CreateCompletionResponse, ) @@ -148,7 +157,7 @@ class CreateEmbeddingRequest(BaseModel): CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding) -@app.post( +@router.post( "/v1/embeddings", response_model=CreateEmbeddingResponse, ) @@ -202,7 +211,7 @@ class CreateChatCompletionRequest(BaseModel): CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatCompletion) -@app.post( +@router.post( "/v1/chat/completions", response_model=CreateChatCompletionResponse, ) @@ -256,7 +265,7 @@ class ModelList(TypedDict): GetModelResponse = create_model_from_typeddict(ModelList) -@app.get("/v1/models", response_model=GetModelResponse) +@router.get("/v1/models", response_model=GetModelResponse) def get_models() -> ModelList: return { "object": "list", diff --git a/tests/test_llama.py b/tests/test_llama.py index 2bf38b3..3ea19e0 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -22,9 +22,11 @@ def test_llama_patch(monkeypatch): ## Set up mock function def mock_eval(*args, **kwargs): return 0 - + def mock_get_logits(*args, **kwargs): - return (llama_cpp.c_float * n_vocab)(*[llama_cpp.c_float(0) for _ in range(n_vocab)]) + return (llama_cpp.c_float * n_vocab)( + *[llama_cpp.c_float(0) for _ in range(n_vocab)] + ) monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval) monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits) @@ -88,6 +90,7 @@ def test_llama_patch(monkeypatch): def test_llama_pickle(): import pickle import tempfile + fp = tempfile.TemporaryFile() llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True) pickle.dump(llama, fp) @@ -101,6 +104,7 @@ def test_llama_pickle(): assert llama.detokenize(llama.tokenize(text)) == text + def test_utf8(monkeypatch): llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True) n_vocab = int(llama_cpp.llama_n_vocab(llama.ctx)) @@ -110,7 +114,9 @@ def test_utf8(monkeypatch): return 0 def mock_get_logits(*args, **kwargs): - return (llama_cpp.c_float * n_vocab)(*[llama_cpp.c_float(0) for _ in range(n_vocab)]) + return (llama_cpp.c_float * n_vocab)( + *[llama_cpp.c_float(0) for _ in range(n_vocab)] + ) monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval) monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits) @@ -143,11 +149,12 @@ def test_utf8(monkeypatch): def test_llama_server(): from fastapi.testclient import TestClient - from llama_cpp.server.app import app, init_llama, Settings - s = Settings() - s.model = MODEL - s.vocab_only = True - init_llama(s) + from llama_cpp.server.app import create_app, Settings + + settings = Settings() + settings.model = MODEL + settings.vocab_only = True + app = create_app(settings) client = TestClient(app) response = client.get("/v1/models") assert response.json() == {