diff --git a/tests/test_llama.py b/tests/test_llama.py index fe2bd66..2bf38b3 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -1,4 +1,3 @@ -import pytest import llama_cpp MODEL = "./vendor/llama.cpp/models/ggml-vocab.bin" @@ -15,15 +14,20 @@ def test_llama(): assert llama.detokenize(llama.tokenize(text)) == text -@pytest.mark.skip(reason="need to update sample mocking") +# @pytest.mark.skip(reason="need to update sample mocking") def test_llama_patch(monkeypatch): llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True) + n_vocab = int(llama_cpp.llama_n_vocab(llama.ctx)) ## Set up mock function def mock_eval(*args, **kwargs): return 0 + + def mock_get_logits(*args, **kwargs): + return (llama_cpp.c_float * n_vocab)(*[llama_cpp.c_float(0) for _ in range(n_vocab)]) monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval) + monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits) output_text = " jumps over the lazy dog." output_tokens = llama.tokenize(output_text.encode("utf-8")) @@ -38,7 +42,7 @@ def test_llama_patch(monkeypatch): else: return token_eos - monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_top_p_top_k", mock_sample) + monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_token", mock_sample) text = "The quick brown fox" @@ -97,15 +101,19 @@ def test_llama_pickle(): assert llama.detokenize(llama.tokenize(text)) == text -@pytest.mark.skip(reason="need to update sample mocking") def test_utf8(monkeypatch): llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True) + n_vocab = int(llama_cpp.llama_n_vocab(llama.ctx)) ## Set up mock function def mock_eval(*args, **kwargs): return 0 + def mock_get_logits(*args, **kwargs): + return (llama_cpp.c_float * n_vocab)(*[llama_cpp.c_float(0) for _ in range(n_vocab)]) + monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval) + monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits) output_text = "😀" output_tokens = llama.tokenize(output_text.encode("utf-8")) @@ -120,7 +128,7 @@ def test_utf8(monkeypatch): else: return token_eos - monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_top_p_top_k", mock_sample) + monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_token", mock_sample) ## Test basic completion with utf8 multibyte n = 0 # reset