From 0b32bb3d43638b8cd606df0c83f89fdcede7ed1c Mon Sep 17 00:00:00 2001 From: Mug <> Date: Tue, 4 Apr 2023 11:48:48 +0200 Subject: [PATCH] Add instruction mode --- examples/low_level_api_chatllama_cpp.py | 101 +++++++++++++++--------- 1 file changed, 64 insertions(+), 37 deletions(-) diff --git a/examples/low_level_api_chatllama_cpp.py b/examples/low_level_api_chatllama_cpp.py index a244867..6462121 100644 --- a/examples/low_level_api_chatllama_cpp.py +++ b/examples/low_level_api_chatllama_cpp.py @@ -5,24 +5,26 @@ Quirks: * Input is always echoed if on, so it should be turned off when using "input()" * The first antiprompt should be the userprompt like "\nUser:", because its added when n_predict is reached (aka generation ended prematurely) - * n_predict can be set to -1 for unlimited length responses + * n_predict can be set to -1 for unlimited length responses (or just a really high value) + * It's always in interactive mode, generation ends either by reaching an antiprompt + or running out of n_predict. + * Instruction mode adds its own antiprompt """ import llama_cpp -def toIntArray(lst): - return [int(i) for i in lst] - # A LLaMA interactive session class LLaMAInteract: def __init__(self, primer: str="", model: str="./models/30B/ggml-model-q4_0.bin", + instruct: bool=False, n_ctx: int=1024, seed: int=0, n_threads: int=8, antiprompt: list[str]=[], input_echo: bool=True, n_predict: int=20, + n_keep: int=0, n_batch: int=8, repeat_last_n: int=64, top_k: int=50, @@ -31,17 +33,17 @@ class LLaMAInteract: repeat_penalty: float=1, ) -> None: # input args + self.instruct = instruct self.n_threads = n_threads self.input_echo = input_echo self.n_predict = n_predict + self.n_keep = n_keep self.n_batch = n_batch self.repeat_last_n = repeat_last_n self.top_k=top_k self.top_p=top_p self.temp=temp self.repeat_penalty=repeat_penalty - self.n_ctx = n_ctx - self.seed = seed # runtime args self.input_consumed = 0 @@ -54,8 +56,8 @@ class LLaMAInteract: # model load self.lparams = llama_cpp.llama_context_default_params() - self.lparams.n_ctx = self.n_ctx - self.lparams.seed = self.seed + self.lparams.n_ctx = n_ctx + self.lparams.seed = seed self.ctx = llama_cpp.llama_init_from_file(model.encode("utf8"), self.lparams) # determine the required inference memory per token: @@ -63,29 +65,44 @@ class LLaMAInteract: llama_cpp.llama_eval(self.ctx, (llama_cpp.c_int * len(tmp))(*tmp), len(tmp), 0, self.n_threads) # determine newline token - self.llama_token_newline = (llama_cpp.llama_token * 1)() - llama_cpp.llama_tokenize(self.ctx, b"\n", self.llama_token_newline, len(self.llama_token_newline), False) - self.llama_token_newline = toIntArray(self.llama_token_newline) + self.llama_token_newline = self._tokenize("\n", False) + self.inp_prefix = self._tokenize("\n\n### Instruction:\n\n") + self.inp_suffix = self._tokenize("\n\n### Response:\n\n", False) + + # add instruction as antiprompt + if (self.instruct): + self.first_antiprompt.append(self.inp_prefix) # primer feed if (len(primer) > 0): - self.input(primer) - self.n_keep = len(self.embd_inp) + self.embd_inp += self._tokenize(primer) + + # break immediately if using instruct + self.init_break = self.instruct + + # number of tokens to keep when resetting context + if (self.n_keep < 0 or self.n_keep > len(self.embd_inp) or self.instruct): + self.n_keep = len(self.embd_inp) # create internal context - self.n_ctx = int(llama_cpp.llama_n_ctx(self.ctx)) + self.n_ctx = llama_cpp.llama_n_ctx(self.ctx) self.last_n_tokens = [0]*self.n_ctx #TODO: deque doesnt support slices # determine antiprompt tokens for i in antiprompt: - d_antiprompt = (llama_cpp.llama_token * (len(i) + 1))() - n_antiprompt = llama_cpp.llama_tokenize(self.ctx, i.encode("utf8"), d_antiprompt, len(d_antiprompt), False) - self.first_antiprompt.append(toIntArray(d_antiprompt[:n_antiprompt])) + self.first_antiprompt.append(self._tokenize(i, False)) + + # tokenize a prompt + def _tokenize(self, prompt, bos=True): + _arr = (llama_cpp.llama_token * (len(prompt) + 1))() + _n = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8"), _arr, len(_arr), bos) + return _arr[:_n] # if an antiprompt is present def use_antiprompt(self): return len(self.first_antiprompt) > 0 + # generate tokens def generate(self): while self.remaining_tokens > 0 or self.use_antiprompt(): # predict @@ -125,16 +142,16 @@ class LLaMAInteract: self.repeat_penalty, ) self.last_n_tokens.pop(0) - self.last_n_tokens.append(int(id)) + self.last_n_tokens.append(id) # replace end of text token with newline token when in interactive mode - if (id == llama_cpp.llama_token_eos() and self.use_antiprompt()): + if (id == llama_cpp.llama_token_eos() and self.use_antiprompt() and not self.instruct): id = self.llama_token_newline[0] # tokenize and inject first reverse prompt self.embd_inp += self.first_antiprompt[0] # add it to the context - self.embd.append(int(id)) + self.embd.append(id) # echo this to console self.output_echo = True @@ -147,9 +164,9 @@ class LLaMAInteract: # some user input remains from prompt or interaction, forward it to processing while len(self.embd_inp) > self.input_consumed: - self.embd.append(int(self.embd_inp[self.input_consumed])) + self.embd.append(self.embd_inp[self.input_consumed]) self.last_n_tokens.pop(0) - self.last_n_tokens.append(int(self.embd_inp[self.input_consumed])) + self.last_n_tokens.append(self.embd_inp[self.input_consumed]) self.input_consumed += 1 if len(self.embd) >= self.n_batch: break @@ -159,11 +176,17 @@ class LLaMAInteract: for id in self.embd: yield id - # if antiprompt is present, stop - if (self.use_antiprompt() and len(self.embd_inp) <= self.input_consumed): - for i in self.first_antiprompt: - if i == self.last_n_tokens[-len(i):]: - return + if (len(self.embd_inp) <= self.input_consumed): + # if antiprompt is present, stop + if (self.use_antiprompt()): + for i in self.first_antiprompt: + if i == self.last_n_tokens[-len(i):]: + return + + # if we are using instruction mode, and we have processed the initial prompt + if (self.init_break): + self.init_break = False + break # if end of generation if len(self.embd) > 0 and self.embd[-1] == llama_cpp.llama_token_eos(): @@ -174,15 +197,20 @@ class LLaMAInteract: self.embd_inp += self.first_antiprompt[0] break + # return past text def past(self): for id in self.last_n_tokens[-self.n_past:]: yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8") + # write input def input(self, prompt: str): - embd_arr = (llama_cpp.llama_token * (len(prompt) + 1))() - n_of_tok = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8"), embd_arr, len(embd_arr), True) - self.embd_inp += toIntArray(embd_arr[:n_of_tok]) + if (self.instruct): + self.embd_inp += self.inp_prefix + self.embd_inp += self._tokenize(prompt + "\n") + if (self.instruct): + self.embd_inp += self.inp_suffix + # write output def output(self): self.remaining_tokens = self.n_predict for id in self.generate(): @@ -193,7 +221,7 @@ if __name__ == "__main__": USER_NAME="User" AI_NAME="ChatLLaMa" - + time_now = datetime.now() prompt = f"""Text transcript of a never ending dialog, where {USER_NAME} interacts with an AI assistant named {AI_NAME}. {AI_NAME} is helpful, kind, honest, friendly, good at writing and never fails to answer {USER_NAME}’s requests immediately and with details and precision. @@ -214,7 +242,7 @@ The transcript only includes text, it does not include markup like HTML and Mark {USER_NAME}:""" print("Loading model...") - ll = LLaMAInteract(prompt, + m = LLaMAInteract(prompt, model="./models/30B/ggml-model-q4_0.bin", n_ctx=2048, antiprompt=[f"\n{USER_NAME}:"], @@ -224,12 +252,11 @@ The transcript only includes text, it does not include markup like HTML and Mark ) print("Loaded model!") - for i in ll.output(): + for i in m.output(): print(i,end="",flush=True) - ll.input_echo = False + m.input_echo = False - inp = lambda x: f" {x}\n" while True: - ll.input(inp(input(' '))) - for i in ll.output(): + m.input(" " + input('\n> ' if m.instruct else " ")) + for i in m.output(): print(i,end="",flush=True) \ No newline at end of file