diff --git a/paddle/fluid/operators/controlflow/conditional_block_op.cc b/paddle/fluid/operators/controlflow/conditional_block_op.cc index fd06e33a6bb6e03b8e90c47ae3edbb9ce18e0e85..7ffbf1933be374b49ba3a5e8bb868849abc245f3 100644 --- a/paddle/fluid/operators/controlflow/conditional_block_op.cc +++ b/paddle/fluid/operators/controlflow/conditional_block_op.cc @@ -153,7 +153,7 @@ class ConditionalBlockGradOp : public ConditionalOp { /* keep_kid_scopes */ false); AssignLocalGradientToParentScope(dev_place, cur_scope, scope, - inside_grads, outside_grads); + inside_grads, outside_grads, inputs); return; } @@ -165,27 +165,36 @@ class ConditionalBlockGradOp : public ConditionalOp { const platform::Place &place, const framework::Scope &cur_scope, const framework::Scope &parent_scope, const std::vector &inside_grads, - const std::vector &outside_grads) const { + const std::vector &outside_grads, + const std::vector &inputs) const { + std::vector assign_zero_outside_grads; + std::vector assign_zero_inputs; for (size_t i = 0; i < outside_grads.size(); ++i) { const std::string &outside_grad_name = outside_grads[i]; const std::string &inside_grad_name = inside_grads[i]; VLOG(4) << "inside_grad_name = " << inside_grad_name << ", outside_grad_name = " << outside_grad_name; - framework::Variable *inside_var = - cur_scope.FindLocalVar(inside_grad_name); - if (inside_var == nullptr) { - continue; - } framework::Variable *outside_var = parent_scope.FindVar(outside_grad_name); if (outside_var == nullptr) { continue; } + framework::Variable *inside_var = + cur_scope.FindLocalVar(inside_grad_name); + if (inside_var == nullptr) { + assign_zero_outside_grads.emplace_back(outside_grad_name); + assign_zero_inputs.emplace_back(inputs[i]); + continue; + } platform::DeviceContext *dev_ctx = platform::DeviceContextPool::Instance().Get(place); framework::VisitVarType(*inside_var, AssignFunctor(outside_var, *dev_ctx)); } + // Assign zero to the grad_vars that are in outside_grads but not in + // inside_grads + AssignZeroToParentScope(place, parent_scope, assign_zero_inputs, + assign_zero_outside_grads); } void AssignZeroToParentScope( diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py index 9a9e7ee243872b9068bcd0107d5300d8177200bc..276aa68e895c68345352413dd47decc1456dac20 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py @@ -424,6 +424,41 @@ class TestDy2StIfElseRetInt4(TestDy2StIfElseRetInt1): ProgramTranslator().enable(False) +class IfElseNet(paddle.nn.Layer): + def __init__(self): + super(IfElseNet, self).__init__() + self.param = self.create_parameter( + shape=[3, 2], dtype='float32', is_bias=False) + + @paddle.jit.to_static + def forward(self, a, b, c): + a = paddle.matmul(a, self.param) + a = paddle.reshape(a, (2, 4)) + cond = paddle.to_tensor([10]) + if cond == 10: + a_argmax = a.argmax(axis=-1) + b = b + self.param + else: + print(c) + return b + + +class TestDy2StIfElseBackward(unittest.TestCase): + def test_run_backward(self): + a = paddle.randn((4, 3), dtype='float32') + a.stop_gradient = False + b = paddle.to_tensor([10]).astype('float32') + b.stop_gradient = False + c = paddle.to_tensor([2]) + c.stop_gradient = False + + net = IfElseNet() + net.train() + out = net(a, b, c) + out.backward() + self.assertTrue(np.allclose((b + net.param).numpy(), out.numpy())) + + if __name__ == '__main__': with paddle.fluid.framework._test_eager_guard(): unittest.main()