From ce57920e608d075335dbd291476420f2abc491be Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 28 Jul 2023 14:45:18 -0400 Subject: [PATCH] Suppress llama.cpp output when loading model. --- llama_cpp/llama.py | 23 +++++++++++++++++++---- llama_cpp/utils.py | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 4 deletions(-) create mode 100644 llama_cpp/utils.py diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 2537af2..47f71e9 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -27,6 +27,8 @@ from .llama_types import * import numpy as np import numpy.typing as npt +from .utils import suppress_stdout_stderr + class BaseLlamaCache(ABC): """Base cache class for a llama.cpp model.""" @@ -308,12 +310,25 @@ class Llama: if not os.path.exists(model_path): raise ValueError(f"Model path does not exist: {model_path}") - self.model = llama_cpp.llama_load_model_from_file( - self.model_path.encode("utf-8"), self.params - ) + if verbose: + self.model = llama_cpp.llama_load_model_from_file( + self.model_path.encode("utf-8"), self.params + ) + else: + with suppress_stdout_stderr(): + self.model = llama_cpp.llama_load_model_from_file( + self.model_path.encode("utf-8"), self.params + ) assert self.model is not None - self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.params) + if verbose: + self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.params) + else: + with suppress_stdout_stderr(): + print("here") + self.ctx = llama_cpp.llama_new_context_with_model( + self.model, self.params + ) assert self.ctx is not None diff --git a/llama_cpp/utils.py b/llama_cpp/utils.py new file mode 100644 index 0000000..c14f53f --- /dev/null +++ b/llama_cpp/utils.py @@ -0,0 +1,38 @@ +import os +import sys + + +class suppress_stdout_stderr(object): + # Oddly enough this works better than the contextlib version + def __enter__(self): + self.outnull_file = open(os.devnull, "w") + self.errnull_file = open(os.devnull, "w") + + self.old_stdout_fileno_undup = sys.stdout.fileno() + self.old_stderr_fileno_undup = sys.stderr.fileno() + + self.old_stdout_fileno = os.dup(sys.stdout.fileno()) + self.old_stderr_fileno = os.dup(sys.stderr.fileno()) + + self.old_stdout = sys.stdout + self.old_stderr = sys.stderr + + os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup) + os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup) + + sys.stdout = self.outnull_file + sys.stderr = self.errnull_file + return self + + def __exit__(self, *_): + sys.stdout = self.old_stdout + sys.stderr = self.old_stderr + + os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup) + os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup) + + os.close(self.old_stdout_fileno) + os.close(self.old_stderr_fileno) + + self.outnull_file.close() + self.errnull_file.close()