未验证 提交 2d3087c6 编写于 作者: W Walter 提交者: GitHub

Merge pull request #1798 from RainFrost1/develop

修复Paddle develop分支 log打印两次的问题
......@@ -71,7 +71,7 @@ class Engine(object):
self.output_dir = self.config['Global']['output_dir']
log_file = os.path.join(self.output_dir, self.config["Arch"]["name"],
f"{mode}.log")
init_logger(name='root', log_file=log_file)
init_logger(log_file=log_file)
print_config(config)
# init train_func and eval_func
......@@ -92,7 +92,8 @@ class Engine(object):
self.vdl_writer = LogWriter(logdir=vdl_writer_path)
# set device
assert self.config["Global"]["device"] in ["cpu", "gpu", "xpu", "npu", "mlu"]
assert self.config["Global"][
"device"] in ["cpu", "gpu", "xpu", "npu", "mlu"]
self.device = paddle.set_device(self.config["Global"]["device"])
logger.info('train with paddle {} and device {}'.format(
paddle.__version__, self.device))
......@@ -107,9 +108,7 @@ class Engine(object):
self.scale_loss = 1.0
self.use_dynamic_loss_scaling = False
if self.amp:
AMP_RELATED_FLAGS_SETTING = {
'FLAGS_max_inplace_grad_add': 8,
}
AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, }
if paddle.is_compiled_with_cuda():
AMP_RELATED_FLAGS_SETTING.update({
'FLAGS_cudnn_batchnorm_spatial_persistent': 1
......
......@@ -71,7 +71,7 @@ def main(args):
log_file = os.path.join(global_config['output_dir'],
config["Arch"]["name"], f"{mode}.log")
init_logger(name='root', log_file=log_file)
init_logger(log_file=log_file)
print_config(config)
if global_config.get("is_distributed", True):
......
......@@ -22,7 +22,7 @@ import paddle.distributed as dist
_logger = None
def init_logger(name='root', log_file=None, log_level=logging.INFO):
def init_logger(name='ppcls', log_file=None, log_level=logging.INFO):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
......@@ -59,6 +59,7 @@ def init_logger(name='root', log_file=None, log_level=logging.INFO):
_logger.setLevel(log_level)
else:
_logger.setLevel(logging.ERROR)
_logger.propagate = False
def log_at_trainer0(log):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册