Add field to disable reseting between generations

This commit is contained in:
Andrei Betlen 2023-04-13 00:28:00 -04:00
parent 22fa5a621f
commit 6595ad84bf

View file

@ -218,6 +218,7 @@ class Llama:
top_p: float,
temp: float,
repeat_penalty: float,
reset: bool = True,
) -> Generator[
llama_cpp.llama_token, Optional[Sequence[llama_cpp.llama_token]], None
]:
@ -235,12 +236,14 @@ class Llama:
top_p: The top-p sampling parameter.
temp: The temperature parameter.
repeat_penalty: The repeat penalty parameter.
reset: Whether to reset the model state.
Yields:
The generated tokens.
"""
assert self.ctx is not None
self.reset()
if reset:
self.reset()
while True:
self.eval(tokens)
token = self.sample(