Change pointer to lower overhead byref

This commit is contained in:
Andrei Betlen 2023-05-07 20:01:34 -04:00
parent 14da46f16e
commit e72f58614b

View file

@ -295,47 +295,47 @@ class Llama:
ctx=self.ctx,
last_tokens_data=last_n_tokens_data,
last_tokens_size=last_n_tokens_size,
candidates=llama_cpp.ctypes.pointer(candidates),
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
penalty=repeat_penalty,
)
if float(temp.value) == 0.0:
return llama_cpp.llama_sample_token_greedy(
ctx=self.ctx,
candidates=llama_cpp.ctypes.pointer(candidates),
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
)
else:
llama_cpp.llama_sample_top_k(
ctx=self.ctx,
candidates=llama_cpp.ctypes.pointer(candidates),
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
k=top_k,
min_keep=llama_cpp.c_size_t(1),
)
llama_cpp.llama_sample_tail_free(
ctx=self.ctx,
candidates=llama_cpp.ctypes.pointer(candidates),
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
z=llama_cpp.c_float(1.0),
min_keep=llama_cpp.c_size_t(1),
)
llama_cpp.llama_sample_typical(
ctx=self.ctx,
candidates=llama_cpp.ctypes.pointer(candidates),
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
p=llama_cpp.c_float(1.0),
min_keep=llama_cpp.c_size_t(1),
)
llama_cpp.llama_sample_top_p(
ctx=self.ctx,
candidates=llama_cpp.ctypes.pointer(candidates),
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
p=top_p,
min_keep=llama_cpp.c_size_t(1),
)
llama_cpp.llama_sample_temperature(
ctx=self.ctx,
candidates=llama_cpp.ctypes.pointer(candidates),
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
temp=temp,
)
return llama_cpp.llama_sample_token(
ctx=self.ctx,
candidates=llama_cpp.ctypes.pointer(candidates),
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
)
def sample(