Fix mirostat sampling

This commit is contained in:
Andrei Betlen 2024-01-19 08:31:59 -05:00
parent 141293a75b
commit 3babe3512c

View file

@ -329,6 +329,8 @@ class Llama:
(n_ctx, self._n_vocab), dtype=np.single
)
self._mirostat_mu = ctypes.c_float(2.0 * 5.0) # TODO: Move this to sampling context
@property
def ctx(self) -> llama_cpp.llama_context_p:
assert self._ctx.ctx is not None
@ -516,7 +518,7 @@ class Llama:
candidates=self._candidates,
tau=mirostat_tau,
eta=mirostat_eta,
mu=2.0 * mirostat_tau,
mu=ctypes.pointer(self._mirostat_mu),
m=100,
)
elif mirostat_mode == 2:
@ -525,7 +527,7 @@ class Llama:
candidates=self._candidates,
tau=mirostat_tau,
eta=mirostat_eta,
mu=2.0 * mirostat_tau,
mu=ctypes.pointer(self._mirostat_mu)
)
else:
self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1)
@ -581,6 +583,10 @@ class Llama:
Yields:
The generated tokens.
"""
# Reset mirostat sampling
self._mirostat_mu = ctypes.c_float(2.0 * mirostat_tau)
# Check for kv cache prefix match
if reset and self.n_tokens > 0:
longest_prefix = 0
for a, b in zip(self._input_ids, tokens[:-1]):
@ -595,12 +601,15 @@ class Llama:
tokens = tokens[longest_prefix:]
self.n_tokens = longest_prefix
# Reset the model state
if reset:
self.reset()
# Reset the grammar
if grammar is not None:
grammar.reset()
# Eval and sample
while True:
self.eval(tokens)
token = self.sample(