llama.cpp/examples/low_level_api/low_level_api_chatllama_cpp.py

293 lines
9.1 KiB
Python
Raw Normal View History

2023-04-03 20:54:46 +00:00
"""
This is an example implementation of main.cpp from llama.cpp
Quirks:
* Its not exactly alike since this port is designed around programmatic I/O
* Input is always echoed if on, so it should be turned off when using "input()"
* The first antiprompt should be the userprompt like "\nUser:",
because its added when n_predict is reached (aka generation ended prematurely)
2023-04-04 09:48:48 +00:00
* n_predict can be set to -1 for unlimited length responses (or just a really high value)
* It's always in interactive mode, generation ends either by reaching an antiprompt
or running out of n_predict.
* Instruction mode adds its own antiprompt.
You should also still be feeding the model with a "primer" prompt that
shows it the expected format.
2023-04-03 20:54:46 +00:00
"""
import llama_cpp
# A LLaMA interactive session
class LLaMAInteract:
def __init__(self,
primer: str="",
model: str="./models/30B/ggml-model-q4_0.bin",
2023-04-04 09:48:48 +00:00
instruct: bool=False,
2023-04-03 20:54:46 +00:00
n_ctx: int=1024,
seed: int=0,
n_threads: int=8,
antiprompt: list[str]=[],
input_echo: bool=True,
n_predict: int=20,
2023-04-04 09:48:48 +00:00
n_keep: int=0,
2023-04-03 20:54:46 +00:00
n_batch: int=8,
repeat_last_n: int=64,
top_k: int=50,
top_p: float=1.,
temp: float=1.0,
repeat_penalty: float=1,
init_break: bool=True,
instruct_inp_prefix: str="\n\n### Instruction:\n\n",
instruct_inp_suffix: str="\n\n### Response:\n\n",
2023-04-03 20:54:46 +00:00
) -> None:
# input args
2023-04-04 09:48:48 +00:00
self.instruct = instruct
2023-04-03 20:54:46 +00:00
self.n_threads = n_threads
self.input_echo = input_echo
self.n_predict = n_predict
2023-04-04 09:48:48 +00:00
self.n_keep = n_keep
2023-04-03 20:54:46 +00:00
self.n_batch = n_batch
self.repeat_last_n = repeat_last_n
self.top_k=top_k
self.top_p=top_p
self.temp=temp
self.repeat_penalty=repeat_penalty
self.init_break = init_break
2023-04-03 20:54:46 +00:00
# runtime args
self.input_consumed = 0
self.embd = []
self.embd_inp = []
self.n_past = 0
self.first_antiprompt = []
self.remaining_tokens = self.n_predict
self.output_echo = input_echo
# model load
self.lparams = llama_cpp.llama_context_default_params()
2023-04-04 09:48:48 +00:00
self.lparams.n_ctx = n_ctx
self.lparams.seed = seed
2023-04-03 20:54:46 +00:00
self.ctx = llama_cpp.llama_init_from_file(model.encode("utf8"), self.lparams)
# determine the required inference memory per token:
tmp = [0, 1, 2, 3]
llama_cpp.llama_eval(self.ctx, (llama_cpp.c_int * len(tmp))(*tmp), len(tmp), 0, self.n_threads)
# determine newline token
2023-04-04 09:48:48 +00:00
self.llama_token_newline = self._tokenize("\n", False)
self.inp_prefix = self._tokenize(instruct_inp_prefix)
self.inp_suffix = self._tokenize(instruct_inp_suffix, False)
2023-04-04 09:48:48 +00:00
# add instruction as antiprompt
if (self.instruct):
self.first_antiprompt.append(self._tokenize(instruct_inp_prefix.strip(), False))
2023-04-03 20:54:46 +00:00
# primer feed
if (len(primer) > 0):
2023-04-04 09:48:48 +00:00
self.embd_inp += self._tokenize(primer)
# number of tokens to keep when resetting context
if (self.n_keep < 0 or self.n_keep > len(self.embd_inp) or self.instruct):
self.n_keep = len(self.embd_inp)
2023-04-03 20:54:46 +00:00
# create internal context
2023-04-04 09:48:48 +00:00
self.n_ctx = llama_cpp.llama_n_ctx(self.ctx)
2023-04-03 20:54:46 +00:00
self.last_n_tokens = [0]*self.n_ctx #TODO: deque doesnt support slices
# determine antiprompt tokens
for i in antiprompt:
2023-04-04 09:48:48 +00:00
self.first_antiprompt.append(self._tokenize(i, False))
# tokenize a prompt
def _tokenize(self, prompt, bos=True):
_arr = (llama_cpp.llama_token * (len(prompt) + 1))()
_n = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8"), _arr, len(_arr), bos)
return _arr[:_n]
2023-04-03 20:54:46 +00:00
# if an antiprompt is present
def use_antiprompt(self):
return len(self.first_antiprompt) > 0
2023-04-04 09:48:48 +00:00
# generate tokens
2023-04-03 20:54:46 +00:00
def generate(self):
while self.remaining_tokens > 0 or self.use_antiprompt():
# predict
if len(self.embd) > 0:
# infinite text generation via context swapping
# if we run out of context:
# - take the n_keep first tokens from the original prompt (via n_past)
# - take half of the last (n_ctx - n_keep) tokens and recompute the logits in a batch
if (self.n_past + len(self.embd) > self.n_ctx):
n_left = self.n_past - self.n_keep
self.n_past = self.n_keep
# insert n_left/2 tokens at the start of embd from last_n_tokens
_insert = self.last_n_tokens[
self.n_ctx - int(n_left/2) - len(self.embd):-len(self.embd)
2023-04-03 20:54:46 +00:00
]
self.embd = _insert + self.embd
2023-04-03 20:54:46 +00:00
if (llama_cpp.llama_eval(
self.ctx, (llama_cpp.llama_token * len(self.embd))(*self.embd), len(self.embd), self.n_past, self.n_threads
) != 0):
raise Exception("Failed to llama_eval!")
self.n_past += len(self.embd)
self.embd = []
if len(self.embd_inp) <= self.input_consumed:
# out of user input, sample next token
_arr = self.last_n_tokens[-min(self.repeat_last_n, self.n_past):]
id = llama_cpp.llama_sample_top_p_top_k(
self.ctx,
(llama_cpp.llama_token * len(_arr))(*_arr),
len(_arr),
self.top_k,
self.top_p,
self.temp,
self.repeat_penalty,
)
self.last_n_tokens.pop(0)
2023-04-04 09:48:48 +00:00
self.last_n_tokens.append(id)
2023-04-03 20:54:46 +00:00
# replace end of text token with newline token when in interactive mode
2023-04-04 09:48:48 +00:00
if (id == llama_cpp.llama_token_eos() and self.use_antiprompt() and not self.instruct):
2023-04-03 20:54:46 +00:00
id = self.llama_token_newline[0]
# tokenize and inject first reverse prompt
self.embd_inp += self.first_antiprompt[0]
# add it to the context
2023-04-04 09:48:48 +00:00
self.embd.append(id)
2023-04-03 20:54:46 +00:00
# echo this to console
self.output_echo = True
# decrement remaining sampling budget
self.remaining_tokens -= 1
else:
# output to console if input echo is on
self.output_echo = self.input_echo
# some user input remains from prompt or interaction, forward it to processing
while len(self.embd_inp) > self.input_consumed:
2023-04-04 09:48:48 +00:00
self.embd.append(self.embd_inp[self.input_consumed])
2023-04-03 20:54:46 +00:00
self.last_n_tokens.pop(0)
2023-04-04 09:48:48 +00:00
self.last_n_tokens.append(self.embd_inp[self.input_consumed])
2023-04-03 20:54:46 +00:00
self.input_consumed += 1
if len(self.embd) >= self.n_batch:
break
# display tokens
if self.output_echo:
for id in self.embd:
yield id
2023-04-04 09:48:48 +00:00
if (len(self.embd_inp) <= self.input_consumed):
# if antiprompt is present, stop
if (self.use_antiprompt()):
if True in [
i == self.last_n_tokens[-len(i):]
for i in self.first_antiprompt
]:
break
2023-04-04 09:48:48 +00:00
# if we are using instruction mode, and we have processed the initial prompt
if (self.init_break):
break
2023-04-03 20:54:46 +00:00
# if end of generation
if len(self.embd) > 0 and self.embd[-1] == llama_cpp.llama_token_eos():
break
# respect n_predict even if antiprompt is present
if (self.use_antiprompt() and self.remaining_tokens <= 0 and self.n_predict != -1):
if not self.instruct:
self.embd_inp += self.first_antiprompt[0]
2023-04-03 20:54:46 +00:00
break
self.init_break = False
def __enter__(self):
return self
def __exit__(self, type, value, tb):
llama_cpp.llama_free(self.ctx)
2023-04-04 09:48:48 +00:00
# return past text
2023-04-03 20:54:46 +00:00
def past(self):
for id in self.last_n_tokens[-self.n_past:]:
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8")
2023-04-04 09:48:48 +00:00
# write input
2023-04-03 20:54:46 +00:00
def input(self, prompt: str):
if (self.instruct and self.last_n_tokens[-len(self.inp_prefix):] != self.inp_prefix):
2023-04-04 09:48:48 +00:00
self.embd_inp += self.inp_prefix
self.embd_inp += self._tokenize(prompt)
2023-04-04 09:48:48 +00:00
if (self.instruct):
self.embd_inp += self.inp_suffix
2023-04-03 20:54:46 +00:00
2023-04-04 09:48:48 +00:00
# write output
2023-04-03 20:54:46 +00:00
def output(self):
self.remaining_tokens = self.n_predict
for id in self.generate():
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8")
if __name__ == "__main__":
from datetime import datetime
USER_NAME="User"
AI_NAME="ChatLLaMa"
2023-04-04 09:48:48 +00:00
2023-04-03 20:54:46 +00:00
time_now = datetime.now()
prompt = f"""Text transcript of a never ending dialog, where {USER_NAME} interacts with an AI assistant named {AI_NAME}.
{AI_NAME} is helpful, kind, honest, friendly, good at writing and never fails to answer {USER_NAME}s requests immediately and with details and precision.
There are no annotations like (30 seconds passed...) or (to himself), just what {USER_NAME} and {AI_NAME} say aloud to each other.
The dialog lasts for years, the entirety of it is shared below. It's 10000 pages long.
The transcript only includes text, it does not include markup like HTML and Markdown.
{USER_NAME}: Hello, {AI_NAME}!
{AI_NAME}: Hello {USER_NAME}! How may I help you today?
{USER_NAME}: What time is it?
{AI_NAME}: It is {time_now.strftime("%H:%M")}.
{USER_NAME}: What year is it?
{AI_NAME}: We are in {time_now.strftime("%Y")}.
{USER_NAME}: What is a cat?
{AI_NAME}: A cat is a domestic species of small carnivorous mammal. It is the only domesticated species in the family Felidae.
{USER_NAME}: Name a color.
{AI_NAME}: Blue
{USER_NAME}:"""
print("Loading model...")
with LLaMAInteract(prompt,
2023-04-03 20:54:46 +00:00
model="./models/30B/ggml-model-q4_0.bin",
n_ctx=2048,
antiprompt=[f"\n{USER_NAME}:"],
repeat_last_n=256,
n_predict=2048,
temp=0.7, top_p=0.5, top_k=40, repeat_penalty=1.17647
) as m:
print("Loaded model!")
2023-04-03 20:54:46 +00:00
2023-04-04 09:48:48 +00:00
for i in m.output():
print(i,end="",flush=True)
m.input_echo = False
def inp():
out = ""
while (t := input()).endswith("\\"):
out += t[:-1] + "\n"
return out + t + "\n"
while True:
if (m.instruct):
print('\n> ', end="")
m.input(inp())
else:
print(f" ", end="")
m.input(f" {inp()}{AI_NAME}:")
print(f"{AI_NAME}: ",end="")
try:
for i in m.output():
print(i,end="",flush=True)
except KeyboardInterrupt:
if not m.instruct:
print(f"\n{USER_NAME}:",end="")
m.input(f"\n{USER_NAME}:")