diff --git a/examples/low_level_api/low_level_api_chatllama_cpp.py b/examples/low_level_api/low_level_api_chatllama_cpp.py index 594d15e..02adf3c 100644 --- a/examples/low_level_api/low_level_api_chatllama_cpp.py +++ b/examples/low_level_api/low_level_api_chatllama_cpp.py @@ -33,6 +33,7 @@ class LLaMAInteract: top_p: float=1., temp: float=1.0, repeat_penalty: float=1, + init_break: bool=True, instruct_inp_prefix: str="\n\n### Instruction:\n\n", instruct_inp_suffix: str="\n\n### Response:\n\n", ) -> None: @@ -48,6 +49,7 @@ class LLaMAInteract: self.top_p=top_p self.temp=temp self.repeat_penalty=repeat_penalty + self.init_break = init_break # runtime args self.input_consumed = 0 @@ -81,9 +83,6 @@ class LLaMAInteract: if (len(primer) > 0): self.embd_inp += self._tokenize(primer) - # break immediately if using instruct - self.init_break = self.instruct - # number of tokens to keep when resetting context if (self.n_keep < 0 or self.n_keep > len(self.embd_inp) or self.instruct): self.n_keep = len(self.embd_inp) @@ -182,13 +181,14 @@ class LLaMAInteract: if (len(self.embd_inp) <= self.input_consumed): # if antiprompt is present, stop if (self.use_antiprompt()): - for i in self.first_antiprompt: - if i == self.last_n_tokens[-len(i):]: - return + if True in [ + i == self.last_n_tokens[-len(i):] + for i in self.first_antiprompt + ]: + break # if we are using instruction mode, and we have processed the initial prompt if (self.init_break): - self.init_break = False break # if end of generation @@ -201,6 +201,8 @@ class LLaMAInteract: self.embd_inp += self.first_antiprompt[0] break + self.init_break = False + def __enter__(self): return self