未验证 提交 257438f3 编写于 作者: J JZ-LIANG 提交者: GitHub

bugfix (#45332)

上级 fe6aacc0
......@@ -189,9 +189,8 @@ class Engine:
serial_main_prog = self._orig_main_prog.clone()
serial_startup_prog = self._orig_startup_prog.clone()
# FIXME to support grad clip
# with static.program_guard(serial_main_prog, serial_startup_prog), \
# utils.unique_name.guard():
with static.program_guard(serial_main_prog, serial_startup_prog):
with static.program_guard(serial_main_prog, serial_startup_prog), \
utils.unique_name.guard():
inputs_spec = self.inputs_spec
labels_spec = self.labels_spec if self.labels_spec else []
inputs = [s._create_feed_layer() for s in inputs_spec]
......
......@@ -542,12 +542,9 @@ def cast_parameters_to_fp16(place, program, scope=None, to_fp16_var_names=None):
fp16_var_names = to_fp16_var_names if to_fp16_var_names else set()
var_scope = scope if scope else global_scope()
print(
"======================cast_parameters_to_fp16=============================="
)
for param in all_parameters:
if param.name in fp16_var_names:
print("---- cast {} to fp16 dtype ----".format(param.name))
_logger.debug("---- cast {} to fp16 dtype ----".format(param.name))
param_t = var_scope.find_var(param.name).get_tensor()
data = np.array(param_t)
param_t.set(np.float16(data), place)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册