diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index bf124f9..cd16bca 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -128,6 +128,20 @@ class Llama: repeat_penalty=repeat_penalty, ) + def _generate(self, past_tokens, max_tokens, top_p, top_k, temp, repeat_penalty): + last_n_tokens = deque([0] * self.last_n, maxlen=self.last_n) + last_n_tokens.extend(past_tokens) + for i in range(max_tokens): + token = self._sample( + last_n_tokens, + top_p=top_p, + top_k=top_k, + temp=temp, + repeat_penalty=repeat_penalty + ) + yield token + self._eval([token], len(past_tokens) + i) + def __call__( self, prompt: str, @@ -162,8 +176,9 @@ class Llama: Returns: Response object containing the generated text. """ + completion_id = f"cmpl-{str(uuid.uuid4())}" + created= int(time.time()) text = b"" - finish_reason = "length" completion_tokens = [] last_n_tokens = deque([0] * self.last_n, maxlen=self.last_n) @@ -182,14 +197,8 @@ class Llama: if stop is not None: stop = [s.encode("utf-8") for s in stop] - for i in range(max_tokens): - token = self._sample( - last_n_tokens, - top_p=top_p, - top_k=top_k, - temp=temperature, - repeat_penalty=repeat_penalty - ) + finish_reason = None + for token in self._generate(prompt_tokens, max_tokens, top_p, top_k, temperature, repeat_penalty): if token == llama_cpp.llama_token_eos(): finish_reason = "stop" break @@ -204,7 +213,8 @@ class Llama: finish_reason = "stop" break - self._eval([token], len(prompt_tokens) + len(completion_tokens)) + if finish_reason is None: + finish_reason = "length" text = text.decode("utf-8") @@ -220,9 +230,9 @@ class Llama: )[:logprobs] return { - "id": f"cmpl-{str(uuid.uuid4())}", # Likely to change + "id": completion_id, "object": "text_completion", - "created": int(time.time()), + "created": created, "model": self.model_path, "choices": [ {