diff --git a/tests/test_llama.py b/tests/test_llama.py index 84fa31c..9ce2a2a 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -47,16 +47,22 @@ def test_llama_cpp_tokenization(): @pytest.fixture def mock_llama(monkeypatch): def setup_mock(llama: llama_cpp.Llama, output_text: str): + n_ctx = llama.n_ctx() n_vocab = llama.n_vocab() output_tokens = llama.tokenize( output_text.encode("utf-8"), add_bos=True, special=True ) + logits = (llama_cpp.c_float * (n_vocab * n_ctx))(-100.0) + for i in range(n_ctx): + output_idx = i + 1 # logits for first tokens predict second token + if output_idx < len(output_tokens): + logits[i * n_vocab + output_tokens[output_idx]] = 100.0 + else: + logits[i * n_vocab + llama.token_eos()] = 100.0 n = 0 last_n_tokens = 0 def mock_decode(ctx: llama_cpp.llama_context_p, batch: llama_cpp.llama_batch): - nonlocal n - nonlocal last_n_tokens # Test some basic invariants of this mocking technique assert ctx == llama._ctx.ctx, "context does not match mock_llama" assert batch.n_tokens > 0, "no tokens in batch" @@ -70,26 +76,22 @@ def mock_llama(monkeypatch): batch.n_tokens - 1 ], "logits not allocated for last token" # Update the mock context state + nonlocal n + nonlocal last_n_tokens n = max(batch.pos[i] for i in range(batch.n_tokens)) + 1 last_n_tokens = batch.n_tokens return 0 - def mock_get_logits(*args, **kwargs): - nonlocal n - nonlocal last_n_tokens + def mock_get_logits(ctx: llama_cpp.llama_context_p): + # Test some basic invariants of this mocking technique + assert ctx == llama._ctx.ctx, "context does not match mock_llama" assert n > 0, "mock_llama_decode not called" assert last_n_tokens > 0, "mock_llama_decode not called" - logits = (llama_cpp.c_float * (last_n_tokens * n_vocab))(-100.0) - for logits_idx, output_idx in enumerate( - range(n - last_n_tokens + 1, n + 1) - ): - if output_idx < len(output_tokens): - logits[ - logits_idx * last_n_tokens + output_tokens[output_idx] - ] = 100.0 - else: - logits[logits_idx * last_n_tokens + llama.token_eos()] = 100.0 - return logits + # Return view of logits for last_n_tokens + return (llama_cpp.c_float * (last_n_tokens * n_vocab)).from_address( + ctypes.addressof(logits) + + (n - last_n_tokens) * n_vocab * ctypes.sizeof(llama_cpp.c_float) + ) monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode) monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)