From 1347e1d050fc5a9a32ffe0bb3e22858da28003bd Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Wed, 10 Apr 2024 02:40:41 -0400 Subject: [PATCH] feat: Add typechecking for ctypes structure attributes --- llama_cpp/llama_cpp.py | 216 ++++++++++++++++++++++++++++++++++------- 1 file changed, 180 insertions(+), 36 deletions(-) diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 99ae7de..2450d11 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -237,7 +237,7 @@ LLAMA_FILE_MAGIC_GGLA = 0x67676C61 # define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' LLAMA_FILE_MAGIC_GGSN = 0x6767736E -#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq' +# define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq' LLAMA_FILE_MAGIC_GGSQ = 0x67677371 # define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN @@ -245,9 +245,9 @@ LLAMA_SESSION_MAGIC = LLAMA_FILE_MAGIC_GGSN # define LLAMA_SESSION_VERSION 5 LLAMA_SESSION_VERSION = 5 -#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ +# define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ LLAMA_STATE_SEQ_MAGIC = LLAMA_FILE_MAGIC_GGSQ -#define LLAMA_STATE_SEQ_VERSION 1 +# define LLAMA_STATE_SEQ_VERSION 1 LLAMA_STATE_SEQ_VERSION = 1 # struct llama_model; @@ -431,6 +431,11 @@ class llama_token_data(ctypes.Structure): logit (float): log-odds of the token p (float): probability of the token""" + if TYPE_CHECKING: + id: llama_token + logit: float + p: float + _fields_ = [ ("id", llama_token), ("logit", ctypes.c_float), @@ -454,6 +459,11 @@ class llama_token_data_array(ctypes.Structure): size (int): size of the array sorted (bool): whether the array is sorted""" + if TYPE_CHECKING: + data: CtypesArray[llama_token_data] + size: int + sorted: bool + _fields_ = [ ("data", llama_token_data_p), ("size", ctypes.c_size_t), @@ -515,6 +525,15 @@ class llama_batch(ctypes.Structure): logits (ctypes.Array[ctypes.ctypes.c_int8]): if zero, the logits for the respective token will not be output """ + if TYPE_CHECKING: + n_tokens: int + token: CtypesArray[llama_token] + embd: CtypesArray[ctypes.c_float] + pos: CtypesArray[CtypesArray[llama_pos]] + n_seq_id: CtypesArray[ctypes.c_int] + seq_id: CtypesArray[CtypesArray[llama_seq_id]] + logits: CtypesArray[ctypes.c_int8] + _fields_ = [ ("n_tokens", ctypes.c_int32), ("token", ctypes.POINTER(llama_token)), @@ -609,6 +628,18 @@ class llama_model_params(ctypes.Structure): use_mmap (bool): use mmap if possible use_mlock (bool): force system to keep model in RAM""" + if TYPE_CHECKING: + n_gpu_layers: int + split_mode: int + main_gpu: int + tensor_split: CtypesArray[ctypes.c_float] + progress_callback: Callable[[float, ctypes.c_void_p], bool] + progress_callback_user_data: ctypes.c_void_p + kv_overrides: CtypesArray[llama_model_kv_override] + vocab_only: bool + use_mmap: bool + use_mlock: bool + _fields_ = [ ("n_gpu_layers", ctypes.c_int32), ("split_mode", ctypes.c_int), @@ -696,6 +727,34 @@ class llama_context_params(ctypes.Structure): abort_callback_data (ctypes.ctypes.c_void_p): data for abort_callback """ + if TYPE_CHECKING: + seed: int + n_ctx: int + n_batch: int + n_ubatch: int + n_seq_max: int + n_threads: int + n_threads_batch: int + rope_scaling_type: int + pooling_type: int + rope_freq_base: float + rope_freq_scale: float + yarn_ext_factor: float + yarn_attn_factor: float + yarn_beta_fast: float + yarn_beta_slow: float + yarn_orig_ctx: int + defrag_thold: float + cb_eval: Callable[[ctypes.c_void_p, bool], bool] + cb_eval_user_data: ctypes.c_void_p + type_k: int + type_v: int + logits_all: bool + embeddings: bool + offload_kqv: bool + abort_callback: Callable[[ctypes.c_void_p], bool] + abort_callback_data: ctypes.c_void_p + _fields_ = [ ("seed", ctypes.c_uint32), ("n_ctx", ctypes.c_uint32), @@ -771,6 +830,18 @@ class llama_model_quantize_params(ctypes.Structure): kv_overrides (ctypes.c_void_p): pointer to vector containing overrides """ + if TYPE_CHECKING: + nthread: int + ftype: int + output_tensor_type: int + token_embedding_type: int + allow_requantize: bool + quantize_output_tensor: bool + only_copy: bool + pure: bool + imatrix: ctypes.c_void_p + kv_overrides: ctypes.c_void_p + _fields_ = [ ("nthread", ctypes.c_int32), ("ftype", ctypes.c_int), @@ -828,6 +899,10 @@ LLAMA_GRETYPE_CHAR_ALT = 6 # uint32_t value; // Unicode code point or rule ID # } llama_grammar_element; class llama_grammar_element(ctypes.Structure): + if TYPE_CHECKING: + type: int + value: int + _fields_ = [ ("type", ctypes.c_int), ("value", ctypes.c_uint32), @@ -851,6 +926,17 @@ llama_grammar_element_p = ctypes.POINTER(llama_grammar_element) # int32_t n_eval; # }; class llama_timings(ctypes.Structure): + if TYPE_CHECKING: + t_start_ms: float + t_end_ms: float + t_load_ms: float + t_sample_ms: float + t_p_eval_ms: float + t_eval_ms: float + n_sample: int + n_p_eval: int + n_eval: int + _fields_ = [ ("t_start_ms", ctypes.c_double), ("t_end_ms", ctypes.c_double), @@ -951,7 +1037,8 @@ GGML_NUMA_STRATEGY_COUNT = 5 [ctypes.c_int], None, ) -def llama_numa_init(numa: int, /): ... +def llama_numa_init(numa: int, /): + ... # // Call once at the end of the program - currently only used for MPI @@ -976,7 +1063,8 @@ def llama_backend_free(): ) def llama_load_model_from_file( path_model: bytes, params: llama_model_params, / -) -> Optional[llama_model_p]: ... +) -> Optional[llama_model_p]: + ... # LLAMA_API void llama_free_model(struct llama_model * model); @@ -985,7 +1073,8 @@ def llama_load_model_from_file( [llama_model_p_ctypes], None, ) -def llama_free_model(model: llama_model_p, /): ... +def llama_free_model(model: llama_model_p, /): + ... # LLAMA_API struct llama_context * llama_new_context_with_model( @@ -998,7 +1087,8 @@ def llama_free_model(model: llama_model_p, /): ... ) def llama_new_context_with_model( model: llama_model_p, params: llama_context_params, / -) -> Optional[llama_context_p]: ... +) -> Optional[llama_context_p]: + ... # // Frees all allocated memory @@ -1019,82 +1109,98 @@ def llama_free(ctx: llama_context_p, /): [], ctypes.c_int64, ) -def llama_time_us() -> int: ... +def llama_time_us() -> int: + ... # LLAMA_API size_t llama_max_devices(void); @ctypes_function("llama_max_devices", [], ctypes.c_size_t) -def llama_max_devices() -> int: ... +def llama_max_devices() -> int: + ... # LLAMA_API bool llama_supports_mmap (void); @ctypes_function("llama_supports_mmap", [], ctypes.c_bool) -def llama_supports_mmap() -> bool: ... +def llama_supports_mmap() -> bool: + ... # LLAMA_API bool llama_supports_mlock (void); @ctypes_function("llama_supports_mlock", [], ctypes.c_bool) -def llama_supports_mlock() -> bool: ... +def llama_supports_mlock() -> bool: + ... # LLAMA_API bool llama_supports_gpu_offload(void); @ctypes_function("llama_supports_gpu_offload", [], ctypes.c_bool) -def llama_supports_gpu_offload() -> bool: ... +def llama_supports_gpu_offload() -> bool: + ... # LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx); @ctypes_function("llama_get_model", [llama_context_p_ctypes], llama_model_p_ctypes) -def llama_get_model(ctx: llama_context_p, /) -> Optional[llama_model_p]: ... +def llama_get_model(ctx: llama_context_p, /) -> Optional[llama_model_p]: + ... # LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx); @ctypes_function("llama_n_ctx", [llama_context_p_ctypes], ctypes.c_uint32) -def llama_n_ctx(ctx: llama_context_p, /) -> int: ... +def llama_n_ctx(ctx: llama_context_p, /) -> int: + ... # LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx); @ctypes_function("llama_n_batch", [llama_context_p_ctypes], ctypes.c_uint32) -def llama_n_batch(ctx: llama_context_p, /) -> int: ... +def llama_n_batch(ctx: llama_context_p, /) -> int: + ... # LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx); @ctypes_function("llama_n_ubatch", [llama_context_p_ctypes], ctypes.c_uint32) -def llama_n_ubatch(ctx: llama_context_p, /) -> int: ... +def llama_n_ubatch(ctx: llama_context_p, /) -> int: + ... # LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx); @ctypes_function("llama_n_seq_max", [llama_context_p_ctypes], ctypes.c_uint32) -def llama_n_seq_max(ctx: llama_context_p, /) -> int: ... +def llama_n_seq_max(ctx: llama_context_p, /) -> int: + ... # LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model); @ctypes_function("llama_vocab_type", [llama_model_p_ctypes], ctypes.c_int) -def llama_vocab_type(model: llama_model_p, /) -> int: ... +def llama_vocab_type(model: llama_model_p, /) -> int: + ... # LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model); @ctypes_function("llama_rope_type", [llama_model_p_ctypes], ctypes.c_int) -def llama_rope_type(model: llama_model_p, /) -> int: ... +def llama_rope_type(model: llama_model_p, /) -> int: + ... # LLAMA_API int32_t llama_n_vocab (const struct llama_model * model); @ctypes_function("llama_n_vocab", [llama_model_p_ctypes], ctypes.c_int32) -def llama_n_vocab(model: llama_model_p, /) -> int: ... +def llama_n_vocab(model: llama_model_p, /) -> int: + ... # LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model); @ctypes_function("llama_n_ctx_train", [llama_model_p_ctypes], ctypes.c_int32) -def llama_n_ctx_train(model: llama_model_p, /) -> int: ... +def llama_n_ctx_train(model: llama_model_p, /) -> int: + ... # LLAMA_API int32_t llama_n_embd (const struct llama_model * model); @ctypes_function("llama_n_embd", [llama_model_p_ctypes], ctypes.c_int32) -def llama_n_embd(model: llama_model_p, /) -> int: ... +def llama_n_embd(model: llama_model_p, /) -> int: + ... # LLAMA_API int32_t llama_n_layer (const struct llama_model * model); @ctypes_function("llama_n_layer", [llama_model_p_ctypes], ctypes.c_int32) -def llama_n_layer(model: llama_model_p, /) -> int: ... +def llama_n_layer(model: llama_model_p, /) -> int: + ... # // Get the model's RoPE frequency scaling factor @@ -1358,6 +1464,9 @@ class llama_kv_cache_view_cell(ctypes.Structure): pos (llama_pos): The position for this cell. Takes KV cache shifts into account. May be negative if the cell is not populated.""" + if TYPE_CHECKING: + pos: llama_pos + _fields_ = [("pos", llama_pos)] @@ -1394,6 +1503,16 @@ class llama_kv_cache_view_cell(ctypes.Structure): # llama_seq_id * cells_sequences; # }; class llama_kv_cache_view(ctypes.Structure): + if TYPE_CHECKING: + n_cells: int + n_max_seq: int + token_count: int + used_cells: int + max_contiguous: int + max_contiguous_idx: int + cells: CtypesArray[llama_kv_cache_view_cell] + cells_sequences: CtypesArray[llama_seq_id] + _fields_ = [ ("n_cells", ctypes.c_int32), ("n_max_seq", ctypes.c_int32), @@ -1783,7 +1902,8 @@ def llama_state_load_file( n_token_capacity: Union[ctypes.c_size_t, int], n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t], /, -) -> bool: ... +) -> bool: + ... # LLAMA_API DEPRECATED(bool llama_load_session_file( @@ -1811,7 +1931,8 @@ def llama_load_session_file( n_token_capacity: Union[ctypes.c_size_t, int], n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t], /, -) -> int: ... +) -> int: + ... # LLAMA_API bool llama_state_save_file( @@ -1835,7 +1956,8 @@ def llama_state_save_file( tokens: CtypesArray[llama_token], n_token_count: Union[ctypes.c_size_t, int], /, -) -> bool: ... +) -> bool: + ... # LLAMA_API DEPRECATED(bool llama_save_session_file( @@ -1860,7 +1982,8 @@ def llama_save_session_file( tokens: CtypesArray[llama_token], n_token_count: Union[ctypes.c_size_t, int], /, -) -> int: ... +) -> int: + ... # // Get the exact size needed to copy the KV cache of a single sequence @@ -2233,7 +2356,8 @@ def llama_get_embeddings_seq( ) def llama_token_get_text( model: llama_model_p, token: Union[llama_token, int], / -) -> bytes: ... +) -> bytes: + ... # LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token); @@ -2242,7 +2366,8 @@ def llama_token_get_text( ) def llama_token_get_score( model: llama_model_p, token: Union[llama_token, int], / -) -> float: ... +) -> float: + ... # LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token); @@ -2251,7 +2376,8 @@ def llama_token_get_score( ) def llama_token_get_type( model: llama_model_p, token: Union[llama_token, int], / -) -> int: ... +) -> int: + ... # // Special tokens @@ -2318,17 +2444,20 @@ def llama_token_prefix(model: llama_model_p) -> int: # LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle @ctypes_function("llama_token_middle", [llama_model_p_ctypes], llama_token) -def llama_token_middle(model: llama_model_p, /) -> int: ... +def llama_token_middle(model: llama_model_p, /) -> int: + ... # LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix @ctypes_function("llama_token_suffix", [llama_model_p_ctypes], llama_token) -def llama_token_suffix(model: llama_model_p, /) -> int: ... +def llama_token_suffix(model: llama_model_p, /) -> int: + ... # LLAMA_API llama_token llama_token_eot (const struct llama_model * model); // End of infill middle @ctypes_function("llama_token_eot", [llama_model_p_ctypes], llama_token) -def llama_token_eot(model: llama_model_p, /) -> int: ... +def llama_token_eot(model: llama_model_p, /) -> int: + ... # // @@ -2459,7 +2588,8 @@ def llama_chat_apply_template( chat: CtypesArray[llama_chat_message], n_msg: int, /, -) -> int: ... +) -> int: + ... # // @@ -2989,6 +3119,12 @@ def llama_grammar_accept_token( # bool eob; // Callback should set this to true when a beam is at end-of-beam. # }; class llama_beam_view(ctypes.Structure): + if TYPE_CHECKING: + tokens: CtypesArray[llama_token] + n_tokens: int + p: float + eob: bool + _fields_ = [ ("tokens", llama_token_p), ("n_tokens", ctypes.c_size_t), @@ -3008,6 +3144,12 @@ class llama_beam_view(ctypes.Structure): # bool last_call; // True iff this is the last callback invocation. # }; class llama_beams_state(ctypes.Structure): + if TYPE_CHECKING: + beam_views: CtypesArray[llama_beam_view] + n_beams: int + common_prefix_length: int + last_call: bool + _fields_ = [ ("beam_views", ctypes.POINTER(llama_beam_view)), ("n_beams", ctypes.c_size_t), @@ -3060,7 +3202,8 @@ def llama_beam_search( n_past: Union[ctypes.c_int, int], n_predict: Union[ctypes.c_int, int], /, -): ... +): + ... # /// @details Build a split GGUF final path for this chunk. @@ -3179,4 +3322,5 @@ def llama_log_set( [ctypes.c_void_p, llama_context_p_ctypes], None, ) -def llama_dump_timing_info_yaml(stream: ctypes.c_void_p, ctx: llama_context_p, /): ... +def llama_dump_timing_info_yaml(stream: ctypes.c_void_p, ctx: llama_context_p, /): + ...