diff --git a/tests/test_llama.py b/tests/test_llama.py index 9ce2a2a..c98148e 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -96,6 +96,56 @@ def mock_llama(monkeypatch): monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode) monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits) + def mock_kv_cache_clear(ctx: llama_cpp.llama_context_p): + # Test some basic invariants of this mocking technique + assert ctx == llama._ctx.ctx, "context does not match mock_llama" + return + + def mock_kv_cache_seq_rm( + ctx: llama_cpp.llama_context_p, + seq_id: llama_cpp.llama_seq_id, + pos0: llama_cpp.llama_pos, + pos1: llama_cpp.llama_pos, + ): + # Test some basic invariants of this mocking technique + assert ctx == llama._ctx.ctx, "context does not match mock_llama" + return + + def mock_kv_cache_seq_cp( + ctx: llama_cpp.llama_context_p, + seq_id_src: llama_cpp.llama_seq_id, + seq_id_dst: llama_cpp.llama_seq_id, + pos0: llama_cpp.llama_pos, + pos1: llama_cpp.llama_pos, + ): + # Test some basic invariants of this mocking technique + assert ctx == llama._ctx.ctx, "context does not match mock_llama" + return + + def mock_kv_cache_seq_keep( + ctx: llama_cpp.llama_context_p, + seq_id: llama_cpp.llama_seq_id, + ): + # Test some basic invariants of this mocking technique + assert ctx == llama._ctx.ctx, "context does not match mock_llama" + return + + def mock_kv_cache_seq_shift( + ctx: llama_cpp.llama_context_p, + seq_id: llama_cpp.llama_seq_id, + pos0: llama_cpp.llama_pos, + pos1: llama_cpp.llama_pos, + ): + # Test some basic invariants of this mocking technique + assert ctx == llama._ctx.ctx, "context does not match mock_llama" + return + + monkeypatch.setattr("llama_cpp.llama_cpp.llama_kv_cache_clear", mock_kv_cache_clear) + monkeypatch.setattr("llama_cpp.llama_cpp.llama_kv_cache_seq_rm", mock_kv_cache_seq_rm) + monkeypatch.setattr("llama_cpp.llama_cpp.llama_kv_cache_seq_cp", mock_kv_cache_seq_cp) + monkeypatch.setattr("llama_cpp.llama_cpp.llama_kv_cache_seq_keep", mock_kv_cache_seq_keep) + monkeypatch.setattr("llama_cpp.llama_cpp.llama_kv_cache_seq_shift", mock_kv_cache_seq_shift) + return setup_mock