Compare commits

...

8 commits

Author SHA1 Message Date
baalajimaestro 4cb67f59d8
Merge https://github.com/abetlen/llama-cpp-python 2024-03-17 10:24:32 +05:30
Andrei Betlen 6eb25231e4 feat: Update llama.cpp 2024-03-15 12:58:45 -04:00
Andrei Betlen 20e6815252 fix: json mode 2024-03-15 12:58:34 -04:00
Andrei Betlen 1a9b8af2dd feat: Update llama.cpp 2024-03-14 11:46:48 -04:00
Andrei Betlen 4084aabe86 fix: set default pooling type to unspecified 2024-03-14 10:04:57 -04:00
Andrei Betlen d318cc8b83 fix: Set default pooling_type to mean, check for null pointer. 2024-03-14 09:17:41 -04:00
Andrei Betlen dd0ee56217 feat: Update llama.cpp 2024-03-13 15:57:35 -04:00
Andrei Betlen 08e910f7a7 feat: Update llama.cpp 2024-03-10 23:45:05 -04:00
4 changed files with 147 additions and 122 deletions

View file

@ -79,6 +79,7 @@ class Llama:
n_threads: Optional[int] = None,
n_threads_batch: Optional[int] = None,
rope_scaling_type: Optional[int] = llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
pooling_type: int = llama_cpp.LLAMA_POOLING_TYPE_UNSPECIFIED,
rope_freq_base: float = 0.0,
rope_freq_scale: float = 0.0,
yarn_ext_factor: float = -1.0,
@ -151,6 +152,7 @@ class Llama:
n_threads: Number of threads to use for generation
n_threads_batch: Number of threads to use for batch processing
rope_scaling_type: RoPE scaling type, from `enum llama_rope_scaling_type`. ref: https://github.com/ggerganov/llama.cpp/pull/2054
pooling_type: Pooling type, from `enum llama_pooling_type`.
rope_freq_base: RoPE base frequency, 0 = from model
rope_freq_scale: RoPE frequency scaling factor, 0 = from model
yarn_ext_factor: YaRN extrapolation mix factor, negative = from model
@ -271,6 +273,7 @@ class Llama:
if rope_scaling_type is not None
else llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
)
self.context_params.pooling_type = pooling_type
self.context_params.rope_freq_base = (
rope_freq_base if rope_freq_base != 0.0 else 0
)
@ -814,9 +817,12 @@ class Llama:
# store embeddings
for i in range(n_seq):
embedding: List[float] = llama_cpp.llama_get_embeddings_seq(
ptr = llama_cpp.llama_get_embeddings_seq(
self._ctx.ctx, i
)[:n_embd]
)
if not ptr:
raise RuntimeError("Failed to get embeddings from sequence pooling type is not set")
embedding: List[float] = ptr[:n_embd]
if normalize:
norm = float(np.linalg.norm(embedding))
embedding = [v / norm for v in embedding]

View file

