From 257438f32b3bf77eca3970fe3f50cbe4056b4719 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Tue, 23 Aug 2022 16:33:54 +0800 Subject: [PATCH] bugfix (#45332) --- python/paddle/distributed/auto_parallel/engine.py | 5 ++--- python/paddle/fluid/contrib/mixed_precision/fp16_utils.py | 5 +---- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index 35ff882491a..11953aa085d 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -189,9 +189,8 @@ class Engine: serial_main_prog = self._orig_main_prog.clone() serial_startup_prog = self._orig_startup_prog.clone() # FIXME to support grad clip - # with static.program_guard(serial_main_prog, serial_startup_prog), \ - # utils.unique_name.guard(): - with static.program_guard(serial_main_prog, serial_startup_prog): + with static.program_guard(serial_main_prog, serial_startup_prog), \ + utils.unique_name.guard(): inputs_spec = self.inputs_spec labels_spec = self.labels_spec if self.labels_spec else [] inputs = [s._create_feed_layer() for s in inputs_spec] diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py index e35dc901c83..b23c94c7e49 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py @@ -542,12 +542,9 @@ def cast_parameters_to_fp16(place, program, scope=None, to_fp16_var_names=None): fp16_var_names = to_fp16_var_names if to_fp16_var_names else set() var_scope = scope if scope else global_scope() - print( - "======================cast_parameters_to_fp16==============================" - ) for param in all_parameters: if param.name in fp16_var_names: - print("---- cast {} to fp16 dtype ----".format(param.name)) + _logger.debug("---- cast {} to fp16 dtype ----".format(param.name)) param_t = var_scope.find_var(param.name).get_tensor() data = np.array(param_t) param_t.set(np.float16(data), place) -- GitLab