From b2cbdda8d43e8de21e4d5f43a29d3b5741c96850 Mon Sep 17 00:00:00 2001 From: LielinJiang Date: Fri, 24 Apr 2020 12:26:29 +0000 Subject: [PATCH] add logger --- hapi/__init__.py | 17 ++++++--------- hapi/logger.py | 55 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 11 deletions(-) create mode 100644 hapi/logger.py diff --git a/hapi/__init__.py b/hapi/__init__.py index eb3f008..3860aaf 100644 --- a/hapi/__init__.py +++ b/hapi/__init__.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from hapi import logger from hapi.configure import Config from hapi import callbacks from hapi import datasets @@ -22,16 +22,11 @@ from hapi import model from hapi import progressbar from hapi import text from hapi import vision +from hapi import loss + +logger.setup_logger() __all__ = [ - 'Config', - 'callbacks', - 'datasets', - 'distributed', - 'download', - 'metrics', - 'model', - 'progressbar', - 'text', - 'vision', + 'Config', 'callbacks', 'datasets', 'distributed', 'download', 'metrics', + 'model', 'progressbar', 'text', 'vision', 'loss' ] diff --git a/hapi/logger.py b/hapi/logger.py new file mode 100644 index 0000000..91e5ae8 --- /dev/null +++ b/hapi/logger.py @@ -0,0 +1,55 @@ +import os +import sys +import logging +import functools + +from paddle.fluid.dygraph.parallel import ParallelEnv + + +@functools.lru_cache() +def setup_logger(output=None, name="hapi", log_level=logging.INFO): + """ + Initialize logger of hapi and set its verbosity level to "INFO". + + Args: + output (str): a file name or a directory to save log. If None, will not save log file. + If ends with ".txt" or ".log", assumed to be a file name. + Otherwise, logs will be saved to `output/log.txt`. + name (str): the root module name of this logger. Default: 'hapi' + + Returns: + logging.Logger: a logger + """ + logger = logging.getLogger(name) + logger.propagate = False + + format_str = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' + logging.basicConfig(format=format_str, level=log_level) + + # stdout logging: only local rank==0 + local_rank = ParallelEnv().local_rank + if local_rank == 0: + ch = logging.StreamHandler(stream=sys.stdout) + ch.setLevel(logging.DEBUG) + + ch.setFormatter(logging.Formatter(format_str)) + logger.addHandler(ch) + + # file logging if output is not None: all workers + if output is not None: + if output.endswith(".txt") or output.endswith(".log"): + filename = output + else: + filename = os.path.join(output, "log.txt") + if local_rank > 0: + filename = filename + ".rank{}".format(local_rank) + + if not os.path.exists(os.path.dirname(filename)): + os.makedirs(os.path.dirname(filename)) + + fh = logging.StreamHandler(filename) + fh.setLevel(logging.DEBUG) + fh.setFormatter(logging.Formatter(format_str)) + logger.addHandler(fh) + + return logger -- GitLab