Update llama.cpp

This commit is contained in:
Andrei Betlen 2023-11-23 16:26:00 -05:00
parent 4474157949
commit 36048d46af
2 changed files with 120 additions and 13 deletions

View file

@ -273,11 +273,11 @@ llama_progress_callback = ctypes.CFUNCTYPE(None, c_float, c_void_p)
# } llama_batch; # } llama_batch;
class llama_batch(Structure): class llama_batch(Structure):
"""Input data for llama_decode """Input data for llama_decode
A llama_batch object can contain input about one or many sequences A llama_batch object can contain input about one or many sequences
The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
Attributes: Attributes:
token (ctypes.Array[llama_token]): the token ids of the input (used when embd is NULL) token (ctypes.Array[llama_token]): the token ids of the input (used when embd is NULL)
embd (ctypes.Array[ctypes.c_float]): token embeddings (i.e. float vector of size n_embd) (used when token is NULL) embd (ctypes.Array[ctypes.c_float]): token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
@ -890,11 +890,103 @@ _lib.llama_model_apply_lora_from_file.restype = c_int
# // # //
# // Returns the number of tokens in the KV cache # // Information associated with an individual cell in the KV cache view.
# LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx), # struct llama_kv_cache_view_cell {
# "avoid using this, it will be removed in the future, instead - count the tokens in user code"); # // The position for this cell. Takes KV cache shifts into account.
# // May be negative if the cell is not populated.
# llama_pos pos;
# };
class llama_kv_cache_view_cell(Structure):
_fields_ = [("pos", llama_pos)]
# // An updateable view of the KV cache.
# struct llama_kv_cache_view {
# // Number of KV cache cells. This will be the same as the context size.
# int32_t n_cells;
# // Maximum number of sequences that can exist in a cell. It's not an error
# // if there are more sequences in a cell than this value, however they will
# // not be visible in the view cells_sequences.
# int32_t n_max_seq;
# // Number of tokens in the cache. For example, if there are two populated
# // cells, the first with 1 sequence id in it and the second with 2 sequence
# // ids then you'll have 3 tokens.
# int32_t token_count;
# // Number of populated cache cells.
# int32_t used_cells;
# // Maximum contiguous empty slots in the cache.
# int32_t max_contiguous;
# // Index to the start of the max_contiguous slot range. Can be negative
# // when cache is full.
# int32_t max_contiguous_idx;
# // Information for an individual cell.
# struct llama_kv_cache_view_cell * cells;
# // The sequences for each cell. There will be n_max_seq items per cell.
# llama_seq_id * cells_sequences;
# };
class llama_kv_cache_view(Structure):
_fields_ = [
("n_cells", c_int32),
("n_max_seq", c_int32),
("token_count", c_int32),
("used_cells", c_int32),
("max_contiguous", c_int32),
("max_contiguous_idx", c_int32),
("cells", POINTER(llama_kv_cache_view_cell)),
("cells_sequences", POINTER(llama_seq_id)),
]
# // Create an empty KV cache view. (use only for debugging purposes)
# LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq);
def llama_kv_cache_view_init(
ctx: llama_context_p, n_max_seq: Union[c_int32, int]
) -> llama_kv_cache_view:
"""Create an empty KV cache view. (use only for debugging purposes)"""
return _lib.llama_kv_cache_view_init(ctx, n_max_seq)
_lib.llama_kv_cache_view_init.argtypes = [llama_context_p, c_int32]
_lib.llama_kv_cache_view_init.restype = llama_kv_cache_view
# // Free a KV cache view. (use only for debugging purposes)
# LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view);
def llama_kv_cache_view_free(view: llama_kv_cache_view):
"""Free a KV cache view. (use only for debugging purposes)"""
return _lib.llama_kv_cache_view_free(view)
_lib.llama_kv_cache_view_free.argtypes = [llama_kv_cache_view]
_lib.llama_kv_cache_view_free.restype = None
# // Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
# LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);
def llama_kv_cache_view_update(ctx: llama_context_p, view: llama_kv_cache_view):
"""Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)"""
return _lib.llama_kv_cache_view_update(ctx, view)
_lib.llama_kv_cache_view_update.argtypes = [llama_context_p, llama_kv_cache_view]
_lib.llama_kv_cache_view_update.restype = None
# // Returns the number of tokens in the KV cache (slow, use only for debug)
# // If a KV cell has multiple sequences assigned to it, it will be counted multiple times
# LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx);
def llama_get_kv_cache_token_count(ctx: llama_context_p) -> int: def llama_get_kv_cache_token_count(ctx: llama_context_p) -> int:
"""Returns the number of tokens in the KV cache""" """Returns the number of tokens in the KV cache (slow, use only for debug)
If a KV cell has multiple sequences assigned to it, it will be counted multiple times
"""
return _lib.llama_get_kv_cache_token_count(ctx) return _lib.llama_get_kv_cache_token_count(ctx)
@ -902,6 +994,17 @@ _lib.llama_get_kv_cache_token_count.argtypes = [llama_context_p]
_lib.llama_get_kv_cache_token_count.restype = c_int _lib.llama_get_kv_cache_token_count.restype = c_int
# // Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
# LLAMA_API int llama_get_kv_cache_used_cells(const struct llama_context * ctx);
def llama_get_kv_cache_used_cells(ctx: llama_context_p) -> int:
"""Returns the number of used KV cells (i.e. have at least one sequence assigned to them)"""
return _lib.llama_get_kv_cache_used_cells(ctx)
_lib.llama_get_kv_cache_used_cells.argtypes = [llama_context_p]
_lib.llama_get_kv_cache_used_cells.restype = c_int
# // Clear the KV cache # // Clear the KV cache
# LLAMA_API void llama_kv_cache_clear( # LLAMA_API void llama_kv_cache_clear(
# struct llama_context * ctx); # struct llama_context * ctx);
@ -1205,8 +1308,9 @@ def llama_batch_get_one(
seq_id: llama_seq_id, seq_id: llama_seq_id,
) -> llama_batch: ) -> llama_batch:
"""Return batch for single sequence of tokens starting at pos_0 """Return batch for single sequence of tokens starting at pos_0
NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it""" NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it
"""
return _lib.llama_batch_get_one(tokens, n_tokens, pos_0, seq_id) return _lib.llama_batch_get_one(tokens, n_tokens, pos_0, seq_id)
@ -1290,7 +1394,8 @@ def llama_set_n_threads(
): ):
"""Set the number of threads used for decoding """Set the number of threads used for decoding
n_threads is the number of threads used for generation (single token) n_threads is the number of threads used for generation (single token)
n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)""" n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
"""
return _lib.llama_set_n_threads(ctx, n_threads, n_threads_batch) return _lib.llama_set_n_threads(ctx, n_threads, n_threads_batch)
@ -1540,7 +1645,8 @@ def llama_token_to_piece(
"""Token Id -> Piece. """Token Id -> Piece.
Uses the vocabulary in the provided context. Uses the vocabulary in the provided context.
Does not write null terminator to the buffer. Does not write null terminator to the buffer.
User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens.""" User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens.
"""
return _lib.llama_token_to_piece(model, token, buf, length) return _lib.llama_token_to_piece(model, token, buf, length)
@ -1626,7 +1732,8 @@ def llama_sample_repetition_penalties(
penalty_present: Union[c_float, float], penalty_present: Union[c_float, float],
): ):
"""Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. """Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.""" Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
"""
return _lib.llama_sample_repetition_penalties( return _lib.llama_sample_repetition_penalties(
ctx, ctx,
candidates, candidates,

2
vendor/llama.cpp vendored

@ -1 +1 @@
Subproject commit 8e672efe632bb6a7333964a255c4b96f018b9a65 Subproject commit 6b0a7420d03b9d13cb0e9439a01ce8476d8bf093