Add instruction mode

This commit is contained in:
Mug 2023-04-04 11:48:48 +02:00
parent f1615f05e6
commit 0b32bb3d43

View file

@ -5,24 +5,26 @@ Quirks:
* Input is always echoed if on, so it should be turned off when using "input()"
* The first antiprompt should be the userprompt like "\nUser:",
because its added when n_predict is reached (aka generation ended prematurely)
* n_predict can be set to -1 for unlimited length responses
* n_predict can be set to -1 for unlimited length responses (or just a really high value)
* It's always in interactive mode, generation ends either by reaching an antiprompt
or running out of n_predict.
* Instruction mode adds its own antiprompt
"""
import llama_cpp
def toIntArray(lst):
return [int(i) for i in lst]
# A LLaMA interactive session
class LLaMAInteract:
def __init__(self,
primer: str="",
model: str="./models/30B/ggml-model-q4_0.bin",
instruct: bool=False,
n_ctx: int=1024,
seed: int=0,
n_threads: int=8,
antiprompt: list[str]=[],
input_echo: bool=True,
n_predict: int=20,
n_keep: int=0,
n_batch: int=8,
repeat_last_n: int=64,
top_k: int=50,
@ -31,17 +33,17 @@ class LLaMAInteract:
repeat_penalty: float=1,
) -> None:
# input args
self.instruct = instruct
self.n_threads = n_threads
self.input_echo = input_echo
self.n_predict = n_predict
self.n_keep = n_keep
self.n_batch = n_batch
self.repeat_last_n = repeat_last_n
self.top_k=top_k
self.top_p=top_p
self.temp=temp
self.repeat_penalty=repeat_penalty
self.n_ctx = n_ctx
self.seed = seed
# runtime args
self.input_consumed = 0
@ -54,8 +56,8 @@ class LLaMAInteract:
# model load
self.lparams = llama_cpp.llama_context_default_params()
self.lparams.n_ctx = self.n_ctx
self.lparams.seed = self.seed
self.lparams.n_ctx = n_ctx
self.lparams.seed = seed
self.ctx = llama_cpp.llama_init_from_file(model.encode("utf8"), self.lparams)
# determine the required inference memory per token:
@ -63,29 +65,44 @@ class LLaMAInteract:
llama_cpp.llama_eval(self.ctx, (llama_cpp.c_int * len(tmp))(*tmp), len(tmp), 0, self.n_threads)
# determine newline token
self.llama_token_newline = (llama_cpp.llama_token * 1)()
llama_cpp.llama_tokenize(self.ctx, b"\n", self.llama_token_newline, len(self.llama_token_newline), False)
self.llama_token_newline = toIntArray(self.llama_token_newline)
self.llama_token_newline = self._tokenize("\n", False)
self.inp_prefix = self._tokenize("\n\n### Instruction:\n\n")
self.inp_suffix = self._tokenize("\n\n### Response:\n\n", False)
# add instruction as antiprompt
if (self.instruct):
self.first_antiprompt.append(self.inp_prefix)
# primer feed
if (len(primer) > 0):
self.input(primer)
self.n_keep = len(self.embd_inp)
self.embd_inp += self._tokenize(primer)
# break immediately if using instruct
self.init_break = self.instruct
# number of tokens to keep when resetting context
if (self.n_keep < 0 or self.n_keep > len(self.embd_inp) or self.instruct):
self.n_keep = len(self.embd_inp)
# create internal context
self.n_ctx = int(llama_cpp.llama_n_ctx(self.ctx))
self.n_ctx = llama_cpp.llama_n_ctx(self.ctx)
self.last_n_tokens = [0]*self.n_ctx #TODO: deque doesnt support slices
# determine antiprompt tokens
for i in antiprompt:
d_antiprompt = (llama_cpp.llama_token * (len(i) + 1))()
n_antiprompt = llama_cpp.llama_tokenize(self.ctx, i.encode("utf8"), d_antiprompt, len(d_antiprompt), False)
self.first_antiprompt.append(toIntArray(d_antiprompt[:n_antiprompt]))
self.first_antiprompt.append(self._tokenize(i, False))
# tokenize a prompt
def _tokenize(self, prompt, bos=True):
_arr = (llama_cpp.llama_token * (len(prompt) + 1))()
_n = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8"), _arr, len(_arr), bos)
return _arr[:_n]
# if an antiprompt is present
def use_antiprompt(self):
return len(self.first_antiprompt) > 0
# generate tokens
def generate(self):
while self.remaining_tokens > 0 or self.use_antiprompt():
# predict
@ -125,16 +142,16 @@ class LLaMAInteract:
self.repeat_penalty,
)
self.last_n_tokens.pop(0)
self.last_n_tokens.append(int(id))
self.last_n_tokens.append(id)
# replace end of text token with newline token when in interactive mode
if (id == llama_cpp.llama_token_eos() and self.use_antiprompt()):
if (id == llama_cpp.llama_token_eos() and self.use_antiprompt() and not self.instruct):
id = self.llama_token_newline[0]
# tokenize and inject first reverse prompt
self.embd_inp += self.first_antiprompt[0]
# add it to the context
self.embd.append(int(id))
self.embd.append(id)
# echo this to console
self.output_echo = True
@ -147,9 +164,9 @@ class LLaMAInteract:
# some user input remains from prompt or interaction, forward it to processing
while len(self.embd_inp) > self.input_consumed:
self.embd.append(int(self.embd_inp[self.input_consumed]))
self.embd.append(self.embd_inp[self.input_consumed])
self.last_n_tokens.pop(0)
self.last_n_tokens.append(int(self.embd_inp[self.input_consumed]))
self.last_n_tokens.append(self.embd_inp[self.input_consumed])
self.input_consumed += 1
if len(self.embd) >= self.n_batch:
break
@ -159,11 +176,17 @@ class LLaMAInteract:
for id in self.embd:
yield id
# if antiprompt is present, stop
if (self.use_antiprompt() and len(self.embd_inp) <= self.input_consumed):
for i in self.first_antiprompt:
if i == self.last_n_tokens[-len(i):]:
return
if (len(self.embd_inp) <= self.input_consumed):
# if antiprompt is present, stop
if (self.use_antiprompt()):
for i in self.first_antiprompt:
if i == self.last_n_tokens[-len(i):]:
return
# if we are using instruction mode, and we have processed the initial prompt
if (self.init_break):
self.init_break = False
break
# if end of generation
if len(self.embd) > 0 and self.embd[-1] == llama_cpp.llama_token_eos():
@ -174,15 +197,20 @@ class LLaMAInteract:
self.embd_inp += self.first_antiprompt[0]
break
# return past text
def past(self):
for id in self.last_n_tokens[-self.n_past:]:
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8")
# write input
def input(self, prompt: str):
embd_arr = (llama_cpp.llama_token * (len(prompt) + 1))()
n_of_tok = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8"), embd_arr, len(embd_arr), True)
self.embd_inp += toIntArray(embd_arr[:n_of_tok])
if (self.instruct):
self.embd_inp += self.inp_prefix
self.embd_inp += self._tokenize(prompt + "\n")
if (self.instruct):
self.embd_inp += self.inp_suffix
# write output
def output(self):
self.remaining_tokens = self.n_predict
for id in self.generate():
@ -193,7 +221,7 @@ if __name__ == "__main__":
USER_NAME="User"
AI_NAME="ChatLLaMa"
time_now = datetime.now()
prompt = f"""Text transcript of a never ending dialog, where {USER_NAME} interacts with an AI assistant named {AI_NAME}.
{AI_NAME} is helpful, kind, honest, friendly, good at writing and never fails to answer {USER_NAME}s requests immediately and with details and precision.
@ -214,7 +242,7 @@ The transcript only includes text, it does not include markup like HTML and Mark
{USER_NAME}:"""
print("Loading model...")
ll = LLaMAInteract(prompt,
m = LLaMAInteract(prompt,
model="./models/30B/ggml-model-q4_0.bin",
n_ctx=2048,
antiprompt=[f"\n{USER_NAME}:"],
@ -224,12 +252,11 @@ The transcript only includes text, it does not include markup like HTML and Mark
)
print("Loaded model!")
for i in ll.output():
for i in m.output():
print(i,end="",flush=True)
ll.input_echo = False
m.input_echo = False
inp = lambda x: f" {x}\n"
while True:
ll.input(inp(input(' ')))
for i in ll.output():
m.input(" " + input('\n> ' if m.instruct else " "))
for i in m.output():
print(i,end="",flush=True)