diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 30e0de3..615662e 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -839,9 +839,10 @@ class Llama: An embedding object. """ assert self.ctx is not None + assert self.model is not None model_name: str = model if model is not None else self.model_path - if self.model_params.embedding == False: + if self.context_params.embedding == False: raise RuntimeError( "Llama model must be created with embedding=True to call this method" ) @@ -863,7 +864,7 @@ class Llama: n_tokens = len(tokens) total_tokens += n_tokens embedding = llama_cpp.llama_get_embeddings(self.ctx)[ - : llama_cpp.llama_n_embd(self.ctx) + : llama_cpp.llama_n_embd(self.model) ] data.append(