Fix logits_all bug

This commit is contained in:
Andrei Betlen 2023-09-30 16:02:35 -04:00
parent 6ee413d79e
commit d696251fbe

View file

@ -439,7 +439,7 @@ class Llama:
def eval_logits(self) -> Deque[List[float]]: def eval_logits(self) -> Deque[List[float]]:
return deque( return deque(
self.scores[: self.n_tokens, :].tolist(), self.scores[: self.n_tokens, :].tolist(),
maxlen=self._n_ctx if self.model_params.logits_all else 1, maxlen=self._n_ctx if self.context_params.logits_all else 1,
) )
def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]: def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
@ -964,7 +964,7 @@ class Llama:
else: else:
stop_sequences = [] stop_sequences = []
if logprobs is not None and self.model_params.logits_all is False: if logprobs is not None and self.context_params.logits_all is False:
raise ValueError( raise ValueError(
"logprobs is not supported for models created with logits_all=False" "logprobs is not supported for models created with logits_all=False"
) )