Update api to allow for easier interactive mode
This commit is contained in:
parent
eef627c09c
commit
a4a1bbeaa9
|
@ -63,6 +63,11 @@ class Llama:
|
||||||
self.params.embedding = embedding
|
self.params.embedding = embedding
|
||||||
|
|
||||||
self.last_n_tokens_size = last_n_tokens_size
|
self.last_n_tokens_size = last_n_tokens_size
|
||||||
|
self.last_n_tokens_data = deque(
|
||||||
|
[llama_cpp.llama_token(0)] * self.last_n_tokens_size,
|
||||||
|
maxlen=self.last_n_tokens_size,
|
||||||
|
)
|
||||||
|
self.tokens_consumed = 0
|
||||||
self.n_batch = n_batch
|
self.n_batch = n_batch
|
||||||
|
|
||||||
self.n_threads = n_threads or multiprocessing.cpu_count()
|
self.n_threads = n_threads or multiprocessing.cpu_count()
|
||||||
|
@ -115,6 +120,67 @@ class Llama:
|
||||||
output += llama_cpp.llama_token_to_str(self.ctx, token)
|
output += llama_cpp.llama_token_to_str(self.ctx, token)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Reset the model state."""
|
||||||
|
self.last_n_tokens_data.extend(
|
||||||
|
[llama_cpp.llama_token(0)] * self.last_n_tokens_size
|
||||||
|
)
|
||||||
|
self.tokens_consumed = 0
|
||||||
|
|
||||||
|
def eval(self, tokens: Sequence[llama_cpp.llama_token]):
|
||||||
|
"""Evaluate a list of tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens: The list of tokens to evaluate.
|
||||||
|
"""
|
||||||
|
assert self.ctx is not None
|
||||||
|
n_ctx = int(llama_cpp.llama_n_ctx(self.ctx))
|
||||||
|
for i in range(0, len(tokens), self.n_batch):
|
||||||
|
batch = tokens[i : min(len(tokens), i + self.n_batch)]
|
||||||
|
n_past = min(n_ctx - len(batch), self.tokens_consumed)
|
||||||
|
return_code = llama_cpp.llama_eval(
|
||||||
|
ctx=self.ctx,
|
||||||
|
tokens=(llama_cpp.llama_token * len(batch))(*batch),
|
||||||
|
n_tokens=llama_cpp.c_int(len(batch)),
|
||||||
|
n_past=llama_cpp.c_int(n_past),
|
||||||
|
n_threads=llama_cpp.c_int(self.n_threads),
|
||||||
|
)
|
||||||
|
if int(return_code) != 0:
|
||||||
|
raise RuntimeError(f"llama_eval returned {return_code}")
|
||||||
|
self.last_n_tokens_data.extend(batch)
|
||||||
|
self.tokens_consumed += len(batch)
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
top_k: int,
|
||||||
|
top_p: float,
|
||||||
|
temp: float,
|
||||||
|
repeat_penalty: float,
|
||||||
|
):
|
||||||
|
"""Sample a token from the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
top_k: The top-k sampling parameter.
|
||||||
|
top_p: The top-p sampling parameter.
|
||||||
|
temp: The temperature parameter.
|
||||||
|
repeat_penalty: The repeat penalty parameter.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The sampled token.
|
||||||
|
"""
|
||||||
|
assert self.ctx is not None
|
||||||
|
return llama_cpp.llama_sample_top_p_top_k(
|
||||||
|
ctx=self.ctx,
|
||||||
|
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
|
||||||
|
*self.last_n_tokens_data
|
||||||
|
),
|
||||||
|
last_n_tokens_size=llama_cpp.c_int(self.last_n_tokens_size),
|
||||||
|
top_k=llama_cpp.c_int(top_k),
|
||||||
|
top_p=llama_cpp.c_float(top_p),
|
||||||
|
temp=llama_cpp.c_float(temp),
|
||||||
|
repeat_penalty=llama_cpp.c_float(repeat_penalty),
|
||||||
|
)
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
tokens: Sequence[llama_cpp.llama_token],
|
tokens: Sequence[llama_cpp.llama_token],
|
||||||
|
@ -125,7 +191,7 @@ class Llama:
|
||||||
) -> Generator[
|
) -> Generator[
|
||||||
llama_cpp.llama_token, Optional[Sequence[llama_cpp.llama_token]], None
|
llama_cpp.llama_token, Optional[Sequence[llama_cpp.llama_token]], None
|
||||||
]:
|
]:
|
||||||
"""Generate tokens.
|
"""Create a generator of tokens from a prompt.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> llama = Llama("models/ggml-7b.bin")
|
>>> llama = Llama("models/ggml-7b.bin")
|
||||||
|
@ -149,37 +215,14 @@ class Llama:
|
||||||
top_p = 0.0
|
top_p = 0.0
|
||||||
top_k = 1
|
top_k = 1
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
n_ctx = int(llama_cpp.llama_n_ctx(self.ctx))
|
self.reset()
|
||||||
n_tokens = 0
|
|
||||||
last_n_tokens = deque(
|
|
||||||
[llama_cpp.llama_token(0)] * self.last_n_tokens_size,
|
|
||||||
maxlen=self.last_n_tokens_size,
|
|
||||||
)
|
|
||||||
while True:
|
while True:
|
||||||
for i in range(0, len(tokens), self.n_batch):
|
self.eval(tokens)
|
||||||
batch = tokens[i : min(len(tokens), i + self.n_batch)]
|
token = self.sample(
|
||||||
n_past = min(n_ctx - len(batch), n_tokens)
|
top_k=top_k,
|
||||||
return_code = llama_cpp.llama_eval(
|
top_p=top_p,
|
||||||
ctx=self.ctx,
|
temp=temp,
|
||||||
tokens=(llama_cpp.llama_token * len(batch))(*batch),
|
repeat_penalty=repeat_penalty,
|
||||||
n_tokens=llama_cpp.c_int(len(batch)),
|
|
||||||
n_past=llama_cpp.c_int(n_past),
|
|
||||||
n_threads=llama_cpp.c_int(self.n_threads),
|
|
||||||
)
|
|
||||||
if int(return_code) != 0:
|
|
||||||
raise RuntimeError(f"llama_eval returned {return_code}")
|
|
||||||
last_n_tokens.extend(batch)
|
|
||||||
n_tokens += len(batch)
|
|
||||||
token = llama_cpp.llama_sample_top_p_top_k(
|
|
||||||
ctx=self.ctx,
|
|
||||||
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
|
|
||||||
*last_n_tokens
|
|
||||||
),
|
|
||||||
last_n_tokens_size=llama_cpp.c_int(self.last_n_tokens_size),
|
|
||||||
top_k=llama_cpp.c_int(top_k),
|
|
||||||
top_p=llama_cpp.c_float(top_p),
|
|
||||||
temp=llama_cpp.c_float(temp),
|
|
||||||
repeat_penalty=llama_cpp.c_float(repeat_penalty),
|
|
||||||
)
|
)
|
||||||
tokens_or_none = yield token
|
tokens_or_none = yield token
|
||||||
tokens = [token]
|
tokens = [token]
|
||||||
|
@ -197,7 +240,8 @@ class Llama:
|
||||||
"""
|
"""
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
tokens = self.tokenize(input.encode("utf-8"))
|
tokens = self.tokenize(input.encode("utf-8"))
|
||||||
next(self.generate(tokens, top_k=0, top_p=0.0, temp=1.0, repeat_penalty=1.0))
|
self.reset()
|
||||||
|
self.eval(tokens)
|
||||||
n_tokens = len(tokens)
|
n_tokens = len(tokens)
|
||||||
embedding = llama_cpp.llama_get_embeddings(self.ctx)[
|
embedding = llama_cpp.llama_get_embeddings(self.ctx)[
|
||||||
: llama_cpp.llama_n_embd(self.ctx)
|
: llama_cpp.llama_n_embd(self.ctx)
|
||||||
|
|
Loading…
Reference in a new issue