From f1615f05e6032bf79d4dc0c683a518ed39aca55a Mon Sep 17 00:00:00 2001 From: Mug <> Date: Mon, 3 Apr 2023 22:54:46 +0200 Subject: [PATCH] Chat llama.cpp example implementation --- examples/low_level_api_chatllama_cpp.py | 235 ++++++++++++++++++++++++ 1 file changed, 235 insertions(+) create mode 100644 examples/low_level_api_chatllama_cpp.py diff --git a/examples/low_level_api_chatllama_cpp.py b/examples/low_level_api_chatllama_cpp.py new file mode 100644 index 0000000..a244867 --- /dev/null +++ b/examples/low_level_api_chatllama_cpp.py @@ -0,0 +1,235 @@ +""" +This is an example implementation of main.cpp from llama.cpp +Quirks: + * Its not exactly alike since this port is designed around programmatic I/O + * Input is always echoed if on, so it should be turned off when using "input()" + * The first antiprompt should be the userprompt like "\nUser:", + because its added when n_predict is reached (aka generation ended prematurely) + * n_predict can be set to -1 for unlimited length responses +""" +import llama_cpp + +def toIntArray(lst): + return [int(i) for i in lst] + +# A LLaMA interactive session +class LLaMAInteract: + def __init__(self, + primer: str="", + model: str="./models/30B/ggml-model-q4_0.bin", + n_ctx: int=1024, + seed: int=0, + n_threads: int=8, + antiprompt: list[str]=[], + input_echo: bool=True, + n_predict: int=20, + n_batch: int=8, + repeat_last_n: int=64, + top_k: int=50, + top_p: float=1., + temp: float=1.0, + repeat_penalty: float=1, + ) -> None: + # input args + self.n_threads = n_threads + self.input_echo = input_echo + self.n_predict = n_predict + self.n_batch = n_batch + self.repeat_last_n = repeat_last_n + self.top_k=top_k + self.top_p=top_p + self.temp=temp + self.repeat_penalty=repeat_penalty + self.n_ctx = n_ctx + self.seed = seed + + # runtime args + self.input_consumed = 0 + self.embd = [] + self.embd_inp = [] + self.n_past = 0 + self.first_antiprompt = [] + self.remaining_tokens = self.n_predict + self.output_echo = input_echo + + # model load + self.lparams = llama_cpp.llama_context_default_params() + self.lparams.n_ctx = self.n_ctx + self.lparams.seed = self.seed + self.ctx = llama_cpp.llama_init_from_file(model.encode("utf8"), self.lparams) + + # determine the required inference memory per token: + tmp = [0, 1, 2, 3] + llama_cpp.llama_eval(self.ctx, (llama_cpp.c_int * len(tmp))(*tmp), len(tmp), 0, self.n_threads) + + # determine newline token + self.llama_token_newline = (llama_cpp.llama_token * 1)() + llama_cpp.llama_tokenize(self.ctx, b"\n", self.llama_token_newline, len(self.llama_token_newline), False) + self.llama_token_newline = toIntArray(self.llama_token_newline) + + # primer feed + if (len(primer) > 0): + self.input(primer) + self.n_keep = len(self.embd_inp) + + # create internal context + self.n_ctx = int(llama_cpp.llama_n_ctx(self.ctx)) + self.last_n_tokens = [0]*self.n_ctx #TODO: deque doesnt support slices + + # determine antiprompt tokens + for i in antiprompt: + d_antiprompt = (llama_cpp.llama_token * (len(i) + 1))() + n_antiprompt = llama_cpp.llama_tokenize(self.ctx, i.encode("utf8"), d_antiprompt, len(d_antiprompt), False) + self.first_antiprompt.append(toIntArray(d_antiprompt[:n_antiprompt])) + + # if an antiprompt is present + def use_antiprompt(self): + return len(self.first_antiprompt) > 0 + + def generate(self): + while self.remaining_tokens > 0 or self.use_antiprompt(): + # predict + if len(self.embd) > 0: + # infinite text generation via context swapping + # if we run out of context: + # - take the n_keep first tokens from the original prompt (via n_past) + # - take half of the last (n_ctx - n_keep) tokens and recompute the logits in a batch + if (self.n_past + len(self.embd) > self.n_ctx): + n_left = self.n_past - self.n_keep + self.n_past = self.n_keep + + # insert n_left/2 tokens at the start of embd from last_n_tokens + _insert = self.last_n_tokens[ + -(int(n_left/2) - len(self.embd)):-len(self.embd) + ] + self.embd[:len(_insert)] = _insert + #TODO: Still untested + + if (llama_cpp.llama_eval( + self.ctx, (llama_cpp.llama_token * len(self.embd))(*self.embd), len(self.embd), self.n_past, self.n_threads + ) != 0): + raise Exception("Failed to llama_eval!") + + self.n_past += len(self.embd) + self.embd = [] + if len(self.embd_inp) <= self.input_consumed: + # out of user input, sample next token + _arr = self.last_n_tokens[-min(self.repeat_last_n, self.n_past):] + id = llama_cpp.llama_sample_top_p_top_k( + self.ctx, + (llama_cpp.llama_token * len(_arr))(*_arr), + len(_arr), + self.top_k, + self.top_p, + self.temp, + self.repeat_penalty, + ) + self.last_n_tokens.pop(0) + self.last_n_tokens.append(int(id)) + + # replace end of text token with newline token when in interactive mode + if (id == llama_cpp.llama_token_eos() and self.use_antiprompt()): + id = self.llama_token_newline[0] + # tokenize and inject first reverse prompt + self.embd_inp += self.first_antiprompt[0] + + # add it to the context + self.embd.append(int(id)) + + # echo this to console + self.output_echo = True + + # decrement remaining sampling budget + self.remaining_tokens -= 1 + else: + # output to console if input echo is on + self.output_echo = self.input_echo + + # some user input remains from prompt or interaction, forward it to processing + while len(self.embd_inp) > self.input_consumed: + self.embd.append(int(self.embd_inp[self.input_consumed])) + self.last_n_tokens.pop(0) + self.last_n_tokens.append(int(self.embd_inp[self.input_consumed])) + self.input_consumed += 1 + if len(self.embd) >= self.n_batch: + break + + # display tokens + if self.output_echo: + for id in self.embd: + yield id + + # if antiprompt is present, stop + if (self.use_antiprompt() and len(self.embd_inp) <= self.input_consumed): + for i in self.first_antiprompt: + if i == self.last_n_tokens[-len(i):]: + return + + # if end of generation + if len(self.embd) > 0 and self.embd[-1] == llama_cpp.llama_token_eos(): + break + + # respect n_predict even if antiprompt is present + if (self.use_antiprompt() and self.remaining_tokens <= 0 and self.n_predict != -1): + self.embd_inp += self.first_antiprompt[0] + break + + def past(self): + for id in self.last_n_tokens[-self.n_past:]: + yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8") + + def input(self, prompt: str): + embd_arr = (llama_cpp.llama_token * (len(prompt) + 1))() + n_of_tok = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8"), embd_arr, len(embd_arr), True) + self.embd_inp += toIntArray(embd_arr[:n_of_tok]) + + def output(self): + self.remaining_tokens = self.n_predict + for id in self.generate(): + yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8") + +if __name__ == "__main__": + from datetime import datetime + + USER_NAME="User" + AI_NAME="ChatLLaMa" + + time_now = datetime.now() + prompt = f"""Text transcript of a never ending dialog, where {USER_NAME} interacts with an AI assistant named {AI_NAME}. +{AI_NAME} is helpful, kind, honest, friendly, good at writing and never fails to answer {USER_NAME}’s requests immediately and with details and precision. +There are no annotations like (30 seconds passed...) or (to himself), just what {USER_NAME} and {AI_NAME} say aloud to each other. +The dialog lasts for years, the entirety of it is shared below. It's 10000 pages long. +The transcript only includes text, it does not include markup like HTML and Markdown. + +{USER_NAME}: Hello, {AI_NAME}! +{AI_NAME}: Hello {USER_NAME}! How may I help you today? +{USER_NAME}: What time is it? +{AI_NAME}: It is {time_now.strftime("%H:%M")}. +{USER_NAME}: What year is it? +{AI_NAME}: We are in {time_now.strftime("%Y")}. +{USER_NAME}: What is a cat? +{AI_NAME}: A cat is a domestic species of small carnivorous mammal. It is the only domesticated species in the family Felidae. +{USER_NAME}: Name a color. +{AI_NAME}: Blue +{USER_NAME}:""" + + print("Loading model...") + ll = LLaMAInteract(prompt, + model="./models/30B/ggml-model-q4_0.bin", + n_ctx=2048, + antiprompt=[f"\n{USER_NAME}:"], + repeat_last_n=256, + n_predict=2048, + temp=0.7, top_p=0.5, top_k=40, repeat_penalty=1.17647 + ) + print("Loaded model!") + + for i in ll.output(): + print(i,end="",flush=True) + ll.input_echo = False + + inp = lambda x: f" {x}\n" + while True: + ll.input(inp(input(' '))) + for i in ll.output(): + print(i,end="",flush=True) \ No newline at end of file