feat: add support for KV cache quantization options (#1307)

* add KV cache quantization options

https://github.com/abetlen/llama-cpp-python/discussions/1220
https://github.com/abetlen/llama-cpp-python/issues/1305

* Add ggml_type

* Use ggml_type instead of string for quantization

* Add server support

---------

Co-authored-by: Andrei Betlen <abetlen@gmail.com>
This commit is contained in:
Limour 2024-04-01 22:19:28 +08:00 committed by GitHub
parent aa9f1ae011
commit f165048a69
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 94 additions and 41 deletions

View file

@ -105,6 +105,9 @@ class Llama:
draft_model: Optional[LlamaDraftModel] = None,
# Tokenizer Override
tokenizer: Optional[BaseLlamaTokenizer] = None,
# KV cache quantization
type_k: Optional[int] = None,
type_v: Optional[int] = None,
# Misc
verbose: bool = True,
# Extra Params
@ -172,6 +175,8 @@ class Llama:
draft_model: Optional draft model to use for speculative decoding.
tokenizer: Optional tokenizer to override the default tokenizer from llama.cpp.
verbose: Print verbose output to stderr.
type_k: KV cache data type for K (default: f16)
type_v: KV cache data type for V (default: f16)
Raises:
ValueError: If the model path does not exist.
@ -298,7 +303,11 @@ class Llama:
) # Must be set to True for speculative decoding
self.context_params.embeddings = embedding # TODO: Rename to embeddings
self.context_params.offload_kqv = offload_kqv
# KV cache quantization
if type_k is not None:
self.context_params.type_k = type_k
if type_v is not None:
self.context_params.type_v = type_v
# Sampling Params
self.last_n_tokens_size = last_n_tokens_size
@ -1724,6 +1733,7 @@ class Llama:
n_threads=self.context_params.n_threads,
n_threads_batch=self.context_params.n_threads_batch,
rope_scaling_type=self.context_params.rope_scaling_type,
pooling_type=self.context_params.pooling_type,
rope_freq_base=self.context_params.rope_freq_base,
rope_freq_scale=self.context_params.rope_freq_scale,
yarn_ext_factor=self.context_params.yarn_ext_factor,
@ -1733,6 +1743,7 @@ class Llama:
yarn_orig_ctx=self.context_params.yarn_orig_ctx,
logits_all=self.context_params.logits_all,
embedding=self.context_params.embeddings,
offload_kqv=self.context_params.offload_kqv,
# Sampling Params
last_n_tokens_size=self.last_n_tokens_size,
# LoRA Params
@ -1744,51 +1755,17 @@ class Llama:
# Chat Format Params
chat_format=self.chat_format,
chat_handler=self.chat_handler,
# Speculative Decidng
draft_model=self.draft_model,
# KV cache quantization
type_k=self.context_params.type_k,
type_v=self.context_params.type_v,
# Misc
verbose=self.verbose,
)
def __setstate__(self, state):
self.__init__(
model_path=state["model_path"],
# Model Params
n_gpu_layers=state["n_gpu_layers"],
split_mode=state["split_mode"],
main_gpu=state["main_gpu"],
tensor_split=state["tensor_split"],
vocab_only=state["vocab_only"],
use_mmap=state["use_mmap"],
use_mlock=state["use_mlock"],
kv_overrides=state["kv_overrides"],
# Context Params
seed=state["seed"],
n_ctx=state["n_ctx"],
n_batch=state["n_batch"],
n_threads=state["n_threads"],
n_threads_batch=state["n_threads_batch"],
rope_freq_base=state["rope_freq_base"],
rope_freq_scale=state["rope_freq_scale"],
rope_scaling_type=state["rope_scaling_type"],
yarn_ext_factor=state["yarn_ext_factor"],
yarn_attn_factor=state["yarn_attn_factor"],
yarn_beta_fast=state["yarn_beta_fast"],
yarn_beta_slow=state["yarn_beta_slow"],
yarn_orig_ctx=state["yarn_orig_ctx"],
logits_all=state["logits_all"],
embedding=state["embedding"],
# Sampling Params
last_n_tokens_size=state["last_n_tokens_size"],
# LoRA Params
lora_base=state["lora_base"],
lora_path=state["lora_path"],
# Backend Params
numa=state["numa"],
# Chat Format Params
chat_format=state["chat_format"],
chat_handler=state["chat_handler"],
# Misc
verbose=state["verbose"],
)
self.__init__(**state)
def save_state(self) -> LlamaState:
assert self._ctx.ctx is not None

