未验证 提交 54632b5c 编写于 作者: 0 0x45f 提交者: GitHub

Fix param@grad type error for amp in run_program (#40938)

上级 09e5b00c
...@@ -204,7 +204,9 @@ class PartialProgramLayer: ...@@ -204,7 +204,9 @@ class PartialProgramLayer:
""" """
Lazy initialized property of train_amp_program. Lazy initialized property of train_amp_program.
""" """
return self._append_backward_desc(self._infer_amp_program) train_amp_program = self._append_backward_desc(self._infer_amp_program)
self._set_grad_type(self._params, train_amp_program)
return train_amp_program
@LazyInitialized @LazyInitialized
@switch_to_static_graph @switch_to_static_graph
...@@ -224,7 +226,10 @@ class PartialProgramLayer: ...@@ -224,7 +226,10 @@ class PartialProgramLayer:
""" """
Lazy initialized property of _train_pure_fp16_program. Lazy initialized property of _train_pure_fp16_program.
""" """
return self._append_backward_desc(self._infer_pure_fp16_program) train_pure_fp16_program = self._append_backward_desc(
self._infer_pure_fp16_program)
self._set_grad_type(self._params, train_pure_fp16_program)
return train_pure_fp16_program
@LazyInitialized @LazyInitialized
def _infer_program_id(self): def _infer_program_id(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册