Enable finish reason tests

This commit is contained in:
Andrei Betlen 2023-10-19 02:56:45 -04:00
parent 09a8406c83
commit ef03d77b59

View file

@ -69,7 +69,7 @@ def test_llama_patch(monkeypatch):
n = 0 # reset n = 0 # reset
chunks = list(llama.create_completion(text, max_tokens=20, stream=True)) chunks = list(llama.create_completion(text, max_tokens=20, stream=True))
assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == output_text assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == output_text
# assert chunks[-1]["choices"][0]["finish_reason"] == "stop" assert chunks[-1]["choices"][0]["finish_reason"] == "stop"
## Test basic completion until stop sequence ## Test basic completion until stop sequence
n = 0 # reset n = 0 # reset
@ -83,19 +83,19 @@ def test_llama_patch(monkeypatch):
assert ( assert (
"".join(chunk["choices"][0]["text"] for chunk in chunks) == " jumps over the " "".join(chunk["choices"][0]["text"] for chunk in chunks) == " jumps over the "
) )
# assert chunks[-1]["choices"][0]["finish_reason"] == "stop" assert chunks[-1]["choices"][0]["finish_reason"] == "stop"
## Test basic completion until length ## Test basic completion until length
n = 0 # reset n = 0 # reset
completion = llama.create_completion(text, max_tokens=2) completion = llama.create_completion(text, max_tokens=2)
assert completion["choices"][0]["text"] == " jumps" assert completion["choices"][0]["text"] == " jumps"
# assert completion["choices"][0]["finish_reason"] == "length" assert completion["choices"][0]["finish_reason"] == "length"
## Test streaming completion until length ## Test streaming completion until length
n = 0 # reset n = 0 # reset
chunks = list(llama.create_completion(text, max_tokens=2, stream=True)) chunks = list(llama.create_completion(text, max_tokens=2, stream=True))
assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == " jumps" assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == " jumps"
# assert chunks[-1]["choices"][0]["finish_reason"] == "length" assert chunks[-1]["choices"][0]["finish_reason"] == "length"
def test_llama_pickle(): def test_llama_pickle():