Extract generate method

This commit is contained in:
Andrei Betlen 2023-03-28 02:42:22 -04:00
parent 1c823f6d0f
commit 30fc0f3866

View file

@ -128,6 +128,20 @@ class Llama:
repeat_penalty=repeat_penalty,
)
def _generate(self, past_tokens, max_tokens, top_p, top_k, temp, repeat_penalty):
last_n_tokens = deque([0] * self.last_n, maxlen=self.last_n)
last_n_tokens.extend(past_tokens)
for i in range(max_tokens):
token = self._sample(
last_n_tokens,
top_p=top_p,
top_k=top_k,
temp=temp,
repeat_penalty=repeat_penalty
)
yield token
self._eval([token], len(past_tokens) + i)
def __call__(
self,
prompt: str,
@ -162,8 +176,9 @@ class Llama:
Returns:
Response object containing the generated text.
"""
completion_id = f"cmpl-{str(uuid.uuid4())}"
created= int(time.time())
text = b""
finish_reason = "length"
completion_tokens = []
last_n_tokens = deque([0] * self.last_n, maxlen=self.last_n)
@ -182,14 +197,8 @@ class Llama:
if stop is not None:
stop = [s.encode("utf-8") for s in stop]
for i in range(max_tokens):
token = self._sample(
last_n_tokens,
top_p=top_p,
top_k=top_k,
temp=temperature,
repeat_penalty=repeat_penalty
)
finish_reason = None
for token in self._generate(prompt_tokens, max_tokens, top_p, top_k, temperature, repeat_penalty):
if token == llama_cpp.llama_token_eos():
finish_reason = "stop"
break
@ -204,7 +213,8 @@ class Llama:
finish_reason = "stop"
break
self._eval([token], len(prompt_tokens) + len(completion_tokens))
if finish_reason is None:
finish_reason = "length"
text = text.decode("utf-8")
@ -220,9 +230,9 @@ class Llama:
)[:logprobs]
return {
"id": f"cmpl-{str(uuid.uuid4())}", # Likely to change
"id": completion_id,
"object": "text_completion",
"created": int(time.time()),
"created": created,
"model": self.model_path,
"choices": [
{