Extract generate method

This commit is contained in:
Andrei Betlen 2023-03-28 02:42:22 -04:00
parent 1c823f6d0f
commit 30fc0f3866

View file

@ -128,6 +128,20 @@ class Llama:
repeat_penalty=repeat_penalty, repeat_penalty=repeat_penalty,
) )
def _generate(self, past_tokens, max_tokens, top_p, top_k, temp, repeat_penalty):
last_n_tokens = deque([0] * self.last_n, maxlen=self.last_n)
last_n_tokens.extend(past_tokens)
for i in range(max_tokens):
token = self._sample(
last_n_tokens,
top_p=top_p,
top_k=top_k,
temp=temp,
repeat_penalty=repeat_penalty
)
yield token
self._eval([token], len(past_tokens) + i)
def __call__( def __call__(
self, self,
prompt: str, prompt: str,
@ -162,8 +176,9 @@ class Llama:
Returns: Returns:
Response object containing the generated text. Response object containing the generated text.
""" """
completion_id = f"cmpl-{str(uuid.uuid4())}"
created= int(time.time())
text = b"" text = b""
finish_reason = "length"
completion_tokens = [] completion_tokens = []
last_n_tokens = deque([0] * self.last_n, maxlen=self.last_n) last_n_tokens = deque([0] * self.last_n, maxlen=self.last_n)
@ -182,14 +197,8 @@ class Llama:
if stop is not None: if stop is not None:
stop = [s.encode("utf-8") for s in stop] stop = [s.encode("utf-8") for s in stop]
for i in range(max_tokens): finish_reason = None
token = self._sample( for token in self._generate(prompt_tokens, max_tokens, top_p, top_k, temperature, repeat_penalty):
last_n_tokens,
top_p=top_p,
top_k=top_k,
temp=temperature,
repeat_penalty=repeat_penalty
)
if token == llama_cpp.llama_token_eos(): if token == llama_cpp.llama_token_eos():
finish_reason = "stop" finish_reason = "stop"
break break
@ -204,7 +213,8 @@ class Llama:
finish_reason = "stop" finish_reason = "stop"
break break
self._eval([token], len(prompt_tokens) + len(completion_tokens)) if finish_reason is None:
finish_reason = "length"
text = text.decode("utf-8") text = text.decode("utf-8")
@ -220,9 +230,9 @@ class Llama:
)[:logprobs] )[:logprobs]
return { return {
"id": f"cmpl-{str(uuid.uuid4())}", # Likely to change "id": completion_id,
"object": "text_completion", "object": "text_completion",
"created": int(time.time()), "created": created,
"model": self.model_path, "model": self.model_path,
"choices": [ "choices": [
{ {