diff --git a/CMakeLists.txt b/CMakeLists.txt index c633c07..8d06370 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -41,4 +41,23 @@ if (LLAMA_BUILD) FILES $ DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/llama_cpp ) + add_subdirectory(vendor/llama.cpp/examples/llava) + set_target_properties(llava_shared PROPERTIES OUTPUT_NAME "llava") + install( + TARGETS llava_shared + LIBRARY DESTINATION ${SKBUILD_PLATLIB_DIR}/llama_cpp + RUNTIME DESTINATION ${SKBUILD_PLATLIB_DIR}/llama_cpp + ARCHIVE DESTINATION ${SKBUILD_PLATLIB_DIR}/llama_cpp + FRAMEWORK DESTINATION ${SKBUILD_PLATLIB_DIR}/llama_cpp + RESOURCE DESTINATION ${SKBUILD_PLATLIB_DIR}/llama_cpp + ) + # Temporary fix for https://github.com/scikit-build/scikit-build-core/issues/374 + install( + TARGETS llava_shared + LIBRARY DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/llama_cpp + RUNTIME DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/llama_cpp + ARCHIVE DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/llama_cpp + FRAMEWORK DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/llama_cpp + RESOURCE DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/llama_cpp + ) endif() diff --git a/docs/server.md b/docs/server.md new file mode 100644 index 0000000..e7d4bb6 --- /dev/null +++ b/docs/server.md @@ -0,0 +1,77 @@ +# OpenAI Compatible Server + +`llama-cpp-python` offers an OpenAI API compatible web server. + +This web server can be used to serve local models and easily connect them to existing clients. + +## Setup + +### Installation + +The server can be installed by running the following command: + +```bash +pip install llama-cpp-python[server] +``` + +### Running the server + +The server can then be started by running the following command: + +```bash +python3 -m llama_cpp.server --model +``` + +### Server options + +For a full list of options, run: + +```bash +python3 -m llama_cpp.server --help +``` + +NOTE: All server options are also available as environment variables. For example, `--model` can be set by setting the `MODEL` environment variable. + +## Guides + +### Multi-modal Models + +`llama-cpp-python` supports the llava1.5 family of multi-modal models which allow the language model to +read information from both text and images. + +You'll first need to download one of the available multi-modal models in GGUF format: + +- [llava1.5 7b](https://huggingface.co/mys/ggml_llava-v1.5-7b) +- [llava1.5 13b](https://huggingface.co/mys/ggml_llava-v1.5-13b) + +Then when you run the server you'll need to also specify the path to the clip model used for image embedding + +```bash +python3 -m llama_cpp.server --model --clip-model-path +``` + +Then you can just use the OpenAI API as normal + +```python3 +from openai import OpenAI + +client = OpenAI(base_url="http://:/v1", api_key="sk-xxx") +response = client.chat.completions.create( + model="gpt-4-vision-preview", + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "" + }, + }, + {"type": "text", "text": "What does the image say"}, + ], + } + ], +) +print(response) +``` \ No newline at end of file diff --git a/examples/notebooks/Multimodal.ipynb b/examples/notebooks/Multimodal.ipynb new file mode 100644 index 0000000..11b14df --- /dev/null +++ b/examples/notebooks/Multimodal.ipynb @@ -0,0 +1,84 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ChatCompletion(id='chatcmpl-65a710ba-41d1-4d0a-a124-a44b2b4a0189', choices=[Choice(finish_reason='stop', index=0, message=ChatCompletionMessage(content=' The image reads \"LlamaC++.\"', role='assistant', function_call=None, tool_calls=None))], created=1699413274, model='gpt-4-vision-preview', object='chat.completion', system_fingerprint=None, usage=CompletionUsage(completion_tokens=10, prompt_tokens=624, total_tokens=634))\n" + ] + } + ], + "source": [ + "from openai import OpenAI\n", + "\n", + "import urllib.request\n", + "import base64\n", + "\n", + "def get_data_url(url):\n", + " return \"data:image/png;base64,\" + base64.b64encode(urllib.request.urlopen(url).read()).decode(\"utf-8\")\n", + "\n", + "client = OpenAI(base_url=\"http://100.64.159.73:8000/v1\", api_key=\"sk-1234\")\n", + "response = client.chat.completions.create(\n", + " model=\"gpt-4-vision-preview\",\n", + " messages=[\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": [\n", + " {\n", + " \"type\": \"image_url\",\n", + " \"image_url\": {\n", + " \"url\": get_data_url(\"https://user-images.githubusercontent.com/1991296/230134379-7181e485-c521-4d23-a0d6-f7b3b61ba524.png\"),\n", + " # \"url\": \"https://user-images.githubusercontent.com/1991296/230134379-7181e485-c521-4d23-a0d6-f7b3b61ba524.png\",\n", + " },\n", + " },\n", + " {\"type\": \"text\", \"text\": \"What does the image say\"},\n", + " ],\n", + " }\n", + " ],\n", + ")\n", + "print(response)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![](https://user-images.githubusercontent.com/1991296/230134379-7181e485-c521-4d23-a0d6-f7b3b61ba524.png)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5+" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 6dc113a..7a2c34f 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -21,9 +21,9 @@ from collections import deque, OrderedDict import diskcache import ctypes -from . import llama_cpp from .llama_types import * from .llama_grammar import LlamaGrammar +import llama_cpp.llama_cpp as llama_cpp import llama_cpp.llama_chat_format as llama_chat_format import numpy as np @@ -752,6 +752,7 @@ class Llama: numa: bool = False, # Chat Format Params chat_format: str = "llama-2", + chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler] = None, # Misc verbose: bool = True, # Extra Params @@ -784,6 +785,7 @@ class Llama: lora_path: Path to a LoRA file to apply to the model. numa: Enable NUMA support. (NOTE: The initial value of this parameter is used for the remainder of the program as this value is set in llama_backend_init) chat_format: String specifying the chat format to use when calling create_chat_completion. + chat_handler: Optional chat handler to use when calling create_chat_completion. verbose: Print verbose output to stderr. Raises: @@ -910,6 +912,7 @@ class Llama: print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr) self.chat_format = chat_format + self.chat_handler = chat_handler self._n_vocab = self.n_vocab() self._n_ctx = self.n_ctx() @@ -1231,7 +1234,7 @@ class Llama: else: inputs = input - data: List[EmbeddingData] = [] + data: List[Embedding] = [] total_tokens = 0 for index, input in enumerate(inputs): tokens = self.tokenize(input.encode("utf-8"), special=True) @@ -1276,7 +1279,7 @@ class Llama: def _create_completion( self, - prompt: str, + prompt: Union[str, List[int]], suffix: Optional[str] = None, max_tokens: int = 16, temperature: float = 0.8, @@ -1297,7 +1300,9 @@ class Llama: stopping_criteria: Optional[StoppingCriteriaList] = None, logits_processor: Optional[LogitsProcessorList] = None, grammar: Optional[LlamaGrammar] = None, - ) -> Union[Iterator[Completion], Iterator[CompletionChunk]]: + ) -> Union[ + Iterator[CreateCompletionResponse], Iterator[CreateCompletionStreamResponse] + ]: assert self._ctx is not None assert suffix is None or suffix.__class__ is str @@ -1309,7 +1314,7 @@ class Llama: self.tokenize(prompt.encode("utf-8"), special=True) if prompt != "" else [self.token_bos()] - ) + ) if isinstance(prompt, str) else prompt text: bytes = b"" returned_tokens: int = 0 stop = ( @@ -1322,7 +1327,7 @@ class Llama: if len(prompt_tokens) >= self._n_ctx: raise ValueError( - f"Requested tokens ({len(prompt_tokens)}) exceed context window of {llama_cpp.llama_n_ctx(self._ctx)}" + f"Requested tokens ({len(prompt_tokens)}) exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}" ) if max_tokens <= 0: @@ -1732,7 +1737,7 @@ class Llama: def create_completion( self, - prompt: str, + prompt: Union[str, List[int]], suffix: Optional[str] = None, max_tokens: int = 128, temperature: float = 0.8, @@ -1753,7 +1758,7 @@ class Llama: stopping_criteria: Optional[StoppingCriteriaList] = None, logits_processor: Optional[LogitsProcessorList] = None, grammar: Optional[LlamaGrammar] = None, - ) -> Union[Completion, Iterator[CompletionChunk]]: + ) -> Union[CreateCompletionResponse, Iterator[CreateCompletionStreamResponse]]: """Generate text from a prompt. Args: @@ -1800,7 +1805,7 @@ class Llama: grammar=grammar, ) if stream: - chunks: Iterator[CompletionChunk] = completion_or_chunks + chunks: Iterator[CreateCompletionStreamResponse] = completion_or_chunks return chunks completion: Completion = next(completion_or_chunks) # type: ignore return completion @@ -1828,7 +1833,7 @@ class Llama: stopping_criteria: Optional[StoppingCriteriaList] = None, logits_processor: Optional[LogitsProcessorList] = None, grammar: Optional[LlamaGrammar] = None, - ) -> Union[Completion, Iterator[CompletionChunk]]: + ) -> Union[CreateCompletionResponse, Iterator[CreateCompletionStreamResponse]]: """Generate text from a prompt. Args: @@ -1879,7 +1884,9 @@ class Llama: self, messages: List[ChatCompletionRequestMessage], functions: Optional[List[ChatCompletionFunction]] = None, - function_call: Optional[Union[str, ChatCompletionFunctionCall]] = None, + function_call: Optional[ChatCompletionRequestFunctionCall] = None, + tools: Optional[List[ChatCompletionTool]] = None, + tool_choice: Optional[ChatCompletionToolChoiceOption] = None, temperature: float = 0.2, top_p: float = 0.95, top_k: int = 40, @@ -1896,7 +1903,9 @@ class Llama: model: Optional[str] = None, logits_processor: Optional[LogitsProcessorList] = None, grammar: Optional[LlamaGrammar] = None, - ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]: + ) -> Union[ + CreateChatCompletionResponse, Iterator[CreateChatCompletionStreamResponse] + ]: """Generate a chat completion from a list of messages. Args: @@ -1912,12 +1921,16 @@ class Llama: Returns: Generated chat completion or a stream of chat completion chunks. """ - handler = llama_chat_format.get_chat_completion_handler(self.chat_format) + handler = self.chat_handler or llama_chat_format.get_chat_completion_handler( + self.chat_format + ) return handler( - self, + llama=self, messages=messages, functions=functions, function_call=function_call, + tools=tools, + tool_choice=tool_choice, temperature=temperature, top_p=top_p, top_k=top_k, @@ -1974,6 +1987,7 @@ class Llama: numa=self.numa, # Chat Format Params chat_format=self.chat_format, + chat_handler=self.chat_handler, # Misc verbose=self.verbose, ) @@ -2015,6 +2029,7 @@ class Llama: numa=state["numa"], # Chat Format Params chat_format=state["chat_format"], + chat_handler=state["chat_handler"], # Misc verbose=state["verbose"], ) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 903a8c9..60b38d8 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -1,22 +1,24 @@ from __future__ import annotations import os +import ctypes import dataclasses from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Protocol -from . import llama_types -from . import llama +import llama_cpp.llama_types as llama_types +import llama_cpp.llama as llama class LlamaChatCompletionHandler(Protocol): def __call__( self, + *, llama: llama.Llama, messages: List[llama_types.ChatCompletionRequestMessage], functions: Optional[List[llama_types.ChatCompletionFunction]] = None, - function_call: Optional[ - Union[str, llama_types.ChatCompletionFunctionCall] - ] = None, + function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None, + tools: Optional[List[llama_types.ChatCompletionTool]] = None, + tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None, temperature: float = 0.2, top_p: float = 0.95, top_k: int = 40, @@ -33,7 +35,8 @@ class LlamaChatCompletionHandler(Protocol): model: Optional[str] = None, logits_processor: Optional[llama.LogitsProcessorList] = None, grammar: Optional[llama.LlamaGrammar] = None, - ) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]: + **kwargs, # type: ignore + ) -> Union[llama_types.CreateChatCompletionResponse, Iterator[llama_types.CreateChatCompletionStreamResponse]]: ... @@ -199,7 +202,7 @@ def _convert_text_completion_to_chat( def _convert_text_completion_chunks_to_chat( - chunks: Iterator[llama_types.CompletionChunk], + chunks: Iterator[llama_types.CreateCompletionStreamResponse], ) -> Iterator[llama_types.ChatCompletionChunk]: for i, chunk in enumerate(chunks): if i == 0: @@ -239,12 +242,15 @@ def _convert_text_completion_chunks_to_chat( def _convert_completion_to_chat( completion_or_chunks: Union[ - llama_types.Completion, Iterator[llama_types.CompletionChunk] + llama_types.CreateCompletionResponse, + Iterator[llama_types.CreateCompletionStreamResponse], ], stream: bool = False, -) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]: +) -> Union[ + llama_types.CreateChatCompletionResponse, Iterator[llama_types.ChatCompletionChunk] +]: if stream: - chunks: Iterator[llama_types.CompletionChunk] = completion_or_chunks # type: ignore + chunks: Iterator[llama_types.CreateCompletionStreamResponse] = completion_or_chunks # type: ignore return _convert_text_completion_chunks_to_chat(chunks) else: completion: llama_types.Completion = completion_or_chunks # type: ignore @@ -329,7 +335,9 @@ def get_chat_format(name: str): ) -def hf_autotokenizer_to_chat_formatter(pretrained_model_name_or_path: Union[str, os.PathLike[str]]) -> ChatFormatter: +def hf_autotokenizer_to_chat_formatter( + pretrained_model_name_or_path: Union[str, os.PathLike[str]] +) -> ChatFormatter: # https://huggingface.co/docs/transformers/main/chat_templating # https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format # https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json @@ -538,7 +546,7 @@ def functionary_chat_handler( llama: llama.Llama, messages: List[llama_types.ChatCompletionRequestMessage], functions: Optional[List[llama_types.ChatCompletionFunction]] = None, - function_call: Optional[Union[str, llama_types.ChatCompletionFunctionCall]] = None, + function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None, temperature: float = 0.2, top_p: float = 0.95, top_k: int = 40, @@ -555,6 +563,7 @@ def functionary_chat_handler( model: Optional[str] = None, logits_processor: Optional[llama.LogitsProcessorList] = None, grammar: Optional[llama.LlamaGrammar] = None, + **kwargs, # type: ignore ) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]: SYSTEM_MESSAGE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary""" @@ -613,13 +622,13 @@ def functionary_chat_handler( all_messages: List[llama_types.ChatCompletionRequestMessage] = [] if functions is not None: all_messages.append( - llama_types.ChatCompletionRequestMessage( + llama_types.ChatCompletionRequestSystemMessage( role="system", content=generate_schema_from_functions(functions) ) ) all_messages.append( - llama_types.ChatCompletionRequestMessage( + llama_types.ChatCompletionRequestSystemMessage( role="system", content=SYSTEM_MESSAGE ) ) @@ -636,7 +645,9 @@ def functionary_chat_handler( all_messages.append(message) all_messages.append( - llama_types.ChatCompletionRequestMessage(role="assistant", content=None) + llama_types.ChatCompletionRequestAssistantMessage( + role="assistant", content=None + ) ) def message_to_str(msg: llama_types.ChatCompletionRequestMessage): @@ -713,6 +724,10 @@ def functionary_chat_handler( prompt=new_prompt, stop=["user:", ""], stream=False ) # type: ignore + assert "usage" in completion + assert isinstance(function_call, str) + assert stream is False # TODO: support stream mode + return llama_types.CreateChatCompletionResponse( id="chat" + completion["id"], object="chat.completion", @@ -734,3 +749,119 @@ def functionary_chat_handler( ], usage=completion["usage"], ) + + +class Llava15ChatHandler: + def __init__(self, clip_model_path: str): + import llama_cpp.llava_cpp as llava_cpp + + self._llava_cpp = llava_cpp + self.clip_model_path = clip_model_path + + self.clip_ctx = self._llava_cpp.clip_model_load(self.clip_model_path.encode(), 0) + + def __del__(self): + if self.clip_ctx is not None: + self._llava_cpp.clip_free(self.clip_ctx) + self.clip_ctx = None + + def load_image(self, image_url: str) -> bytes: + if image_url.startswith("data:"): + import base64 + + image_bytes = base64.b64decode(image_url.split(",")[1]) + return image_bytes + else: + import urllib.request + + with urllib.request.urlopen(image_url) as f: + image_bytes = f.read() + return image_bytes + + def __call__( + self, + *, + llama: llama.Llama, + messages: List[llama_types.ChatCompletionRequestMessage], + functions: Optional[List[llama_types.ChatCompletionFunction]] = None, + function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None, + tools: Optional[List[llama_types.ChatCompletionTool]] = None, + tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None, + temperature: float = 0.2, + top_p: float = 0.95, + top_k: int = 40, + stream: bool = False, + stop: Optional[Union[str, List[str]]] = [], + max_tokens: int = 256, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + repeat_penalty: float = 1.1, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, + model: Optional[str] = None, + logits_processor: Optional[llama.LogitsProcessorList] = None, + grammar: Optional[llama.LlamaGrammar] = None, + **kwargs, # type: ignore + ) -> Union[llama_types.CreateChatCompletionResponse, Iterator[llama_types.CreateChatCompletionStreamResponse]]: + assert llama.context_params.logits_all is True # BUG: logits_all=True is required for llava + assert self.clip_ctx is not None + system_prompt = _get_system_message(messages) + system_prompt = system_prompt if system_prompt != "" else "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." + system_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." + user_role = "\nUSER:" + assistant_role = "\nASSISTANT:" + llama.reset() + llama.eval(llama.tokenize(system_prompt.encode("utf8"), add_bos=True)) + for message in messages: + if message["role"] == "user" and message["content"] is not None: + if isinstance(message["content"], str): + llama.eval(llama.tokenize(f"{user_role} {message['content']}".encode("utf8"), add_bos=False)) + else: + assert isinstance(message["content"], list) + llama.eval(llama.tokenize(f"{user_role} ".encode("utf8"), add_bos=False)) + for content in message["content"]: + if content["type"] == "text": + llama.eval(llama.tokenize(f"{content['text']}".encode("utf8"), add_bos=False)) + if content["type"] == "image_url": + image_bytes = self.load_image(content["image_url"]["url"]) if isinstance(content["image_url"], dict) else self.load_image(content["image_url"]) + import array + data_array = array.array('B', image_bytes) + c_ubyte_ptr = (ctypes.c_ubyte * len(data_array)).from_buffer(data_array) + embed = self._llava_cpp.llava_image_embed_make_with_bytes(ctx_clip=self.clip_ctx, n_threads=llama.context_params.n_threads, image_bytes=c_ubyte_ptr, image_bytes_length=len(image_bytes)) + # image_bytes_p = (ctypes.c_uint8 * len(image_bytes)).from_buffer_copy(image_bytes) + # embed = self._llava_cpp.llava_image_embed_make_with_bytes(ctx_clip=self.clip_ctx, n_threads=1, image_bytes=image_bytes_p, image_bytes_length=len(image_bytes)) + try: + n_past = ctypes.c_int(llama.n_tokens) + n_past_p = ctypes.pointer(n_past) + self._llava_cpp.llava_eval_image_embed(ctx_llama=llama.ctx, embed=embed, n_batch=llama.n_batch, n_past=n_past_p) + assert llama.n_ctx() >= n_past.value + llama.n_tokens = n_past.value + finally: + self._llava_cpp.llava_image_embed_free(embed) + if message["role"] == "assistant" and message["content"] is not None: + llama.eval(llama.tokenize(f"ASSISTANT: {message['content']}".encode("utf8"), add_bos=False)) + llama.eval(llama.tokenize(f"{assistant_role}".encode("utf8"), add_bos=False)) + + prompt = llama._input_ids.tolist() + + return _convert_completion_to_chat(llama.create_completion( + prompt=prompt, + temperature=temperature, + top_p=top_p, + top_k=top_k, + stream=stream, + stop=stop, + max_tokens=max_tokens, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + repeat_penalty=repeat_penalty, + tfs_z=tfs_z, + mirostat_mode=mirostat_mode, + mirostat_tau=mirostat_tau, + mirostat_eta=mirostat_eta, + model=model, + logits_processor=logits_processor, + grammar=grammar, + ), stream=stream) \ No newline at end of file diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py index 29431d9..ccbea57 100644 --- a/llama_cpp/llama_grammar.py +++ b/llama_cpp/llama_grammar.py @@ -19,7 +19,7 @@ from typing import ( overload, ) -from . import llama_cpp +import llama_cpp.llama_cpp as llama_cpp # Type aliases llama_grammar_element = llama_cpp.llama_grammar_element diff --git a/llama_cpp/llama_types.py b/llama_cpp/llama_types.py index a64033e..69d07fc 100644 --- a/llama_cpp/llama_types.py +++ b/llama_cpp/llama_types.py @@ -1,4 +1,6 @@ -"""Types and request signatrues for OpenAI compatibility +"""Types and request signatures for OpenAI compatibility + +NOTE: These types may change to match the OpenAI OpenAPI specification. Based on the OpenAI OpenAPI specification: https://github.com/openai/openai-openapi/blob/master/openapi.yaml @@ -8,6 +10,12 @@ from typing import Any, List, Optional, Dict, Union from typing_extensions import TypedDict, NotRequired, Literal +# NOTE: Defining this correctly using annotations seems to break pydantic validation. +# This is a workaround until we can figure out how to do this correctly +# JsonType = Union[None, int, str, bool, List["JsonType"], Dict[str, "JsonType"]] +JsonType = Union[None, int, str, bool, List[Any], Dict[str, Any]] + + class EmbeddingUsage(TypedDict): prompt_tokens: int total_tokens: int @@ -19,9 +27,6 @@ class Embedding(TypedDict): embedding: List[float] -EmbeddingData = Embedding - - class CreateEmbeddingResponse(TypedDict): object: Literal["list"] model: str @@ -49,110 +54,92 @@ class CompletionUsage(TypedDict): total_tokens: int -class CreateCompletionStreamResponse(TypedDict): - id: str - object: Literal["text_completion"] - created: int - model: str - choices: List[CompletionChoice] - - -CompletionChunk = CreateCompletionStreamResponse - - class CreateCompletionResponse(TypedDict): id: str object: Literal["text_completion"] created: int model: str choices: List[CompletionChoice] - usage: CompletionUsage + usage: NotRequired[CompletionUsage] -Completion = CreateCompletionResponse - - -class ChatCompletionFunctionCall(TypedDict): +class ChatCompletionResponseFunctionCall(TypedDict): name: str arguments: str class ChatCompletionResponseMessage(TypedDict): - role: Literal["assistant", "user", "system", "function"] content: Optional[str] - user: NotRequired[str] - function_call: NotRequired[ChatCompletionFunctionCall] + tool_calls: NotRequired["ChatCompletionMessageToolCalls"] + role: Literal["assistant", "function"] # NOTE: "function" may be incorrect here + function_call: NotRequired[ChatCompletionResponseFunctionCall] # DEPRECATED -ChatCompletionMessage = ChatCompletionResponseMessage - - -class ChatCompletionResponseFunction(TypedDict): +class ChatCompletionFunction(TypedDict): name: str description: NotRequired[str] - parameters: Dict[str, Any] # TODO: make this more specific - - -ChatCompletionFunction = ChatCompletionResponseFunction + parameters: Dict[str, JsonType] # TODO: make this more specific class ChatCompletionResponseChoice(TypedDict): index: int - message: ChatCompletionMessage + message: "ChatCompletionResponseMessage" finish_reason: Optional[str] -ChatCompletionChoice = ChatCompletionResponseChoice - - class CreateChatCompletionResponse(TypedDict): id: str object: Literal["chat.completion"] created: int model: str - choices: List[ChatCompletionChoice] + choices: List["ChatCompletionResponseChoice"] usage: CompletionUsage -ChatCompletion = CreateChatCompletionResponse +class ChatCompletionMessageToolCallChunkFunction(TypedDict): + name: str + arguments: str + + +class ChatCompletionMessageToolCallChunk(TypedDict): + index: int + id: NotRequired[str] + type: Literal["function"] + function: ChatCompletionMessageToolCallChunkFunction class ChatCompletionStreamResponseDeltaEmpty(TypedDict): pass -ChatCompletionChunkDeltaEmpty = ChatCompletionStreamResponseDeltaEmpty +class ChatCompletionStreamResponseDeltaFunctionCall(TypedDict): + name: str + arguments: str class ChatCompletionStreamResponseDelta(TypedDict): - role: NotRequired[Literal["assistant"]] content: NotRequired[str] - function_call: NotRequired[ChatCompletionFunctionCall] - - -ChatCompletionChunkDelta = ChatCompletionStreamResponseDelta + function_call: NotRequired[ + ChatCompletionStreamResponseDeltaFunctionCall + ] # DEPRECATED + tool_calls: NotRequired[List[ChatCompletionMessageToolCallChunk]] + role: NotRequired[Literal["system", "user", "assistant", "tool"]] class ChatCompletionStreamResponseChoice(TypedDict): index: int - delta: Union[ChatCompletionChunkDelta, ChatCompletionChunkDeltaEmpty] + delta: Union[ + ChatCompletionStreamResponseDelta, ChatCompletionStreamResponseDeltaEmpty + ] finish_reason: Optional[Literal["stop", "length", "function_call"]] -ChatCompletionChunkChoice = ChatCompletionStreamResponseChoice - - -class ChatCompletionStreamResponse(TypedDict): +class CreateChatCompletionStreamResponse(TypedDict): id: str model: str object: Literal["chat.completion.chunk"] created: int - choices: List[ChatCompletionChunkChoice] - - -ChatCompletionChunk = ChatCompletionStreamResponse - -JsonType = Union[None, int, str, bool, List["JsonType"], Dict[str, "JsonType"]] + choices: List[ChatCompletionStreamResponseChoice] class ChatCompletionFunctions(TypedDict): @@ -165,8 +152,137 @@ class ChatCompletionFunctionCallOption(TypedDict): name: str -class ChatCompletionRequestMessage(TypedDict): - role: Literal["assistant", "user", "system", "function"] +class ChatCompletionRequestMessageContentPartText(TypedDict): + type: Literal["text"] + text: str + + +class ChatCompletionRequestMessageContentPartImageImageUrl(TypedDict): + url: str + detail: NotRequired[Literal["auto", "low", "high"]] + + +class ChatCompletionRequestMessageContentPartImage(TypedDict): + type: Literal["image_url"] + image_url: Union[str, ChatCompletionRequestMessageContentPartImageImageUrl] + + +ChatCompletionRequestMessageContentPart = Union[ + ChatCompletionRequestMessageContentPartText, + ChatCompletionRequestMessageContentPartImage, +] + + +class ChatCompletionRequestSystemMessage(TypedDict): + role: Literal["system"] content: Optional[str] - name: NotRequired[str] - function_call: NotRequired[ChatCompletionFunctionCall] + + +class ChatCompletionRequestUserMessage(TypedDict): + role: Literal["user"] + content: Optional[Union[str, List[ChatCompletionRequestMessageContentPart]]] + + +class ChatCompletionMessageToolCallFunction(TypedDict): + name: str + arguments: str + + +class ChatCompletionMessageToolCall(TypedDict): + id: str + type: Literal["function"] + function: ChatCompletionMessageToolCallFunction + + +ChatCompletionMessageToolCalls = List[ChatCompletionMessageToolCall] + + +class ChatCompletionRequestAssistantMessageFunctionCall(TypedDict): + name: str + arguments: str + + +class ChatCompletionRequestAssistantMessage(TypedDict): + role: Literal["assistant"] + content: Optional[str] + tool_calls: NotRequired[ChatCompletionMessageToolCalls] + function_call: NotRequired[ + ChatCompletionRequestAssistantMessageFunctionCall + ] # DEPRECATED + + +class ChatCompletionRequestToolMessage(TypedDict): + role: Literal["tool"] + content: Optional[str] + tool_call_id: str + + +class ChatCompletionRequestFunctionMessage(TypedDict): + role: Literal["function"] + content: Optional[str] + name: str + + +ChatCompletionRequestMessage = Union[ + ChatCompletionRequestSystemMessage, + ChatCompletionRequestUserMessage, + ChatCompletionRequestAssistantMessage, + ChatCompletionRequestUserMessage, + ChatCompletionRequestToolMessage, + ChatCompletionRequestFunctionMessage, +] + + +class ChatCompletionRequestFunctionCallOption(TypedDict): + name: str + + +ChatCompletionRequestFunctionCall = Union[ + Literal["none", "auto"], ChatCompletionRequestFunctionCallOption +] + +ChatCompletionFunctionParameters = Dict[str, JsonType] # TODO: make this more specific + + +class ChatCompletionToolFunction(TypedDict): + name: str + description: NotRequired[str] + parameters: ChatCompletionFunctionParameters + + +class ChatCompletionTool(TypedDict): + type: Literal["function"] + function: ChatCompletionToolFunction + + +class ChatCompletionNamedToolChoiceFunction(TypedDict): + name: str + + +class ChatCompletionNamedToolChoice(TypedDict): + type: Literal["function"] + function: ChatCompletionNamedToolChoiceFunction + + +ChatCompletionToolChoiceOption = Union[ + Literal["none", "auto"], ChatCompletionNamedToolChoice +] + + +# NOTE: The following type names are not part of the OpenAI OpenAPI specification +# and will be removed in a future major release. + +EmbeddingData = Embedding +CompletionChunk = CreateCompletionResponse +Completion = CreateCompletionResponse +CreateCompletionStreamResponse = CreateCompletionResponse +ChatCompletionMessage = ChatCompletionResponseMessage +ChatCompletionChoice = ChatCompletionResponseChoice +ChatCompletion = CreateChatCompletionResponse +ChatCompletionChunkDeltaEmpty = ChatCompletionStreamResponseDeltaEmpty +ChatCompletionChunkChoice = ChatCompletionStreamResponseChoice +ChatCompletionChunkDelta = ChatCompletionStreamResponseDelta +ChatCompletionChunk = CreateChatCompletionStreamResponse +ChatCompletionStreamResponse = CreateChatCompletionStreamResponse +ChatCompletionResponseFunction = ChatCompletionFunction +ChatCompletionFunctionCall = ChatCompletionResponseFunctionCall diff --git a/llama_cpp/llava_cpp.py b/llama_cpp/llava_cpp.py new file mode 100644 index 0000000..72f6a12 --- /dev/null +++ b/llama_cpp/llava_cpp.py @@ -0,0 +1,232 @@ +import sys +import os +import ctypes +from ctypes import ( + c_bool, + c_char_p, + c_int, + c_int8, + c_int32, + c_uint8, + c_uint32, + c_size_t, + c_float, + c_double, + c_void_p, + POINTER, + _Pointer, # type: ignore + Structure, + Array, +) +import pathlib +from typing import List, Union + +import llama_cpp.llama_cpp as llama_cpp + +# Load the library +def _load_shared_library(lib_base_name: str): + # Construct the paths to the possible shared library names + _base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) + # Searching for the library in the current directory under the name "libllama" (default name + # for llamacpp) and "llama" (default name for this repo) + _lib_paths: List[pathlib.Path] = [] + # Determine the file extension based on the platform + if sys.platform.startswith("linux"): + _lib_paths += [ + _base_path / f"lib{lib_base_name}.so", + ] + elif sys.platform == "darwin": + _lib_paths += [ + _base_path / f"lib{lib_base_name}.so", + _base_path / f"lib{lib_base_name}.dylib", + ] + elif sys.platform == "win32": + _lib_paths += [ + _base_path / f"{lib_base_name}.dll", + _base_path / f"lib{lib_base_name}.dll", + ] + else: + raise RuntimeError("Unsupported platform") + + if "LLAMA_CPP_LIB" in os.environ: + lib_base_name = os.environ["LLAMA_CPP_LIB"] + _lib = pathlib.Path(lib_base_name) + _base_path = _lib.parent.resolve() + _lib_paths = [_lib.resolve()] + + cdll_args = dict() # type: ignore + # Add the library directory to the DLL search path on Windows (if needed) + if sys.platform == "win32" and sys.version_info >= (3, 8): + os.add_dll_directory(str(_base_path)) + if "CUDA_PATH" in os.environ: + os.add_dll_directory(os.path.join(os.environ["CUDA_PATH"], "bin")) + os.add_dll_directory(os.path.join(os.environ["CUDA_PATH"], "lib")) + cdll_args["winmode"] = ctypes.RTLD_GLOBAL + + # Try to load the shared library, handling potential errors + for _lib_path in _lib_paths: + if _lib_path.exists(): + try: + return ctypes.CDLL(str(_lib_path), **cdll_args) + except Exception as e: + raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}") + + raise FileNotFoundError( + f"Shared library with base name '{lib_base_name}' not found" + ) + + +# Specify the base name of the shared library to load +_libllava_base_name = "llava" + +# Load the library +_libllava = _load_shared_library(_libllava_base_name) + + +################################################ +# llava.h +################################################ + +# struct clip_ctx; +clip_ctx_p = c_void_p + +# struct llava_image_embed { +# float * embed; +# int n_image_pos; +# }; +class llava_image_embed(Structure): + _fields_ = [ + ("embed", POINTER(c_float)), + ("n_image_pos", c_int), + ] + +# /** sanity check for clip <-> llava embed size match */ +# LLAVA_API bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip); +def llava_validate_embed_size(ctx_llama: llama_cpp.llama_context_p, ctx_clip: clip_ctx_p) -> bool: + return _libllava.llava_validate_embed_size(ctx_llama, ctx_clip) + +_libllava.llava_validate_embed_size.argtypes = [llama_cpp.llama_context_p, clip_ctx_p] +_libllava.llava_validate_embed_size.restype = c_bool + +# /** build an image embed from image file bytes */ +# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length); +def llava_image_embed_make_with_bytes(ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_bytes: bytes, image_bytes_length: Union[c_int, int]) -> "_Pointer[llava_image_embed]": + return _libllava.llava_image_embed_make_with_bytes(ctx_clip, n_threads, image_bytes, image_bytes_length) + +_libllava.llava_image_embed_make_with_bytes.argtypes = [clip_ctx_p, c_int, POINTER(c_uint8), c_int] +_libllava.llava_image_embed_make_with_bytes.restype = POINTER(llava_image_embed) + +# /** build an image embed from a path to an image filename */ +# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path); +def llava_image_embed_make_with_filename(ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_path: bytes) -> "_Pointer[llava_image_embed]": + return _libllava.llava_image_embed_make_with_filename(ctx_clip, n_threads, image_path) + +_libllava.llava_image_embed_make_with_filename.argtypes = [clip_ctx_p, c_int, c_char_p] +_libllava.llava_image_embed_make_with_filename.restype = POINTER(llava_image_embed) + +# LLAVA_API void llava_image_embed_free(struct llava_image_embed * embed); +# /** free an embedding made with llava_image_embed_make_* */ +def llava_image_embed_free(embed: "_Pointer[llava_image_embed]"): + return _libllava.llava_image_embed_free(embed) + +_libllava.llava_image_embed_free.argtypes = [POINTER(llava_image_embed)] +_libllava.llava_image_embed_free.restype = None + +# /** write the image represented by embed into the llama context with batch size n_batch, starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */ +# LLAVA_API bool llava_eval_image_embed(struct llama_context * ctx_llama, const struct llava_image_embed * embed, int n_batch, int * n_past); +def llava_eval_image_embed(ctx_llama: llama_cpp.llama_context_p, embed: "_Pointer[llava_image_embed]", n_batch: Union[c_int, int], n_past: "_Pointer[c_int]") -> bool: + return _libllava.llava_eval_image_embed(ctx_llama, embed, n_batch, n_past) + +_libllava.llava_eval_image_embed.argtypes = [llama_cpp.llama_context_p, POINTER(llava_image_embed), c_int, POINTER(c_int)] +_libllava.llava_eval_image_embed.restype = c_bool + + +################################################ +# clip.h +################################################ + + +# struct clip_vision_hparams { +# int32_t image_size; +# int32_t patch_size; +# int32_t hidden_size; +# int32_t n_intermediate; +# int32_t projection_dim; +# int32_t n_head; +# int32_t n_layer; +# float eps; +# }; +class clip_vision_hparams(Structure): + _fields_ = [ + ("image_size", c_int32), + ("patch_size", c_int32), + ("hidden_size", c_int32), + ("n_intermediate", c_int32), + ("projection_dim", c_int32), + ("n_head", c_int32), + ("n_layer", c_int32), + ("eps", c_float), + ] + +# /** load mmproj model */ +# CLIP_API struct clip_ctx * clip_model_load(const char * fname, const int verbosity); +def clip_model_load(fname: bytes, verbosity: Union[c_int, int]) -> clip_ctx_p: + return _libllava.clip_model_load(fname, verbosity) + +_libllava.clip_model_load.argtypes = [c_char_p, c_int] +_libllava.clip_model_load.restype = clip_ctx_p + +# /** free mmproj model */ +# CLIP_API void clip_free(struct clip_ctx * ctx); +def clip_free(ctx: clip_ctx_p): + return _libllava.clip_free(ctx) + +_libllava.clip_free.argtypes = [clip_ctx_p] +_libllava.clip_free.restype = None + +# size_t clip_embd_nbytes(const struct clip_ctx * ctx); +# int clip_n_patches(const struct clip_ctx * ctx); +# int clip_n_mmproj_embd(const struct clip_ctx * ctx); + +# // RGB uint8 image +# struct clip_image_u8 { +# int nx; +# int ny; +# uint8_t * data = NULL; +# size_t size; +# }; + +# // RGB float32 image (NHWC) +# // Memory layout: RGBRGBRGB... +# struct clip_image_f32 { +# int nx; +# int ny; +# float * data = NULL; +# size_t size; +# }; + +# struct clip_image_u8_batch { +# struct clip_image_u8 * data; +# size_t size; +# }; + +# struct clip_image_f32_batch { +# struct clip_image_f32 * data; +# size_t size; +# }; + +# struct clip_image_u8 * make_clip_image_u8(); +# struct clip_image_f32 * make_clip_image_f32(); +# CLIP_API void clip_image_u8_free(clip_image_u8 * img); +# CLIP_API void clip_image_f32_free(clip_image_f32 * img); +# CLIP_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img); +# /** interpret bytes as an image file with length bytes_length, and use the result to populate img */ +# CLIP_API bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img); + +# bool clip_image_preprocess(const struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32 * res, const bool pad2square); +# bool clip_image_encode(const struct clip_ctx * ctx, const int n_threads, struct clip_image_f32 * img, float * vec); + +# bool clip_image_batch_encode(const struct clip_ctx * ctx, const int n_threads, const struct clip_image_f32_batch * imgs, +# float * vec); + +# bool clip_model_quantize(const char * fname_inp, const char * fname_out, const int itype); \ No newline at end of file diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index 93afc3e..8ebc427 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -138,6 +138,10 @@ class Settings(BaseSettings): default="llama-2", description="Chat format to use.", ) + clip_model_path: Optional[str] = Field( + default=None, + description="Path to a CLIP model to use for multi-modal chat completion.", + ) # Cache Params cache: bool = Field( default=False, @@ -375,6 +379,14 @@ def create_app(settings: Optional[Settings] = None): ) app.include_router(router) global llama + + ## + chat_handler = None + if settings.chat_format == "llava-1-5": + assert settings.clip_model_path is not None + chat_handler = llama_cpp.llama_chat_format.Llava15ChatHandler(clip_model_path=settings.clip_model_path) + ## + llama = llama_cpp.Llama( model_path=settings.model, # Model Params @@ -411,6 +423,7 @@ def create_app(settings: Optional[Settings] = None): numa=settings.numa, # Chat Format Params chat_format=settings.chat_format, + chat_handler=chat_handler, # Misc verbose=settings.verbose, ) @@ -580,10 +593,6 @@ class CreateCompletionRequest(BaseModel): max_tokens: int = max_tokens_field temperature: float = temperature_field top_p: float = top_p_field - mirostat_mode: int = mirostat_mode_field - mirostat_tau: float = mirostat_tau_field - mirostat_eta: float = mirostat_eta_field - grammar: Optional[str] = None echo: bool = Field( default=False, description="Whether to echo the prompt in the generated text. Useful for chatbots.", @@ -610,6 +619,10 @@ class CreateCompletionRequest(BaseModel): top_k: int = top_k_field repeat_penalty: float = repeat_penalty_field logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None) + mirostat_mode: int = mirostat_mode_field + mirostat_tau: float = mirostat_tau_field + mirostat_eta: float = mirostat_eta_field + grammar: Optional[str] = None model_config = { "json_schema_extra": { @@ -688,7 +701,7 @@ async def create_completion( kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar) iterator_or_completion: Union[ - llama_cpp.Completion, Iterator[llama_cpp.CompletionChunk] + llama_cpp.CreateCompletionResponse, Iterator[llama_cpp.CreateCompletionStreamResponse] ] = await run_in_threadpool(llama, **kwargs) if isinstance(iterator_or_completion, Iterator): @@ -697,7 +710,7 @@ async def create_completion( # If no exception was raised from first_response, we can assume that # the iterator is valid and we can use it to stream the response. - def iterator() -> Iterator[llama_cpp.CompletionChunk]: + def iterator() -> Iterator[llama_cpp.CreateCompletionStreamResponse]: yield first_response yield from iterator_or_completion @@ -748,27 +761,30 @@ class ChatCompletionRequestMessage(BaseModel): ) content: Optional[str] = Field(default="", description="The content of the message.") -from typing import Any class CreateChatCompletionRequest(BaseModel): - messages: List[Any] = Field( + messages: List[llama_cpp.ChatCompletionRequestMessage] = Field( default=[], description="A list of messages to generate completions for." ) functions: Optional[List[llama_cpp.ChatCompletionFunction]] = Field( default=None, description="A list of functions to apply to the generated completions.", ) - function_call: Optional[Union[Literal["auto", "none"], llama_cpp.ChatCompletionFunctionCallOption]] = Field( + function_call: Optional[llama_cpp.ChatCompletionRequestFunctionCall] = Field( default=None, description="A function to apply to the generated completions.", ) + tools: Optional[List[llama_cpp.ChatCompletionTool]] = Field( + default=None, + description="A list of tools to apply to the generated completions.", + ) + tool_choice: Optional[llama_cpp.ChatCompletionToolChoiceOption] = Field( + default=None, + description="A tool to apply to the generated completions.", + ) # TODO: verify max_tokens: int = max_tokens_field temperature: float = temperature_field top_p: float = top_p_field - mirostat_mode: int = mirostat_mode_field - mirostat_tau: float = mirostat_tau_field - mirostat_eta: float = mirostat_eta_field - grammar: Optional[str] = None stop: Optional[List[str]] = stop_field stream: bool = stream_field presence_penalty: Optional[float] = presence_penalty_field @@ -784,6 +800,10 @@ class CreateChatCompletionRequest(BaseModel): top_k: int = top_k_field repeat_penalty: float = repeat_penalty_field logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None) + mirostat_mode: int = mirostat_mode_field + mirostat_tau: float = mirostat_tau_field + mirostat_eta: float = mirostat_eta_field + grammar: Optional[str] = None model_config = { "json_schema_extra": { diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 2833a6f..381efbf 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 2833a6f63c1b87c7f4ac574bcf7a15a2f3bf3ede +Subproject commit 381efbf480959bb6d1e247a8b0c2328f22e350f8