From a6868c9157a2950e1515ca653d804c7f1c66cf5d Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Tue, 12 Oct 2021 20:20:13 +0800 Subject: [PATCH] Fix stop_gradient in RunProgramOp (#36339) (#36353) * Fix stop_gradient in RunProgramOp * fix reference --- paddle/fluid/operators/run_program_op.h | 26 +++++++--- .../tests/unittests/test_run_program_op.py | 48 +++++++++++++++++++ 2 files changed, 67 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/run_program_op.h b/paddle/fluid/operators/run_program_op.h index ac352876e78..04e4dc62b03 100644 --- a/paddle/fluid/operators/run_program_op.h +++ b/paddle/fluid/operators/run_program_op.h @@ -142,10 +142,15 @@ static void ShareVarsIntoScope(const std::vector &vars, static void ShareVarsFromScope(const std::vector &vars, const std::vector &var_names, + const BlockDesc &global_block, framework::Scope *scope) { for (size_t i = 0; i < vars.size(); ++i) { + // NOTE: In case of setting out_tmp.stop_gradient = True in model code, all + // parameters before generating out_tmp have no @GRAD, it will raise error + // because we can't findthem in scope. So we skip sharing these vars or + // var@GRAD if they don't appear in global block. if (var_names[i] == framework::kEmptyVarName || - var_names[i] == "Fake_var") { + var_names[i] == "Fake_var" || !global_block.HasVar(var_names[i])) { VLOG(2) << "find variable name is " << var_names[i] << ", skip it!"; continue; } @@ -214,8 +219,10 @@ class RunProgramOpKernel : public framework::OpKernel { details::ShareVarsIntoScope(input_vars, input_var_names, &scope); details::ShareVarsIntoScope(param_vars, param_names, &scope); + auto *global_block = ctx.Attr("global_block"); + if (end_op_index > start_op_index) { - auto *program = ctx.Attr("global_block")->Program(); + auto *program = global_block->Program(); auto cache_info = framework::GetExecutorInfoFromCache( *program, ctx.GetPlace(), start_op_index, end_op_index, /*is_grad=*/false, program_id, &scope); @@ -240,8 +247,10 @@ class RunProgramOpKernel : public framework::OpKernel { parallel_executor->RunWithoutFetch(skip_eager_delete_vars); } // Step 4. Get Output - details::ShareVarsFromScope(output_vars, output_var_names, &scope); - details::ShareVarsFromScope(dout_vars, dout_var_names, &scope); + details::ShareVarsFromScope(output_vars, output_var_names, *global_block, + &scope); + details::ShareVarsFromScope(dout_vars, dout_var_names, *global_block, + &scope); // Debug info: scope info when run end VLOG(3) << framework::GenScopeTreeDebugInfo(out_scope_vec->front()); @@ -307,10 +316,11 @@ class RunProgramGradOpKernel : public framework::OpKernel { "least one sub scope.")); auto &scope = *(global_inner_scope->kids().front()); + auto *global_block = ctx.Attr("global_block"); if (end_op_index > start_op_index) { // Step 2. prepare executor and scope - auto *program = ctx.Attr("global_block")->Program(); + auto *program = global_block->Program(); auto cache_info = framework::GetExecutorInfoFromCache( *program, ctx.GetPlace(), start_op_index, end_op_index, /*is_grad*/ true, program_id, &scope); @@ -341,8 +351,10 @@ class RunProgramGradOpKernel : public framework::OpKernel { } // Step 4. get outputs - details::ShareVarsFromScope(input_grad_vars, input_grad_var_names, &scope); - details::ShareVarsFromScope(param_grad_vars, param_grad_names, &scope); + details::ShareVarsFromScope(input_grad_vars, input_grad_var_names, + *global_block, &scope); + details::ShareVarsFromScope(param_grad_vars, param_grad_names, + *global_block, &scope); // Step5. drop current scope global_inner_scope->DeleteScope(&scope); diff --git a/python/paddle/fluid/tests/unittests/test_run_program_op.py b/python/paddle/fluid/tests/unittests/test_run_program_op.py index b3d0845a4fb..33b32a6632c 100644 --- a/python/paddle/fluid/tests/unittests/test_run_program_op.py +++ b/python/paddle/fluid/tests/unittests/test_run_program_op.py @@ -343,5 +343,53 @@ class TestRunProgramOpWithEmbedding(RunProgramOpTest): return fwd_op_num +class Net(paddle.nn.Layer): + def __init__(self): + super(Net, self).__init__() + self.fc1 = paddle.nn.Linear(10, 10) + self.fc2 = paddle.nn.Linear(10, 1) + + def forward(self, x): + out = self.fc1(x) + out.stop_gradient = True + out = self.fc2(out) + return out + + +class TestParametersWithStopGradient(unittest.TestCase): + def setUp(self): + self.seed = 2021 + self.iter = 5 + + def train(self, to_static): + # prepare env + paddle.seed(self.seed) + + net = Net() + if to_static: + net = paddle.jit.to_static(net) + sgd = paddle.optimizer.SGD(0.01, parameters=net.parameters()) + + for i in range(self.iter): + x = paddle.rand([4, 10]) + out = net(x) + loss = paddle.mean(out) + + loss.backward() + sgd.minimize(loss) + net.clear_gradients() + + return loss + + def test_stop_gradient(self): + paddle.disable_static() + + dy_loss = self.train(to_static=False) + st_loss = self.train(to_static=True) + self.assertEqual(dy_loss[0], st_loss[0]) + + paddle.enable_static() + + if __name__ == "__main__": unittest.main() -- GitLab