From 54632b5c49faac5c4091d5165e2c32172f5d7b75 Mon Sep 17 00:00:00 2001 From: 0x45f <23097963+0x45f@users.noreply.github.com> Date: Fri, 25 Mar 2022 19:14:16 +0800 Subject: [PATCH] Fix param@grad type error for amp in run_program (#40938) --- .../fluid/dygraph/dygraph_to_static/partial_program.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py b/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py index b8f8de67cc..90f960798e 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py @@ -204,7 +204,9 @@ class PartialProgramLayer: """ Lazy initialized property of train_amp_program. """ - return self._append_backward_desc(self._infer_amp_program) + train_amp_program = self._append_backward_desc(self._infer_amp_program) + self._set_grad_type(self._params, train_amp_program) + return train_amp_program @LazyInitialized @switch_to_static_graph @@ -224,7 +226,10 @@ class PartialProgramLayer: """ Lazy initialized property of _train_pure_fp16_program. """ - return self._append_backward_desc(self._infer_pure_fp16_program) + train_pure_fp16_program = self._append_backward_desc( + self._infer_pure_fp16_program) + self._set_grad_type(self._params, train_pure_fp16_program) + return train_pure_fp16_program @LazyInitialized def _infer_program_id(self): -- GitLab