未验证 提交 54632b5c 编写于 作者: 0 0x45f 提交者: GitHub

Fix param@grad type error for amp in run_program (#40938)

上级 09e5b00c
......@@ -204,7 +204,9 @@ class PartialProgramLayer:
"""
Lazy initialized property of train_amp_program.
"""
return self._append_backward_desc(self._infer_amp_program)
train_amp_program = self._append_backward_desc(self._infer_amp_program)
self._set_grad_type(self._params, train_amp_program)
return train_amp_program
@LazyInitialized
@switch_to_static_graph
......@@ -224,7 +226,10 @@ class PartialProgramLayer:
"""
Lazy initialized property of _train_pure_fp16_program.
"""
return self._append_backward_desc(self._infer_pure_fp16_program)
train_pure_fp16_program = self._append_backward_desc(
self._infer_pure_fp16_program)
self._set_grad_type(self._params, train_pure_fp16_program)
return train_pure_fp16_program
@LazyInitialized
def _infer_program_id(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册