fix: Set default pooling_type to mean, check for null pointer.

This commit is contained in:
Andrei Betlen 2024-03-14 09:17:41 -04:00
parent dd0ee56217
commit d318cc8b83
2 changed files with 8 additions and 3 deletions

View file

@ -79,6 +79,7 @@ class Llama:
n_threads: Optional[int] = None,
n_threads_batch: Optional[int] = None,
rope_scaling_type: Optional[int] = llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
pooling_type: int = llama_cpp.LLAMA_POOLING_TYPE_MEAN,
rope_freq_base: float = 0.0,
rope_freq_scale: float = 0.0,
yarn_ext_factor: float = -1.0,
@ -151,6 +152,7 @@ class Llama:
n_threads: Number of threads to use for generation
n_threads_batch: Number of threads to use for batch processing
rope_scaling_type: RoPE scaling type, from `enum llama_rope_scaling_type`. ref: https://github.com/ggerganov/llama.cpp/pull/2054
pooling_type: Pooling type, from `enum llama_pooling_type`.
rope_freq_base: RoPE base frequency, 0 = from model
rope_freq_scale: RoPE frequency scaling factor, 0 = from model
yarn_ext_factor: YaRN extrapolation mix factor, negative = from model
@ -271,6 +273,7 @@ class Llama:
if rope_scaling_type is not None
else llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
)
self.context_params.pooling_type = pooling_type
self.context_params.rope_freq_base = (
rope_freq_base if rope_freq_base != 0.0 else 0
)
@ -814,9 +817,12 @@ class Llama:
# store embeddings
for i in range(n_seq):
embedding: List[float] = llama_cpp.llama_get_embeddings_seq(
ptr = llama_cpp.llama_get_embeddings_seq(
self._ctx.ctx, i
)[:n_embd]
)
if not ptr:
raise RuntimeError("Failed to get embeddings from sequence pooling type is not set")
embedding: List[float] = ptr[:n_embd]
if normalize:
norm = float(np.linalg.norm(embedding))
embedding = [v / norm for v in embedding]

View file

@ -579,7 +579,6 @@ class llama_model_params(ctypes.Structure):
# bool embeddings; // if true, extract embeddings (together with logits)
# bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
# // Abort callback
# // if it returns true, execution of llama_decode() will be aborted
# // currently works only with CPU execution