From f484a61e8722fff5cf8a19bdcc29f95d9173d326 Mon Sep 17 00:00:00 2001 From: WangZhen <23097963+0x45f@users.noreply.github.com> Date: Tue, 3 Jan 2023 20:14:53 +0800 Subject: [PATCH] [Dy2St]Fix param and out grad names in dy2st for high order grad (#49461) * Fix param and out grad names in dy2st for high order grad --- .../eager/to_static/run_program_op_func.h | 6 -- .../eager/to_static/run_program_op_node.h | 47 +++++++++------- paddle/fluid/operators/run_program_op.cc | 8 +++ .../tests/unittests/test_eager_run_program.py | 4 ++ .../tests/unittests/test_run_program_op.py | 18 ++++++ .../paddle/jit/dy2static/partial_program.py | 55 ++++++++++++++++++- 6 files changed, 108 insertions(+), 30 deletions(-) diff --git a/paddle/fluid/eager/to_static/run_program_op_func.h b/paddle/fluid/eager/to_static/run_program_op_func.h index 8a6b59808d..e58c9bd0c4 100644 --- a/paddle/fluid/eager/to_static/run_program_op_func.h +++ b/paddle/fluid/eager/to_static/run_program_op_func.h @@ -80,16 +80,10 @@ inline void run_program_ad_func( trace_backward, &p_autograd_x, &p_autograd_params); if (require_any_grad) { - std::vector out_names; - for (auto& t : deref_out) { - out_names.emplace_back(t.name()); - } - egr::EagerUtils::PassStopGradient(false, &p_autograd_outs); // Create GradOpNode (1 means [out_grad], 2 means [x_grad, paramx_grad]) auto grad_node = std::make_shared(1, 2); - grad_node->SetFwdOutNames(out_names); // Set Attributes grad_node->SetAttrMap(attrs); // Set TensorWrappers diff --git a/paddle/fluid/eager/to_static/run_program_op_node.h b/paddle/fluid/eager/to_static/run_program_op_node.h index 46635b16ae..a593698963 100644 --- a/paddle/fluid/eager/to_static/run_program_op_node.h +++ b/paddle/fluid/eager/to_static/run_program_op_node.h @@ -791,13 +791,15 @@ class GradNodeRunProgram : public egr::GradNodeBase { } } + auto out_grad_names = + PADDLE_GET_CONST(std::vector, attrs_.at("out_grad_names")); PADDLE_ENFORCE_EQ(hooked_grads[0].size(), - fwd_out_names_.size(), + out_grad_names.size(), paddle::platform::errors::InvalidArgument( "The hooked_grads[0].size() and " - "fwd_out_names_.size() should be equal.")); - for (size_t i = 0; i < fwd_out_names_.size(); ++i) { - hooked_grads[0][i].set_name(fwd_out_names_[i] + "@GRAD"); + "out_grad_names.size() should be equal.")); + for (size_t i = 0; i < out_grad_names.size(); ++i) { + hooked_grads[0][i].set_name(out_grad_names[i]); } RunProgramGradAPI(x_, params_, @@ -829,10 +831,6 @@ class GradNodeRunProgram : public egr::GradNodeBase { step_scope_ = scopes; } - void SetFwdOutNames(std::vector out_names) { - fwd_out_names_ = out_names; - } - protected: void ConstructXGradTensors( const std::vector &x, @@ -850,21 +848,30 @@ class GradNodeRunProgram : public egr::GradNodeBase { } void ConstructParamGradTensors( - const std::vector ¶m, - std::vector *param_grad) { - for (auto &t : param) { - auto t_grad = egr::EagerUtils::unsafe_autograd_meta(t)->Grad(); + const std::vector ¶ms, + std::vector *param_grads) { + auto param_grad_names = PADDLE_GET_CONST(std::vector, + attrs_.at("param_grad_names")); + PADDLE_ENFORCE_EQ(params.size(), + param_grad_names.size(), + paddle::platform::errors::InvalidArgument( + "The param.size() and " + "param_grad_names.size() should be equal.")); + + for (size_t i = 0; i < params.size(); ++i) { + auto &p = params[i]; + auto &p_grad = egr::EagerUtils::unsafe_autograd_meta(p)->Grad(); // In eager mode, the number of param_grad should be the same as // param, so here an empty Tensor is added for the param with // stop_gradient=True - if (!t_grad.defined()) { - param_grad->emplace_back(); - } else if (t_grad.is_dense_tensor()) { - param_grad->emplace_back(std::make_shared()); - } else if (t_grad.is_selected_rows()) { - param_grad->emplace_back(std::make_shared()); + if (!p_grad.defined()) { + param_grads->emplace_back(); + } else if (p_grad.is_dense_tensor()) { + param_grads->emplace_back(std::make_shared()); + } else if (p_grad.is_selected_rows()) { + param_grads->emplace_back(std::make_shared()); } - param_grad->back().set_name(t.name() + "@GRAD"); + param_grads->back().set_name(param_grad_names[i]); } } @@ -880,8 +887,6 @@ class GradNodeRunProgram : public egr::GradNodeBase { std::vector params_; std::vector step_scope_; - std::vector fwd_out_names_; - // Attribute Map paddle::framework::AttributeMap attrs_; }; diff --git a/paddle/fluid/operators/run_program_op.cc b/paddle/fluid/operators/run_program_op.cc index 52e35c3430..eb4f1b88c6 100644 --- a/paddle/fluid/operators/run_program_op.cc +++ b/paddle/fluid/operators/run_program_op.cc @@ -130,6 +130,14 @@ class RunProgramOpMaker : public framework::OpProtoAndCheckerMaker { "(BlockDesc *)" "The global block of executed backward program desc.") .SetDefault(nullptr); + AddAttr>("param_grad_names", + "std::vector" + "The names of parameter gradients.") + .SetDefault({}); + AddAttr>("out_grad_names", + "std::vector" + "The names of output gradients.") + .SetDefault({}); AddComment(R"DOC( RunProgram operator. diff --git a/python/paddle/fluid/tests/unittests/test_eager_run_program.py b/python/paddle/fluid/tests/unittests/test_eager_run_program.py index c265f00a26..33472f85e7 100644 --- a/python/paddle/fluid/tests/unittests/test_eager_run_program.py +++ b/python/paddle/fluid/tests/unittests/test_eager_run_program.py @@ -135,6 +135,10 @@ class TestRunProgram(unittest.TestCase): False, 'program_id', _hash_with_id(program), + 'param_grad_names', + ['Fake_var@GRAD'], + 'out_grad_names', + [out.name + '@GRAD'], ] use_interpretorcore = ( diff --git a/python/paddle/fluid/tests/unittests/test_run_program_op.py b/python/paddle/fluid/tests/unittests/test_run_program_op.py index 193421ac07..7538fffb80 100644 --- a/python/paddle/fluid/tests/unittests/test_run_program_op.py +++ b/python/paddle/fluid/tests/unittests/test_run_program_op.py @@ -262,6 +262,15 @@ class RunProgramOpTest(unittest.TestCase): ) ) + self.attrs.extend( + ( + 'param_grad_names', + [p.name + '@GRAD' for p in inputs['Params']], + 'out_grad_names', + [out.name + '@GRAD' for out in outputs['Out']], + ) + ) + _legacy_C_ops.run_program( inputs['X'], inputs['Params'], @@ -305,6 +314,15 @@ class RunProgramOpTest(unittest.TestCase): ) ) + self.attrs.extend( + ( + 'param_grad_names', + [p.name + '@GRAD' for p in inputs['Params']], + 'out_grad_names', + [out.name + '@GRAD' for out in outputs['Out']], + ) + ) + _legacy_C_ops.run_program( inputs['X'], inputs['Params'], diff --git a/python/paddle/jit/dy2static/partial_program.py b/python/paddle/jit/dy2static/partial_program.py index ca678f6a6a..805c7b743b 100644 --- a/python/paddle/jit/dy2static/partial_program.py +++ b/python/paddle/jit/dy2static/partial_program.py @@ -253,7 +253,7 @@ class PartialProgramLayer: @switch_to_static_graph def _create_forward_backward_train_program(self): - whole_program = self._create_program() + whole_program = self._train_program forward_end_op_index = self._infer_program.desc.block(0).op_size() return self._get_forward_backward_program_form( whole_program, forward_end_op_index @@ -261,7 +261,7 @@ class PartialProgramLayer: @switch_to_static_graph def _create_forward_backward_train_amp_program(self): - whole_program = self._create_amp_program() + whole_program = self._train_amp_program forward_end_op_index = self._infer_amp_program.desc.block(0).op_size() return self._get_forward_backward_program_form( whole_program, forward_end_op_index @@ -269,7 +269,7 @@ class PartialProgramLayer: @switch_to_static_graph def _create_forward_backward_train_pure_fp16_program(self): - whole_program = self._create_pure_fp16_program() + whole_program = self._train_pure_fp16_program forward_end_op_index = self._infer_pure_fp16_program.desc.block( 0 ).op_size() @@ -404,6 +404,43 @@ class PartialProgramLayer: def _infer_pure_fp16_program_id(self): return _hash_with_id(self._infer_pure_fp16_program, self) + @LazyInitialized + def _param_grad_names(self): + names = [] + # NOTE: `names` and `self._params` must be in the same order so that + # the param grad name can be set correctly in the run_program. + for param in self._params: + candidate = [ + var_name + for var_name in self.backward_program.block(0).vars.keys() + if var_name.endswith(param.name + '@GRAD') + ] + if candidate: + names.append( + max(candidate, key=lambda name: name.count('grad/')) + ) + else: + names.append(param.name + '@GRAD') + return names + + @LazyInitialized + def _out_grad_names(self): + names = [] + fwd_end_op_index = self._get_end_op_index() + for i in range( + fwd_end_op_index + 1, + min( + fwd_end_op_index + 2 * len(self._outputs.var_ids), + len(self.program.block(0).ops), + ), + 2, + ): + op = self.program.block(0).ops[i] + if op.type == 'fill_constant': + var_name = op.output('Out')[0] + names.append(var_name) + return names + @property def whole_program_id(self): if self.training: @@ -610,6 +647,18 @@ class PartialProgramLayer: 'program_id', self.program_id, ] + if self.training: + # NOTE: In the case of higher-order gradient, the names of the parameter grads may be like + # `grad/grad/grad/linear_0.w_0@GRAD` instead of simply `linear_0.w_0@GRAD`, so we get + # the correct names of the parameter grads from program. And out grads are similar to above. + attrs.extend( + ( + 'param_grad_names', + self._param_grad_names, + 'out_grad_names', + self._out_grad_names, + ) + ) if self._cuda_graph_capture_mode: attrs.extend( ( -- GitLab