From 58a88ba9af79a720074b23892fb137fc7642449c Mon Sep 17 00:00:00 2001 From: lilong12 Date: Thu, 10 Sep 2020 15:46:39 +0800 Subject: [PATCH] add double grad for expand (#27183) * add double grad for expand, test=develop --- paddle/fluid/operators/expand_op.cc | 22 +++++++++++++++++ .../fluid/tests/unittests/test_nn_grad.py | 24 +++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/paddle/fluid/operators/expand_op.cc b/paddle/fluid/operators/expand_op.cc index 3c898ac29f0..83e205367a7 100644 --- a/paddle/fluid/operators/expand_op.cc +++ b/paddle/fluid/operators/expand_op.cc @@ -228,6 +228,26 @@ class ExpandGradOpMaker : public framework::SingleGradOpMaker { } }; +template +class ExpandDoubleGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetInput("X", this->OutputGrad(framework::GradVarName("X"))); + op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out"))); + if (this->HasInput("expand_times_tensor")) { + op->SetInput("expand_times_tensor", this->Input("expand_times_tensor")); + } + if (this->HasInput("ExpandTimes")) { + op->SetInput("ExpandTimes", this->Input("ExpandTimes")); + } + op->SetAttrMap(this->Attrs()); + op->SetType("expand"); + } +}; + DECLARE_NO_NEED_BUFFER_VARS_INFERER(ExpandGradNoNeedBufVarsInferer, "X"); } // namespace operators @@ -238,6 +258,8 @@ REGISTER_OPERATOR(expand, ops::ExpandOp, ops::ExpandOpMaker, ops::ExpandGradOpMaker, ops::ExpandGradOpMaker); REGISTER_OPERATOR(expand_grad, ops::ExpandGradOp, + ops::ExpandDoubleGradOpMaker, + ops::ExpandDoubleGradOpMaker, ops::ExpandGradNoNeedBufVarsInferer); REGISTER_OP_CPU_KERNEL( expand, ops::ExpandKernel, diff --git a/python/paddle/fluid/tests/unittests/test_nn_grad.py b/python/paddle/fluid/tests/unittests/test_nn_grad.py index 0c39dc5e731..0e4f89f6026 100644 --- a/python/paddle/fluid/tests/unittests/test_nn_grad.py +++ b/python/paddle/fluid/tests/unittests/test_nn_grad.py @@ -153,6 +153,30 @@ class TestMulDoubleGradCheck(unittest.TestCase): class TestReshapeDoubleGradCheck(unittest.TestCase): + @prog_scope() + def func(self, place): + x_shape = [3, 12] + expand_times = [4, 9] + eps = 0.005 + dtype = np.float64 + + x = layers.data('x', x_shape, False, dtype) + x.persistable = True + out = layers.expand(x, expand_times) + x_arr = np.random.uniform(-1, 1, x_shape).astype(dtype) + + gradient_checker.double_grad_check( + [x], out, x_init=x_arr, place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +class TestExpandDoubleGradCheck(unittest.TestCase): @prog_scope() def func(self, place): x_shape = [3, 12] -- GitLab