From eb7645b3ba84e182a903663d68c0b4864b670f9b Mon Sep 17 00:00:00 2001 From: Tanner Hobson Date: Fri, 9 Jun 2023 13:13:08 -0400 Subject: [PATCH] Add support for logit_bias and logit_bias_type parameters --- llama_cpp/llama.py | 2 ++ llama_cpp/server/app.py | 53 +++++++++++++++++++++++++++++++++++++++-- 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 02fe774..197511c 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -1380,6 +1380,7 @@ class Llama: mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, model: Optional[str] = None, + logits_processor: Optional[LogitsProcessorList] = None, ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]: """Generate a chat completion from a list of messages. @@ -1421,6 +1422,7 @@ class Llama: mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, model=model, + logits_processor=logits_processor, ) if stream: chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index f70d8f0..a6194f5 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -249,13 +249,14 @@ class CreateCompletionRequest(BaseModel): ) presence_penalty: Optional[float] = presence_penalty_field frequency_penalty: Optional[float] = frequency_penalty_field + logit_bias: Optional[Dict[str, float]] = Field(None) + logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None) # ignored or currently unsupported model: Optional[str] = model_field n: Optional[int] = 1 logprobs: Optional[int] = Field(None) best_of: Optional[int] = 1 - logit_bias: Optional[Dict[str, float]] = Field(None) user: Optional[str] = Field(None) # llama.cpp specific parameters @@ -274,6 +275,39 @@ class CreateCompletionRequest(BaseModel): CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion) +def make_logit_bias_processor( + llama: llama_cpp.Llama, + logit_bias: Dict[str, float], + logit_bias_type: Optional[Literal["input_ids", "tokens"]], +): + if logit_bias_type is None: + logit_bias_type = "input_ids" + + to_bias: Dict[int, float] = {} + if logit_bias_type == "input_ids": + for input_id, score in logit_bias.items(): + input_id = int(input_id) + to_bias[input_id] = score + + elif logit_bias_type == "tokens": + for token, score in logit_bias.items(): + token = token.encode('utf-8') + for input_id in llama.tokenize(token, add_bos=False): + to_bias[input_id] = score + + def logit_bias_processor( + input_ids: List[int], + scores: List[float], + ) -> List[float]: + new_scores = [None] * len(scores) + for input_id, score in enumerate(scores): + new_scores[input_id] = score + to_bias.get(input_id, 0.0) + + return new_scores + + return logit_bias_processor + + @router.post( "/v1/completions", response_model=CreateCompletionResponse, @@ -291,9 +325,16 @@ async def create_completion( "n", "best_of", "logit_bias", + "logit_bias_type", "user", } kwargs = body.dict(exclude=exclude) + + if body.logit_bias is not None: + kwargs['logits_processor'] = llama_cpp.LogitsProcessorList([ + make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type), + ]) + if body.stream: send_chan, recv_chan = anyio.create_memory_object_stream(10) @@ -372,11 +413,12 @@ class CreateChatCompletionRequest(BaseModel): stream: bool = stream_field presence_penalty: Optional[float] = presence_penalty_field frequency_penalty: Optional[float] = frequency_penalty_field + logit_bias: Optional[Dict[str, float]] = Field(None) + logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None) # ignored or currently unsupported model: Optional[str] = model_field n: Optional[int] = 1 - logit_bias: Optional[Dict[str, float]] = Field(None) user: Optional[str] = Field(None) # llama.cpp specific parameters @@ -413,9 +455,16 @@ async def create_chat_completion( exclude = { "n", "logit_bias", + "logit_bias_type", "user", } kwargs = body.dict(exclude=exclude) + + if body.logit_bias is not None: + kwargs['logits_processor'] = llama_cpp.LogitsProcessorList([ + make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type), + ]) + if body.stream: send_chan, recv_chan = anyio.create_memory_object_stream(10)