Fix bug in embedding

This commit is contained in:
Andrei Betlen 2023-09-30 13:20:22 -04:00
parent bca965325d
commit 42bb721d64

View file

@ -839,9 +839,10 @@ class Llama:
An embedding object. An embedding object.
""" """
assert self.ctx is not None assert self.ctx is not None
assert self.model is not None
model_name: str = model if model is not None else self.model_path model_name: str = model if model is not None else self.model_path
if self.model_params.embedding == False: if self.context_params.embedding == False:
raise RuntimeError( raise RuntimeError(
"Llama model must be created with embedding=True to call this method" "Llama model must be created with embedding=True to call this method"
) )
@ -863,7 +864,7 @@ class Llama:
n_tokens = len(tokens) n_tokens = len(tokens)
total_tokens += n_tokens total_tokens += n_tokens
embedding = llama_cpp.llama_get_embeddings(self.ctx)[ embedding = llama_cpp.llama_get_embeddings(self.ctx)[
: llama_cpp.llama_n_embd(self.ctx) : llama_cpp.llama_n_embd(self.model)
] ]
data.append( data.append(