Refactor server to use factory

This commit is contained in:
Andrei Betlen 2023-05-01 22:38:46 -04:00
parent dd9ad1c759
commit 9eafc4c49a
3 changed files with 47 additions and 31 deletions

View file

@ -24,10 +24,10 @@ Then visit http://localhost:8000/docs to see the interactive API docs.
import os import os
import uvicorn import uvicorn
from llama_cpp.server.app import app, init_llama from llama_cpp.server.app import create_app
if __name__ == "__main__": if __name__ == "__main__":
init_llama() app = create_app()
uvicorn.run( uvicorn.run(
app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000)) app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000))

View file

@ -2,18 +2,18 @@ import os
import json import json
from threading import Lock from threading import Lock
from typing import List, Optional, Union, Iterator, Dict from typing import List, Optional, Union, Iterator, Dict
from typing_extensions import TypedDict, Literal from typing_extensions import TypedDict, Literal, Annotated
import llama_cpp import llama_cpp
from fastapi import Depends, FastAPI from fastapi import Depends, FastAPI, APIRouter
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
class Settings(BaseSettings): class Settings(BaseSettings):
model: str = os.environ.get("MODEL", "null") model: str
n_ctx: int = 2048 n_ctx: int = 2048
n_batch: int = 512 n_batch: int = 512
n_threads: int = max((os.cpu_count() or 2) // 2, 1) n_threads: int = max((os.cpu_count() or 2) // 2, 1)
@ -27,25 +27,29 @@ class Settings(BaseSettings):
vocab_only: bool = False vocab_only: bool = False
app = FastAPI( router = APIRouter()
llama: Optional[llama_cpp.Llama] = None
def create_app(settings: Optional[Settings] = None):
if settings is None:
settings = Settings()
app = FastAPI(
title="🦙 llama.cpp Python API", title="🦙 llama.cpp Python API",
version="0.0.1", version="0.0.1",
) )
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=["*"], allow_origins=["*"],
allow_credentials=True, allow_credentials=True,
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],
) )
app.include_router(router)
llama: llama_cpp.Llama = None
def init_llama(settings: Settings = None):
if settings is None:
settings = Settings()
global llama global llama
llama = llama_cpp.Llama( llama = llama_cpp.Llama(
settings.model, model_path=settings.model,
f16_kv=settings.f16_kv, f16_kv=settings.f16_kv,
use_mlock=settings.use_mlock, use_mlock=settings.use_mlock,
use_mmap=settings.use_mmap, use_mmap=settings.use_mmap,
@ -60,12 +64,17 @@ def init_llama(settings: Settings = None):
if settings.cache: if settings.cache:
cache = llama_cpp.LlamaCache() cache = llama_cpp.LlamaCache()
llama.set_cache(cache) llama.set_cache(cache)
return app
llama_lock = Lock() llama_lock = Lock()
def get_llama(): def get_llama():
with llama_lock: with llama_lock:
yield llama yield llama
class CreateCompletionRequest(BaseModel): class CreateCompletionRequest(BaseModel):
prompt: Union[str, List[str]] prompt: Union[str, List[str]]
suffix: Optional[str] = Field(None) suffix: Optional[str] = Field(None)
@ -102,7 +111,7 @@ class CreateCompletionRequest(BaseModel):
CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion) CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
@app.post( @router.post(
"/v1/completions", "/v1/completions",
response_model=CreateCompletionResponse, response_model=CreateCompletionResponse,
) )
@ -148,7 +157,7 @@ class CreateEmbeddingRequest(BaseModel):
CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding) CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
@app.post( @router.post(
"/v1/embeddings", "/v1/embeddings",
response_model=CreateEmbeddingResponse, response_model=CreateEmbeddingResponse,
) )
@ -202,7 +211,7 @@ class CreateChatCompletionRequest(BaseModel):
CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatCompletion) CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatCompletion)
@app.post( @router.post(
"/v1/chat/completions", "/v1/chat/completions",
response_model=CreateChatCompletionResponse, response_model=CreateChatCompletionResponse,
) )
@ -256,7 +265,7 @@ class ModelList(TypedDict):
GetModelResponse = create_model_from_typeddict(ModelList) GetModelResponse = create_model_from_typeddict(ModelList)
@app.get("/v1/models", response_model=GetModelResponse) @router.get("/v1/models", response_model=GetModelResponse)
def get_models() -> ModelList: def get_models() -> ModelList:
return { return {
"object": "list", "object": "list",

View file

@ -24,7 +24,9 @@ def test_llama_patch(monkeypatch):
return 0 return 0
def mock_get_logits(*args, **kwargs): def mock_get_logits(*args, **kwargs):
return (llama_cpp.c_float * n_vocab)(*[llama_cpp.c_float(0) for _ in range(n_vocab)]) return (llama_cpp.c_float * n_vocab)(
*[llama_cpp.c_float(0) for _ in range(n_vocab)]
)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval) monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits) monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
@ -88,6 +90,7 @@ def test_llama_patch(monkeypatch):
def test_llama_pickle(): def test_llama_pickle():
import pickle import pickle
import tempfile import tempfile
fp = tempfile.TemporaryFile() fp = tempfile.TemporaryFile()
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True) llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
pickle.dump(llama, fp) pickle.dump(llama, fp)
@ -101,6 +104,7 @@ def test_llama_pickle():
assert llama.detokenize(llama.tokenize(text)) == text assert llama.detokenize(llama.tokenize(text)) == text
def test_utf8(monkeypatch): def test_utf8(monkeypatch):
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True) llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
n_vocab = int(llama_cpp.llama_n_vocab(llama.ctx)) n_vocab = int(llama_cpp.llama_n_vocab(llama.ctx))
@ -110,7 +114,9 @@ def test_utf8(monkeypatch):
return 0 return 0
def mock_get_logits(*args, **kwargs): def mock_get_logits(*args, **kwargs):
return (llama_cpp.c_float * n_vocab)(*[llama_cpp.c_float(0) for _ in range(n_vocab)]) return (llama_cpp.c_float * n_vocab)(
*[llama_cpp.c_float(0) for _ in range(n_vocab)]
)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval) monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits) monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
@ -143,11 +149,12 @@ def test_utf8(monkeypatch):
def test_llama_server(): def test_llama_server():
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from llama_cpp.server.app import app, init_llama, Settings from llama_cpp.server.app import create_app, Settings
s = Settings()
s.model = MODEL settings = Settings()
s.vocab_only = True settings.model = MODEL
init_llama(s) settings.vocab_only = True
app = create_app(settings)
client = TestClient(app) client = TestClient(app)
response = client.get("/v1/models") response = client.get("/v1/models")
assert response.json() == { assert response.json() == {