diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 4344418..916fe07 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -176,7 +176,6 @@ class Llama: if self.verbose: print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr) - n_vocab = self.n_vocab() n_ctx = self.n_ctx() @@ -575,9 +574,9 @@ class Llama: else: inputs = input - data = [] + data: List[EmbeddingData] = [] total_tokens = 0 - for input in inputs: + for index, input in enumerate(inputs): tokens = self.tokenize(input.encode("utf-8")) self.reset() self.eval(tokens) @@ -587,20 +586,20 @@ class Llama: : llama_cpp.llama_n_embd(self.ctx) ] - if self.verbose: - llama_cpp.llama_print_timings(self.ctx) data.append( { "object": "embedding", "embedding": embedding, - "index": 0, + "index": index, } ) + if self.verbose: + llama_cpp.llama_print_timings(self.ctx) return { "object": "list", "data": data, - "model": self.model_path, + "model": model_name, "usage": { "prompt_tokens": total_tokens, "total_tokens": total_tokens,