logger.py 2.6 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

L
LielinJiang 已提交
15 16 17 18
import logging
import os
import sys

L
LielinJiang 已提交
19
from paddle.distributed import ParallelEnv
L
LielinJiang 已提交
20

R
ruri 已提交
21
logger_initialized = []
L
LielinJiang 已提交
22

L
LielinJiang 已提交
23 24 25

def setup_logger(output=None, name="ppgan"):
    """
L
LielinJiang 已提交
26
    Initialize the ppgan logger and set its verbosity level to "INFO".
L
LielinJiang 已提交
27 28 29 30 31 32 33 34 35 36 37

    Args:
        output (str): a file name or a directory to save log. If None, will not save log file.
            If ends with ".txt" or ".log", assumed to be a file name.
            Otherwise, logs will be saved to `output/log.txt`.
        name (str): the root module name of this logger

    Returns:
        logging.Logger: a logger
    """
    logger = logging.getLogger(name)
L
LielinJiang 已提交
38 39
    if name in logger_initialized:
        return logger
Q
qingqing01 已提交
40
    logger.setLevel(logging.INFO)
L
LielinJiang 已提交
41 42 43
    logger.propagate = False

    plain_formatter = logging.Formatter(
L
LielinJiang 已提交
44 45
        "[%(asctime)s] %(name)s %(levelname)s: %(message)s",
        datefmt="%m/%d %H:%M:%S")
L
LielinJiang 已提交
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
    # stdout logging: master only
    local_rank = ParallelEnv().local_rank
    if local_rank == 0:
        ch = logging.StreamHandler(stream=sys.stdout)
        ch.setLevel(logging.DEBUG)
        formatter = plain_formatter
        ch.setFormatter(formatter)
        logger.addHandler(ch)

    # file logging: all workers
    if output is not None:
        if output.endswith(".txt") or output.endswith(".log"):
            filename = output
        else:
            filename = os.path.join(output, "log.txt")
        if local_rank > 0:
            filename = filename + ".rank{}".format(local_rank)

        # PathManager.mkdirs(os.path.dirname(filename))
        os.makedirs(os.path.dirname(filename), exist_ok=True)

        # fh = logging.StreamHandler(_cached_log_stream(filename)
        fh = logging.FileHandler(filename, mode='a')
        fh.setLevel(logging.DEBUG)
        fh.setFormatter(plain_formatter)
        logger.addHandler(fh)
R
ruri 已提交
72
    logger_initialized.append(name)
L
LielinJiang 已提交
73
    return logger
L
LielinJiang 已提交
74 75


L
LielinJiang 已提交
76
def get_logger(name='ppgan'):
L
LielinJiang 已提交
77 78 79 80
    logger = logging.getLogger(name)
    if name in logger_initialized:
        return logger

L
LielinJiang 已提交
81
    return setup_logger(name=name)