Fix logits_all bug

This commit is contained in:
Andrei Betlen 2023-09-30 16:02:35 -04:00
parent 6ee413d79e
commit d696251fbe

View file

@ -439,7 +439,7 @@ class Llama:
def eval_logits(self) -> Deque[List[float]]:
return deque(
self.scores[: self.n_tokens, :].tolist(),
maxlen=self._n_ctx if self.model_params.logits_all else 1,
maxlen=self._n_ctx if self.context_params.logits_all else 1,
)
def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
@ -964,7 +964,7 @@ class Llama:
else:
stop_sequences = []
if logprobs is not None and self.model_params.logits_all is False:
if logprobs is not None and self.context_params.logits_all is False:
raise ValueError(
"logprobs is not supported for models created with logits_all=False"
)