提交 23e2cec4 编写于 作者: D dongshuilong

cherry-pick fix log twice bug

上级 90cedca5
...@@ -71,7 +71,7 @@ class Engine(object): ...@@ -71,7 +71,7 @@ class Engine(object):
self.output_dir = self.config['Global']['output_dir'] self.output_dir = self.config['Global']['output_dir']
log_file = os.path.join(self.output_dir, self.config["Arch"]["name"], log_file = os.path.join(self.output_dir, self.config["Arch"]["name"],
f"{mode}.log") f"{mode}.log")
init_logger(name='root', log_file=log_file) init_logger(log_file=log_file)
print_config(config) print_config(config)
# init train_func and eval_func # init train_func and eval_func
...@@ -107,10 +107,11 @@ class Engine(object): ...@@ -107,10 +107,11 @@ class Engine(object):
self.scale_loss = 1.0 self.scale_loss = 1.0
self.use_dynamic_loss_scaling = False self.use_dynamic_loss_scaling = False
if self.amp: if self.amp:
AMP_RELATED_FLAGS_SETTING = { AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, }
'FLAGS_cudnn_batchnorm_spatial_persistent': 1, if paddle.is_compiled_with_cuda():
'FLAGS_max_inplace_grad_add': 8, AMP_RELATED_FLAGS_SETTING.update({
} 'FLAGS_cudnn_batchnorm_spatial_persistent': 1
})
paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING) paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
if "class_num" in config["Global"]: if "class_num" in config["Global"]:
......
...@@ -71,7 +71,7 @@ def main(args): ...@@ -71,7 +71,7 @@ def main(args):
log_file = os.path.join(global_config['output_dir'], log_file = os.path.join(global_config['output_dir'],
config["Arch"]["name"], f"{mode}.log") config["Arch"]["name"], f"{mode}.log")
init_logger(name='root', log_file=log_file) init_logger(log_file=log_file)
print_config(config) print_config(config)
if global_config.get("is_distributed", True): if global_config.get("is_distributed", True):
......
...@@ -22,7 +22,7 @@ import paddle.distributed as dist ...@@ -22,7 +22,7 @@ import paddle.distributed as dist
_logger = None _logger = None
def init_logger(name='root', log_file=None, log_level=logging.INFO): def init_logger(name='ppcls', log_file=None, log_level=logging.INFO):
"""Initialize and get a logger by name. """Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will logger by adding one or two handlers, otherwise the initialized logger will
...@@ -59,6 +59,7 @@ def init_logger(name='root', log_file=None, log_level=logging.INFO): ...@@ -59,6 +59,7 @@ def init_logger(name='root', log_file=None, log_level=logging.INFO):
_logger.setLevel(log_level) _logger.setLevel(log_level)
else: else:
_logger.setLevel(logging.ERROR) _logger.setLevel(logging.ERROR)
_logger.propagate = False
def log_at_trainer0(log): def log_at_trainer0(log):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册