import logging import sys import torch.distributed as dist class LoggerFactory: @staticmethod def create_logger(name=None, level=logging.INFO): """create a logger Args: name (str): name of the logger level: level of logger Raises: ValueError is name is None """ if name is None: raise ValueError("name for logger cannot be None") formatter = logging.Formatter( "[%(asctime)s] [%(levelname)s] " "[%(filename)s:%(lineno)d:%(funcName)s] %(message)s") logger_ = logging.getLogger(name) logger_.setLevel(level) logger_.propagate = False ch = logging.StreamHandler(stream=sys.stdout) ch.setLevel(level) ch.setFormatter(formatter) logger_.addHandler(ch) return logger_ logger = LoggerFactory.create_logger(name="DeepSpeed", level=logging.INFO) def log_dist(message, ranks=None, level=logging.INFO): """Log message when one of following condition meets + not dist.is_initialized() + dist.get_rank() in ranks if ranks is not None or ranks = [-1] Args: message (str) ranks (list) level (int) """ should_log = not dist.is_initialized() ranks = ranks or [] my_rank = dist.get_rank() if dist.is_initialized() else -1 if ranks and not should_log: should_log = ranks[0] == -1 should_log = should_log or (my_rank in set(ranks)) if should_log: final_message = "[Rank {}] {}".format(my_rank, message) logger.log(level, final_message)