fix: Remove deprecated cfg sampling functions

This commit is contained in:
Andrei Betlen 2024-02-28 14:37:07 -05:00
parent 727d60c28a
commit 8c71725d53
2 changed files with 1 additions and 50 deletions

View file

@ -357,21 +357,6 @@ class _LlamaContext:
penalty_present,
)
def sample_classifier_free_guidance(
self,
candidates: "_LlamaTokenDataArray",
guidance_ctx: "_LlamaContext",
scale: float,
):
assert self.ctx is not None
assert guidance_ctx.ctx is not None
llama_cpp.llama_sample_classifier_free_guidance(
self.ctx,
llama_cpp.byref(candidates.candidates),
guidance_ctx.ctx,
scale,
)
def sample_softmax(self, candidates: "_LlamaTokenDataArray"):
assert self.ctx is not None
llama_cpp.llama_sample_softmax(
@ -720,7 +705,7 @@ class _LlamaSamplingContext:
return ctx_main.model.detokenize(self.prev[-n:]).decode("utf-8")
def sample(
self, ctx_main: _LlamaContext, ctx_cfg: Optional[_LlamaContext] = None, idx: int = 0, logits_array: Optional[npt.NDArray[np.single]] = None
self, ctx_main: _LlamaContext, idx: int = 0, logits_array: Optional[npt.NDArray[np.single]] = None
):
n_vocab = ctx_main.model.n_vocab()
id: int = 0
@ -741,11 +726,6 @@ class _LlamaSamplingContext:
) # TODO: Only create this once
token_data_array.copy_logits(logits_array)
if ctx_cfg is not None:
ctx_main.sample_classifier_free_guidance(
token_data_array, ctx_cfg, self.params.cfg_scale
)
# apply penalties
if len(self.prev) > 0:
nl_token = ctx_main.model.token_nl()

View file

@ -2129,35 +2129,6 @@ def llama_sample_apply_guidance(
...
# LLAMA_API DEPRECATED(void llama_sample_classifier_free_guidance(
# struct llama_context * ctx,
# llama_token_data_array * candidates,
# struct llama_context * guidance_ctx,
# float scale),
# "use llama_sample_apply_guidance() instead");
@ctypes_function(
"llama_sample_classifier_free_guidance",
[
llama_context_p_ctypes,
llama_token_data_array_p,
llama_context_p_ctypes,
ctypes.c_float,
],
None,
)
def llama_sample_classifier_free_guidance(
ctx: llama_context_p,
candidates: Union[
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
],
guidance_ctx: llama_context_p,
scale: Union[ctypes.c_float, float],
/,
):
"""Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806"""
...
# /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
# LLAMA_API void llama_sample_softmax(
# struct llama_context * ctx,