@ -339,16 +339,7 @@ def chat_formatter_to_chat_completion_handler(
stop = stop + rstop
if response_format is not None and response_format["type"] == "json_object":
try:
# create grammar from json schema
if "schema" in response_format:
grammar = llama_grammar.LlamaGrammar.from_json_schema(
json.dumps(response_format["schema"]), verbose=llama.verbose
)
except Exception as e:
grammar = llama_grammar.LlamaGrammar.from_string(
llama_grammar.JSON_GBNF, verbose=llama.verbose
)
grammar = _grammar_for_response_format(response_format, verbose=llama.verbose)
completion_or_chunks = llama.create_completion(
prompt=prompt,
@ -606,6 +597,35 @@ def _format_chatglm3(
ret += role
return ret
def _grammar_for_json(verbose:bool=False):
return llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF, verbose=verbose)
def _grammar_for_json_schema(
schema: str,
verbose: bool = False,
fallback_to_json: bool = True
):
try:
return llama_grammar.LlamaGrammar.from_json_schema(schema, verbose=verbose)
except Exception as e:
if fallback_to_json:
return _grammar_for_json(verbose=verbose)
else:
raise e
def _grammar_for_response_format(
response_format: llama_types.ChatCompletionRequestResponseFormat,
verbose: bool = False
):
if response_format["type"] != "json_object":
return None
if "schema" in response_format:
return _grammar_for_json_schema(
json.dumps(response_format["schema"]), verbose=verbose
)
else:
return _grammar_for_json(verbose=verbose)
### Chat Formats ###
@ -1994,16 +2014,7 @@ class Llava15ChatHandler:
prompt = llama.input_ids[: llama.n_tokens].tolist()
if response_format is not None and response_format["type"] == "json_object":
try:
# create grammar from json schema
if "schema" in response_format:
grammar = llama_grammar.LlamaGrammar.from_json_schema(
json.dumps(response_format["schema"])
)
except Exception as e:
grammar = llama_grammar.LlamaGrammar.from_string(
llama_grammar.JSON_GBNF
)
grammar = _grammar_for_response_format(response_format)
return _convert_completion_to_chat(
llama.create_completion(
@ -2159,26 +2170,10 @@ def chatml_function_calling(
tool_calls=None,
add_generation_prompt=True,
)
if response_format is not None and response_format["type"] == "json_object":
try:
grammar = (
llama_grammar.LlamaGrammar.from_json_schema(
json.dumps(response_format["schema"])
)
if "schema" in response_format
else None
)
except Exception as e:
if llama.verbose:
print(
"Failed to parse response format as JSON schema, falling back to default grammar"
)
print(e)
grammar = (
llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF)
if grammar is None
else grammar
)
grammar = _grammar_for_response_format(response_format)
return _convert_completion_to_chat(
llama.create_completion(
prompt=prompt,

View file

@ -198,13 +198,15 @@ llama_seq_id = ctypes.c_int32
# enum llama_vocab_type {
# LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece
# LLAMA_VOCAB_TYPE_BPE = 1, // Byte Pair Encoding
# LLAMA_VOCAB_TYPE_WPM = 2, // WordPiece
# LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab
# LLAMA_VOCAB_TYPE_SPM = 1, // SentencePiece
# LLAMA_VOCAB_TYPE_BPE = 2, // Byte Pair Encoding
# LLAMA_VOCAB_TYPE_WPM = 3, // WordPiece
# };
LLAMA_VOCAB_TYPE_SPM = 0
LLAMA_VOCAB_TYPE_BPE = 1
LLAMA_VOCAB_TYPE_WPM = 2
LLAMA_VOCAB_TYPE_NONE = 0
LLAMA_VOCAB_TYPE_SPM = 1
LLAMA_VOCAB_TYPE_BPE = 2
LLAMA_VOCAB_TYPE_WPM = 3
# // note: these values should be synchronized with ggml_rope
@ -548,8 +550,9 @@ class llama_model_params(ctypes.Structure):
# struct llama_context_params {
# uint32_t seed; // RNG seed, -1 for random
# uint32_t n_ctx; // text context, 0 = from model
# uint32_t n_batch; // prompt processing maximum batch size
# uint32_t n_parallel; // number of parallel sequences (i.e. distinct states for recurrent models)
# uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode
# uint32_t n_ubatch; // physical maximum batch size
# uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models)
# uint32_t n_threads; // number of threads to use for generation
# uint32_t n_threads_batch; // number of threads to use for batch processing
@ -590,8 +593,9 @@ class llama_context_params(ctypes.Structure):
Attributes:
seed (int): RNG seed, -1 for random
n_ctx (int): text context, 0 = from model
n_batch (int): prompt processing maximum batch size
n_parallel (int): number of parallel sequences (i.e. distinct states for recurrent models)
n_batch (int): logical maximum batch size that can be submitted to llama_decode
n_ubatch (int): physical maximum batch size
n_seq_max (int): max number of sequences (i.e. distinct states for recurrent models)
n_threads (int): number of threads to use for generation
n_threads_batch (int): number of threads to use for batch processing
rope_scaling_type (int): RoPE scaling type, from `enum llama_rope_scaling_type`
@ -619,7 +623,8 @@ class llama_context_params(ctypes.Structure):
("seed", ctypes.c_uint32),
("n_ctx", ctypes.c_uint32),
("n_batch", ctypes.c_uint32),
("n_parallel", ctypes.c_uint32),
("n_ubatch", ctypes.c_uint32),
("n_seq_max", ctypes.c_uint32),
("n_threads", ctypes.c_uint32),
("n_threads_batch", ctypes.c_uint32),
("rope_scaling_type", ctypes.c_int),
@ -667,7 +672,7 @@ It might not exist for progress report where '.' is output repeatedly."""
# bool allow_requantize; // allow quantizing non-f32/f16 tensors
# bool quantize_output_tensor; // quantize output.weight
# bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
# bool pure; // disable k-quant mixtures and quantize all tensors to the same type
# bool pure; // quantize all tensors to the default type
# void * imatrix; // pointer to importance matrix data
# } llama_model_quantize_params;
class llama_model_quantize_params(ctypes.Structure):
@ -679,7 +684,7 @@ class llama_model_quantize_params(ctypes.Structure):
allow_requantize (bool): allow quantizing non-f32/f16 tensors
quantize_output_tensor (bool): quantize output.weight
only_copy (bool): only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
pure (bool): disable k-quant mixtures and quantize all tensors to the same type
pure (bool): quantize all tensors to the default type
imatrix (ctypes.ctypes.c_void_p): pointer to importance matrix data
"""
@ -860,8 +865,7 @@ GGML_NUMA_STRATEGY_COUNT = 5
[ctypes.c_int],
None,
)
def llama_numa_init(numa: int, /):
...
def llama_numa_init(numa: int, /): ...
# // Call once at the end of the program - currently only used for MPI
@ -886,8 +890,7 @@ def llama_backend_free():
)
def llama_load_model_from_file(
path_model: bytes, params: llama_model_params, /
) -> Optional[llama_model_p]:
...
) -> Optional[llama_model_p]: ...
# LLAMA_API void llama_free_model(struct llama_model * model);
@ -896,8 +899,7 @@ def llama_load_model_from_file(
[llama_model_p_ctypes],
None,
)
def llama_free_model(model: llama_model_p, /):
...
def llama_free_model(model: llama_model_p, /): ...
# LLAMA_API struct llama_context * llama_new_context_with_model(
@ -910,8 +912,7 @@ def llama_free_model(model: llama_model_p, /):
)
def llama_new_context_with_model(
model: llama_model_p, params: llama_context_params, /
) -> Optional[llama_context_p]:
...
) -> Optional[llama_context_p]: ...
# // Frees all allocated memory
@ -932,80 +933,77 @@ def llama_free(ctx: llama_context_p, /):
[],
ctypes.c_int64,
)
def llama_time_us() -> int:
...
def llama_time_us() -> int: ...
# LLAMA_API size_t llama_max_devices(void);
@ctypes_function("llama_max_devices", [], ctypes.c_size_t)
def llama_max_devices() -> int:
...
def llama_max_devices() -> int: ...
# LLAMA_API bool llama_supports_mmap (void);
@ctypes_function("llama_supports_mmap", [], ctypes.c_bool)
def llama_supports_mmap() -> bool:
...
def llama_supports_mmap() -> bool: ...
# LLAMA_API bool llama_supports_mlock (void);
@ctypes_function("llama_supports_mlock", [], ctypes.c_bool)
def llama_supports_mlock() -> bool:
...
def llama_supports_mlock() -> bool: ...
# LLAMA_API bool llama_supports_gpu_offload(void);
@ctypes_function("llama_supports_gpu_offload", [], ctypes.c_bool)
def llama_supports_gpu_offload() -> bool:
...
def llama_supports_gpu_offload() -> bool: ...
# LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
@ctypes_function("llama_get_model", [llama_context_p_ctypes], llama_model_p_ctypes)
def llama_get_model(ctx: llama_context_p, /) -> Optional[llama_model_p]:
...
def llama_get_model(ctx: llama_context_p, /) -> Optional[llama_model_p]: ...
# LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
@ctypes_function("llama_n_ctx", [llama_context_p_ctypes], ctypes.c_uint32)
def llama_n_ctx(ctx: llama_context_p, /) -> int:
...
def llama_n_ctx(ctx: llama_context_p, /) -> int: ...
# LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
@ctypes_function("llama_n_batch", [llama_context_p_ctypes], ctypes.c_uint32)
def llama_n_batch(ctx: llama_context_p, /) -> int:
...
def llama_n_batch(ctx: llama_context_p, /) -> int: ...
# LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
@ctypes_function("llama_n_ubatch", [llama_context_p_ctypes], ctypes.c_uint32)
def llama_n_ubatch(ctx: llama_context_p, /) -> int: ...
# LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
@ctypes_function("llama_n_seq_max", [llama_context_p_ctypes], ctypes.c_uint32)
def llama_n_seq_max(ctx: llama_context_p, /) -> int: ...
# LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
@ctypes_function("llama_vocab_type", [llama_model_p_ctypes], ctypes.c_int)
def llama_vocab_type(model: llama_model_p, /) -> int:
...
def llama_vocab_type(model: llama_model_p, /) -> int: ...
# LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
@ctypes_function("llama_rope_type", [llama_model_p_ctypes], ctypes.c_int)
def llama_rope_type(model: llama_model_p, /) -> int:
...
def llama_rope_type(model: llama_model_p, /) -> int: ...
# LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
@ctypes_function("llama_n_vocab", [llama_model_p_ctypes], ctypes.c_int32)
def llama_n_vocab(model: llama_model_p, /) -> int:
...
def llama_n_vocab(model: llama_model_p, /) -> int: ...
# LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
@ctypes_function("llama_n_ctx_train", [llama_model_p_ctypes], ctypes.c_int32)
def llama_n_ctx_train(model: llama_model_p, /) -> int:
...
def llama_n_ctx_train(model: llama_model_p, /) -> int: ...
# LLAMA_API int32_t llama_n_embd (const struct llama_model * model);
@ctypes_function("llama_n_embd", [llama_model_p_ctypes], ctypes.c_int32)
def llama_n_embd(model: llama_model_p, /) -> int:
...
def llama_n_embd(model: llama_model_p, /) -> int: ...
# // Get the model's RoPE frequency scaling factor
@ -1192,8 +1190,7 @@ def llama_model_apply_lora_from_file(
path_base_model: Union[ctypes.c_char_p, bytes, None],
n_threads: Union[ctypes.c_int32, int],
/,
) -> int:
...
) -> int: ...
# //
@ -1219,7 +1216,7 @@ class llama_kv_cache_view_cell(ctypes.Structure):
# // Maximum number of sequences that can exist in a cell. It's not an error
# // if there are more sequences in a cell than this value, however they will
# // not be visible in the view cells_sequences.
# int32_t n_max_seq;
# int32_t n_seq_max;
# // Number of tokens in the cache. For example, if there are two populated
# // cells, the first with 1 sequence id in it and the second with 2 sequence
@ -1240,7 +1237,7 @@ class llama_kv_cache_view_cell(ctypes.Structure):
# struct llama_kv_cache_view_cell * cells;
# // The sequences for each cell. There will be n_max_seq items per cell.
# // The sequences for each cell. There will be n_seq_max items per cell.
# llama_seq_id * cells_sequences;
# };
class llama_kv_cache_view(ctypes.Structure):
@ -1260,14 +1257,14 @@ llama_kv_cache_view_p = ctypes.POINTER(llama_kv_cache_view)
# // Create an empty KV cache view. (use only for debugging purposes)
# LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq);
# LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max);
@ctypes_function(
"llama_kv_cache_view_init",
[llama_context_p_ctypes, ctypes.c_int32],
llama_kv_cache_view,
)
def llama_kv_cache_view_init(
ctx: llama_context_p, n_max_seq: Union[ctypes.c_int32, int], /
ctx: llama_context_p, n_seq_max: Union[ctypes.c_int32, int], /
) -> llama_kv_cache_view:
"""Create an empty KV cache view. (use only for debugging purposes)"""
...
@ -1582,8 +1579,7 @@ def llama_load_session_file(
n_token_capacity: Union[ctypes.c_size_t, int],
n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t],
/,
) -> int:
...
) -> int: ...
# LLAMA_API bool llama_save_session_file(
@ -1607,8 +1603,7 @@ def llama_save_session_file(
tokens: CtypesArray[llama_token],
n_token_count: Union[ctypes.c_size_t, int],
/,
) -> int:
...
) -> int: ...
# //
@ -1728,6 +1723,17 @@ def llama_set_n_threads(
"""
...
# // Set whether to use causal attention or not
# // If set to true, the model will only attend to the past tokens
# LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);
@ctypes_function("llama_set_causal_attn", [llama_context_p_ctypes, ctypes.c_bool], None)
def llama_set_causal_attn(ctx: llama_context_p, causal_attn: bool, /):
"""Set whether to use causal attention or not
If set to true, the model will only attend to the past tokens"""
...
# // Set abort callback
# LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);
@ctypes_function(
@ -1745,6 +1751,18 @@ def llama_set_abort_callback(
...
# // Wait until all computations are finished
# // This is automatically done when using one of the functions below to obtain the computation results
# // and is not necessary to call it explicitly in most cases
# LLAMA_API void llama_synchronize(struct llama_context * ctx);
@ctypes_function("llama_synchronize", [llama_context_p_ctypes], None)
def llama_synchronize(ctx: llama_context_p, /):
"""Wait until all computations are finished
This is automatically done when using one of the functions below to obtain the computation results
and is not necessary to call it explicitly in most cases"""
...
# // Token logits obtained from the last call to llama_decode()
# // The logits for the last token are stored in the last row
# // Logits for which llama_batch.logits[i] == 0 are undefined
@ -1760,7 +1778,7 @@ def llama_get_logits(ctx: llama_context_p, /) -> CtypesArray[ctypes.c_float]:
Logits for which llama_batch.logits[i] == 0 are undefined
Rows: n_tokens provided with llama_batch
Cols: n_vocab
Returns:
Pointer to the logits buffer of shape (n_tokens, n_vocab)"""
...
@ -1828,6 +1846,7 @@ def llama_get_embeddings_seq(
shape: [n_embd] (1-dimensional)"""
...
# //
# // Vocab
# //
@ -1839,8 +1858,7 @@ def llama_get_embeddings_seq(
)
def llama_token_get_text(
model: llama_model_p, token: Union[llama_token, int], /
) -> bytes:
...
) -> bytes: ...
# LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token);
@ -1849,8 +1867,7 @@ def llama_token_get_text(
)
def llama_token_get_score(
model: llama_model_p, token: Union[llama_token, int], /
) -> float:
...
) -> float: ...
# LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token);
@ -1859,8 +1876,7 @@ def llama_token_get_score(
)
def llama_token_get_type(
model: llama_model_p, token: Union[llama_token, int], /
) -> int:
...
) -> int: ...
# // Special tokens
@ -1913,20 +1929,17 @@ def llama_token_prefix(model: llama_model_p) -> int:
# LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle
@ctypes_function("llama_token_middle", [llama_model_p_ctypes], llama_token)
def llama_token_middle(model: llama_model_p, /) -> int:
...
def llama_token_middle(model: llama_model_p, /) -> int: ...
# LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix
@ctypes_function("llama_token_suffix", [llama_model_p_ctypes], llama_token)
def llama_token_suffix(model: llama_model_p, /) -> int:
...
def llama_token_suffix(model: llama_model_p, /) -> int: ...
# LLAMA_API llama_token llama_token_eot (const struct llama_model * model); // End of infill middle
@ctypes_function("llama_token_eot", [llama_model_p_ctypes], llama_token)
def llama_token_eot(model: llama_model_p, /) -> int:
...
def llama_token_eot(model: llama_model_p, /) -> int: ...
# //
@ -1936,7 +1949,7 @@ def llama_token_eot(model: llama_model_p, /) -> int:
# /// @details Convert the provided text into tokens.
# /// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
# /// @return Returns the number of tokens on success, no more than n_max_tokens
# /// @return Returns the number of tokens on success, no more than n_tokens_max
# /// @return Returns a negative number on failure - the number of tokens that would have been returned
# /// @param special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.
# /// Does not insert a leading space.
@ -1945,7 +1958,7 @@ def llama_token_eot(model: llama_model_p, /) -> int:
# const char * text,
# int32_t text_len,
# llama_token * tokens,
# int32_t n_max_tokens,
# int32_t n_tokens_max,
# bool add_bos,
# bool special);
@ctypes_function(
@ -1966,12 +1979,26 @@ def llama_tokenize(
text: bytes,
text_len: Union[ctypes.c_int, int],
tokens: CtypesArray[llama_token],
n_max_tokens: Union[ctypes.c_int, int],
n_tokens_max: Union[ctypes.c_int, int],
add_bos: Union[ctypes.c_bool, bool],
special: Union[ctypes.c_bool, bool],
/,
) -> int:
"""Convert the provided text into tokens."""
"""Convert the provided text into tokens.
Args:
model: The model to use for tokenization.
text: The text to tokenize.
text_len: The length of the text.
tokens: The tokens pointer must be large enough to hold the resulting tokens.
n_max_tokens: The maximum number of tokens to return.
add_bos: Whether to add a beginning-of-sentence token.
special: Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.
Does not insert a leading space.
Returns:
Returns the number of tokens on success, no more than n_tokens_max
Returns a negative number on failure - the number of tokens that would have been returned"""
...
@ -2043,8 +2070,7 @@ def llama_chat_apply_template(
chat: CtypesArray[llama_chat_message],
n_msg: int,
/,
) -> int:
...
) -> int: ...
# //
@ -2645,8 +2671,7 @@ def llama_beam_search(
n_past: Union[ctypes.c_int, int],
n_predict: Union[ctypes.c_int, int],
/,
):
...
): ...
# Performance information
@ -2723,5 +2748,4 @@ def llama_log_set(
[ctypes.c_void_p, llama_context_p_ctypes],
None,
)
def llama_dump_timing_info_yaml(stream: ctypes.c_void_p, ctx: llama_context_p, /):
...
def llama_dump_timing_info_yaml(stream: ctypes.c_void_p, ctx: llama_context_p, /): ...

2
vendor/llama.cpp vendored

@ -1 +1 @@
Subproject commit c2101a2e909ac7c08976d414e64e96c90ee5fa9e
Subproject commit 4e9a7f7f7fb6acbddd1462909c8d696e38edbfcc