View file

@ -141,6 +141,70 @@ def byref(obj: CtypesCData, offset: Optional[int] = None) -> CtypesRef[CtypesCDa
byref = ctypes.byref # type: ignore
# from ggml.h
# // NOTE: always add types at the end of the enum to keep backward compatibility
# enum ggml_type {
# GGML_TYPE_F32 = 0,
# GGML_TYPE_F16 = 1,
# GGML_TYPE_Q4_0 = 2,
# GGML_TYPE_Q4_1 = 3,
# // GGML_TYPE_Q4_2 = 4, support has been removed
# // GGML_TYPE_Q4_3 = 5, support has been removed
# GGML_TYPE_Q5_0 = 6,
# GGML_TYPE_Q5_1 = 7,
# GGML_TYPE_Q8_0 = 8,
# GGML_TYPE_Q8_1 = 9,
# GGML_TYPE_Q2_K = 10,
# GGML_TYPE_Q3_K = 11,
# GGML_TYPE_Q4_K = 12,
# GGML_TYPE_Q5_K = 13,
# GGML_TYPE_Q6_K = 14,
# GGML_TYPE_Q8_K = 15,
# GGML_TYPE_IQ2_XXS = 16,
# GGML_TYPE_IQ2_XS = 17,
# GGML_TYPE_IQ3_XXS = 18,
# GGML_TYPE_IQ1_S = 19,
# GGML_TYPE_IQ4_NL = 20,
# GGML_TYPE_IQ3_S = 21,
# GGML_TYPE_IQ2_S = 22,
# GGML_TYPE_IQ4_XS = 23,
# GGML_TYPE_I8 = 24,
# GGML_TYPE_I16 = 25,
# GGML_TYPE_I32 = 26,
# GGML_TYPE_I64 = 27,
# GGML_TYPE_F64 = 28,
# GGML_TYPE_IQ1_M = 29,
# GGML_TYPE_COUNT,
# };
GGML_TYPE_F32 = 0
GGML_TYPE_F16 = 1
GGML_TYPE_Q4_0 = 2
GGML_TYPE_Q4_1 = 3
GGML_TYPE_Q5_0 = 6
GGML_TYPE_Q5_1 = 7
GGML_TYPE_Q8_0 = 8
GGML_TYPE_Q8_1 = 9
GGML_TYPE_Q2_K = 10
GGML_TYPE_Q3_K = 11
GGML_TYPE_Q4_K = 12
GGML_TYPE_Q5_K = 13
GGML_TYPE_Q6_K = 14
GGML_TYPE_Q8_K = 15
GGML_TYPE_IQ2_XXS = 16
GGML_TYPE_IQ2_XS = 17
GGML_TYPE_IQ3_XXS = 18
GGML_TYPE_IQ1_S = 19
GGML_TYPE_IQ4_NL = 20
GGML_TYPE_IQ3_S = 21
GGML_TYPE_IQ2_S = 22
GGML_TYPE_IQ4_XS = 23
GGML_TYPE_I8 = 24
GGML_TYPE_I16 = 25
GGML_TYPE_I32 = 26
GGML_TYPE_I64 = 27
GGML_TYPE_F64 = 28
GGML_TYPE_IQ1_M = 29
GGML_TYPE_COUNT = 30
# from ggml-backend.h
# typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);

View file

@ -175,6 +175,9 @@ class LlamaProxy:
chat_handler=chat_handler,
# Speculative Decoding
draft_model=draft_model,
# KV Cache Quantization
type_k=settings.type_k,
type_v=settings.type_v,
# Tokenizer
tokenizer=tokenizer,
# Misc

View file

@ -159,6 +159,15 @@ class ModelSettings(BaseSettings):
default=10,
description="Number of tokens to predict using the draft model.",
)
# KV Cache Quantization
type_k: Optional[int] = Field(
default=None,
description="Type of the key cache quantization.",
)
type_v: Optional[int] = Field(
default=None,
description="Type of the value cache quantization.",
)
# Misc
verbose: bool = Field(
default=True, description="Whether to print debug information."