docs: Improve low-level docstrings

This commit is contained in:
Andrei Betlen 2023-11-27 19:03:02 -05:00
parent 9c68b1804a
commit 396dbf0b2b

View file

@ -212,6 +212,12 @@ LLAMA_ROPE_SCALING_MAX_VALUE = LLAMA_ROPE_SCALING_YARN
# float p; // probability of the token
# } llama_token_data;
class llama_token_data(Structure):
"""Used to store token data
Attributes:
id (llama_token): token id
logit (float): log-odds of the token
p (float): probability of the token"""
_fields_ = [
("id", llama_token),
("logit", c_float),
@ -228,6 +234,12 @@ llama_token_data_p = POINTER(llama_token_data)
# bool sorted;
# } llama_token_data_array;
class llama_token_data_array(Structure):
"""Used to sample tokens given logits
Attributes:
data (ctypes.Array[llama_token_data]): token data
size (int): size of the array
sorted (bool): whether the array is sorted"""
_fields_ = [
("data", llama_token_data_p),
("size", c_size_t),
@ -282,8 +294,7 @@ class llama_batch(Structure):
token (ctypes.Array[llama_token]): the token ids of the input (used when embd is NULL)
embd (ctypes.Array[ctypes.c_float]): token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
pos (ctypes.Array[ctypes.Array[llama_pos]]): the positions of the respective token in the sequence
seq_id (ctypes.Array[ctypes.Array[llama_seq_id]]): the sequence to which the respective token belongs
"""
seq_id (ctypes.Array[ctypes.Array[llama_seq_id]]): the sequence to which the respective token belongs"""
_fields_ = [
("n_tokens", c_int32),
@ -316,6 +327,17 @@ class llama_batch(Structure):
# bool use_mlock; // force system to keep model in RAM
# };
class llama_model_params(Structure):
"""Parameters for llama_model
Attributes:
n_gpu_layers (int): number of layers to store in VRAM
main_gpu (int): the GPU that is used for scratch and small tensors
tensor_split (ctypes.Array[ctypes.c_float]): how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)
progress_callback (llama_progress_callback): called with a progress value between 0 and 1, pass NULL to disable
progress_callback_user_data (ctypes.c_void_p): context pointer passed to the progress callback
vocab_only (bool): only load the vocabulary, no weights
use_mmap (bool): use mmap if possible
use_mlock (bool): force system to keep model in RAM"""
_fields_ = [
("n_gpu_layers", c_int32),
("main_gpu", c_int32),
@ -353,6 +375,26 @@ class llama_model_params(Structure):
# bool embedding; // embedding mode only
# };
class llama_context_params(Structure):
"""Parameters for llama_context
Attributes:
seed (int): RNG seed, -1 for random
n_ctx (int): text context, 0 = from model
n_batch (int): prompt processing maximum batch size
n_threads (int): number of threads to use for generation
n_threads_batch (int): number of threads to use for batch processing
rope_scaling_type (int): RoPE scaling type, from `enum llama_rope_scaling_type`
rope_freq_base (float): RoPE base frequency, 0 = from model
rope_freq_scale (float): RoPE frequency scaling factor, 0 = from model
yarn_ext_factor (float): YaRN extrapolation mix factor, negative = from model
yarn_attn_factor (float): YaRN magnitude scaling factor
yarn_beta_fast (float): YaRN low correction dim
yarn_beta_slow (float): YaRN high correction dim
yarn_orig_ctx (int): YaRN original context size
mul_mat_q (bool): if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
f16_kv (bool): use fp16 for KV cache, fp32 otherwise
logits_all (bool): the llama_eval() call computes all logits, not just the last one
embedding (bool): embedding mode only"""
_fields_ = [
("seed", c_uint32),
("n_ctx", c_uint32),
@ -398,6 +440,15 @@ It might not exist for progress report where '.' is output repeatedly."""
# bool pure; // disable k-quant mixtures and quantize all tensors to the same type
# } llama_model_quantize_params;
class llama_model_quantize_params(Structure):
"""Parameters for llama_model_quantize
Attributes:
nthread (int): number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
ftype (int): quantize to this llama_ftype
allow_requantize (bool): allow quantizing non-f32/f16 tensors
quantize_output_tensor (bool): quantize output.weight
only_copy (bool): only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
pure (bool): disable k-quant mixtures and quantize all tensors to the same type"""
_fields_ = [
("nthread", c_int),
("ftype", c_int),
@ -489,6 +540,7 @@ class llama_timings(Structure):
# // Helpers for getting default parameters
# LLAMA_API struct llama_model_params llama_model_default_params(void);
def llama_model_default_params() -> llama_model_params:
"""Get default parameters for llama_model"""
return _lib.llama_model_default_params()
@ -498,6 +550,7 @@ _lib.llama_model_default_params.restype = llama_model_params
# LLAMA_API struct llama_context_params llama_context_default_params(void);
def llama_context_default_params() -> llama_context_params:
"""Get default parameters for llama_context"""
return _lib.llama_context_default_params()
@ -507,6 +560,7 @@ _lib.llama_context_default_params.restype = llama_context_params
# LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);
def llama_model_quantize_default_params() -> llama_model_quantize_params:
"""Get default parameters for llama_model_quantize"""
return _lib.llama_model_quantize_default_params()
@ -1668,6 +1722,7 @@ def llama_grammar_init(
n_rules: Union[c_size_t, int],
start_rule_index: Union[c_size_t, int],
) -> llama_grammar_p:
"""Initialize a grammar from a set of rules."""
return _lib.llama_grammar_init(rules, n_rules, start_rule_index)
@ -1681,6 +1736,7 @@ _lib.llama_grammar_init.restype = llama_grammar_p
# LLAMA_API void llama_grammar_free(struct llama_grammar * grammar);
def llama_grammar_free(grammar: llama_grammar_p):
"""Free a grammar."""
return _lib.llama_grammar_free(grammar)
@ -1690,6 +1746,7 @@ _lib.llama_grammar_free.restype = None
# LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar);
def llama_grammar_copy(grammar: llama_grammar_p) -> llama_grammar_p:
"""Copy a grammar."""
return _lib.llama_grammar_copy(grammar)
@ -1939,6 +1996,11 @@ def llama_sample_temp(
candidates, # type: _Pointer[llama_token_data_array]
temp: Union[c_float, float],
):
"""Temperature sampling described in academic paper "Generating Long Sequences with Sparse Transformers" https://arxiv.org/abs/1904.10509
Parameters:
candidates: A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
temp: The temperature value to use for the sampling. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text."""
return _lib.llama_sample_temp(ctx, candidates, temp)
@ -1960,6 +2022,7 @@ def llama_sample_temperature(
candidates, # type: _Pointer[llama_token_data_array]
temp: Union[c_float, float],
):
"""use llama_sample_temp instead"""
return _lib.llama_sample_temperature(ctx, candidates, temp)
@ -1981,6 +2044,11 @@ def llama_sample_grammar(
candidates, # type: _Pointer[llama_token_data_array]
grammar, # type: llama_grammar_p
):
"""Apply constraints from grammar
Parameters:
candidates: A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
grammar: A grammar object containing the rules and constraints to apply to the generated text."""
return _lib.llama_sample_grammar(ctx, candidates, grammar)
@ -2013,6 +2081,14 @@ def llama_sample_token_mirostat(
m: Union[c_int, int],
mu, # type: _Pointer[c_float]
) -> int:
"""Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
Parameters:
candidates: A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
tau: The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
eta: The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
m: The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
mu: Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal."""
return _lib.llama_sample_token_mirostat(ctx, candidates, tau, eta, m, mu)
@ -2045,6 +2121,13 @@ def llama_sample_token_mirostat_v2(
eta: Union[c_float, float],
mu, # type: _Pointer[c_float]
) -> int:
"""Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
Parameters:
candidates: A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
tau: The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
eta: The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
mu: Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal."""
return _lib.llama_sample_token_mirostat_v2(ctx, candidates, tau, eta, mu)
@ -2067,6 +2150,7 @@ def llama_sample_token_greedy(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
) -> int:
"""Selects the token with the highest probability."""
return _lib.llama_sample_token_greedy(ctx, candidates)
@ -2085,6 +2169,7 @@ def llama_sample_token(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
) -> int:
"""Randomly selects a token from the candidates based on their probabilities."""
return _lib.llama_sample_token(ctx, candidates)
@ -2105,6 +2190,7 @@ def llama_grammar_accept_token(
grammar: llama_grammar_p,
token: Union[llama_token, int],
) -> None:
"""Accepts the sampled token into the grammar"""
_lib.llama_grammar_accept_token(ctx, grammar, token)
@ -2207,6 +2293,7 @@ _lib.llama_beam_search.restype = None
# LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
def llama_get_timings(ctx: llama_context_p) -> llama_timings:
"""Get performance information"""
return _lib.llama_get_timings(ctx)
@ -2216,6 +2303,7 @@ _lib.llama_get_timings.restype = llama_timings
# LLAMA_API void llama_print_timings(struct llama_context * ctx);
def llama_print_timings(ctx: llama_context_p):
"""Print performance information"""
_lib.llama_print_timings(ctx)
@ -2225,6 +2313,7 @@ _lib.llama_print_timings.restype = None
# LLAMA_API void llama_reset_timings(struct llama_context * ctx);
def llama_reset_timings(ctx: llama_context_p):
"""Reset performance information"""
_lib.llama_reset_timings(ctx)
@ -2235,6 +2324,7 @@ _lib.llama_reset_timings.restype = None
# Print system information
# LLAMA_API const char * llama_print_system_info(void);
def llama_print_system_info() -> bytes:
"""Print system information"""
return _lib.llama_print_system_info()
@ -2249,6 +2339,9 @@ _lib.llama_print_system_info.restype = c_char_p
def llama_log_set(
log_callback: "ctypes._FuncPointer", user_data: c_void_p # type: ignore
):
"""Set callback for all future logging events.
If this is not called, or NULL is supplied, everything is output on stderr."""
return _lib.llama_log_set(log_callback, user_data)