未验证 提交 2a75b447 编写于 作者: A Aurelius84 提交者: GitHub

Fix stop_gradient in RunProgramOp (#36339)

* Fix stop_gradient in RunProgramOp

* fix reference
上级 6d353aa5
......@@ -142,10 +142,15 @@ static void ShareVarsIntoScope(const std::vector<Variable *> &vars,
static void ShareVarsFromScope(const std::vector<Variable *> &vars,
const std::vector<std::string> &var_names,
const BlockDesc &global_block,
framework::Scope *scope) {
for (size_t i = 0; i < vars.size(); ++i) {
// NOTE: In case of setting out_tmp.stop_gradient = True in model code, all
// parameters before generating out_tmp have no @GRAD, it will raise error
// because we can't findthem in scope. So we skip sharing these vars or
// var@GRAD if they don't appear in global block.
if (var_names[i] == framework::kEmptyVarName ||
var_names[i] == "Fake_var") {
var_names[i] == "Fake_var" || !global_block.HasVar(var_names[i])) {
VLOG(2) << "find variable name is " << var_names[i] << ", skip it!";
continue;
}
......@@ -214,8 +219,10 @@ class RunProgramOpKernel : public framework::OpKernel<T> {
details::ShareVarsIntoScope(input_vars, input_var_names, &scope);
details::ShareVarsIntoScope(param_vars, param_names, &scope);
auto *global_block = ctx.Attr<BlockDesc *>("global_block");
if (end_op_index > start_op_index) {
auto *program = ctx.Attr<BlockDesc *>("global_block")->Program();
auto *program = global_block->Program();
auto cache_info = framework::GetExecutorInfoFromCache(
*program, ctx.GetPlace(), start_op_index, end_op_index,
/*is_grad=*/false, program_id, &scope);
......@@ -240,8 +247,10 @@ class RunProgramOpKernel : public framework::OpKernel<T> {
parallel_executor->RunWithoutFetch(skip_eager_delete_vars);
}
// Step 4. Get Output
details::ShareVarsFromScope(output_vars, output_var_names, &scope);
details::ShareVarsFromScope(dout_vars, dout_var_names, &scope);
details::ShareVarsFromScope(output_vars, output_var_names, *global_block,
&scope);
details::ShareVarsFromScope(dout_vars, dout_var_names, *global_block,
&scope);
// Debug info: scope info when run end
VLOG(3) << framework::GenScopeTreeDebugInfo(out_scope_vec->front());
......@@ -307,10 +316,11 @@ class RunProgramGradOpKernel : public framework::OpKernel<T> {
"least one sub scope."));
auto &scope = *(global_inner_scope->kids().front());
auto *global_block = ctx.Attr<BlockDesc *>("global_block");
if (end_op_index > start_op_index) {
// Step 2. prepare executor and scope
auto *program = ctx.Attr<BlockDesc *>("global_block")->Program();
auto *program = global_block->Program();
auto cache_info = framework::GetExecutorInfoFromCache(
*program, ctx.GetPlace(), start_op_index, end_op_index,
/*is_grad*/ true, program_id, &scope);
......@@ -341,8 +351,10 @@ class RunProgramGradOpKernel : public framework::OpKernel<T> {
}
// Step 4. get outputs
details::ShareVarsFromScope(input_grad_vars, input_grad_var_names, &scope);
details::ShareVarsFromScope(param_grad_vars, param_grad_names, &scope);
details::ShareVarsFromScope(input_grad_vars, input_grad_var_names,
*global_block, &scope);
details::ShareVarsFromScope(param_grad_vars, param_grad_names,
*global_block, &scope);
// Step5. drop current scope
global_inner_scope->DeleteScope(&scope);
......
......@@ -343,5 +343,53 @@ class TestRunProgramOpWithEmbedding(RunProgramOpTest):
return fwd_op_num
class Net(paddle.nn.Layer):
def __init__(self):
super(Net, self).__init__()
self.fc1 = paddle.nn.Linear(10, 10)
self.fc2 = paddle.nn.Linear(10, 1)
def forward(self, x):
out = self.fc1(x)
out.stop_gradient = True
out = self.fc2(out)
return out
class TestParametersWithStopGradient(unittest.TestCase):
def setUp(self):
self.seed = 2021
self.iter = 5
def train(self, to_static):
# prepare env
paddle.seed(self.seed)
net = Net()
if to_static:
net = paddle.jit.to_static(net)
sgd = paddle.optimizer.SGD(0.01, parameters=net.parameters())
for i in range(self.iter):
x = paddle.rand([4, 10])
out = net(x)
loss = paddle.mean(out)
loss.backward()
sgd.minimize(loss)
net.clear_gradients()
return loss
def test_stop_gradient(self):
paddle.disable_static()
dy_loss = self.train(to_static=False)
st_loss = self.train(to_static=True)
self.assertEqual(dy_loss[0], st_loss[0])
paddle.enable_static()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册