From cc387159f3be6e8d5dd37b036a46899d4dbde21e Mon Sep 17 00:00:00 2001 From: ceci3 Date: Wed, 16 Dec 2020 21:06:21 +0800 Subject: [PATCH] add pad and concat double grad (#29549) * add constant pad double grad --- paddle/fluid/operators/concat_op.cc | 16 +++++ paddle/fluid/operators/pad3d_op.cc | 18 +++++ paddle/fluid/operators/pad_op.cc | 17 ++++- .../fluid/tests/unittests/test_nn_grad.py | 65 +++++++++++++++++++ 4 files changed, 115 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/concat_op.cc b/paddle/fluid/operators/concat_op.cc index 0b3697156d..e84f022214 100644 --- a/paddle/fluid/operators/concat_op.cc +++ b/paddle/fluid/operators/concat_op.cc @@ -201,6 +201,20 @@ class ConcatGradOpMaker : public framework::SingleGradOpMaker { } }; +template +class ConcatDoubleGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr grad_op) const override { + grad_op->SetType("concat"); + grad_op->SetInput("X", this->OutputGrad(framework::GradVarName("X"))); + grad_op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out"))); + grad_op->SetAttrMap(this->Attrs()); + } +}; + } // namespace operators } // namespace paddle @@ -209,6 +223,8 @@ REGISTER_OPERATOR(concat, ops::ConcatOp, ops::ConcatOpMaker, ops::ConcatGradOpMaker, ops::ConcatGradOpMaker); REGISTER_OPERATOR(concat_grad, ops::ConcatOpGrad, + ops::ConcatDoubleGradOpMaker, + ops::ConcatDoubleGradOpMaker, ops::ConcatOpGradNoNeedBufferVarInferer); REGISTER_OP_CPU_KERNEL( concat, ops::ConcatKernel, diff --git a/paddle/fluid/operators/pad3d_op.cc b/paddle/fluid/operators/pad3d_op.cc index 1d41b823b6..0751cf2558 100644 --- a/paddle/fluid/operators/pad3d_op.cc +++ b/paddle/fluid/operators/pad3d_op.cc @@ -893,6 +893,22 @@ class Pad3dOpGradMaker : public framework::SingleGradOpMaker { } }; +template +class Pad3dOpDoubleGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + void Apply(GradOpPtr grad_op) const override { + if (this->HasInput("Paddings")) { + grad_op->SetInput("Paddings", this->Input("Paddings")); + } + grad_op->SetType("pad3d"); + grad_op->SetInput("X", this->OutputGrad(framework::GradVarName("X"))); + grad_op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out"))); + grad_op->SetAttrMap(this->Attrs()); + } +}; + DECLARE_NO_NEED_BUFFER_VARS_INFERER(Pad3dOpGradNoNeedBufferVarsInferer, "X"); } // namespace operators @@ -904,6 +920,8 @@ REGISTER_OPERATOR(pad3d, ops::Pad3dOp, ops::Pad3dOpMaker, ops::Pad3dOpGradMaker, ops::Pad3dOpGradMaker); REGISTER_OPERATOR(pad3d_grad, ops::Pad3dOpGrad, + ops::Pad3dOpDoubleGradMaker, + ops::Pad3dOpDoubleGradMaker, ops::Pad3dOpGradNoNeedBufferVarsInferer); REGISTER_OP_CPU_KERNEL(pad3d, ops::Pad3dCPUKernel, ops::Pad3dCPUKernel, ops::Pad3dCPUKernel, diff --git a/paddle/fluid/operators/pad_op.cc b/paddle/fluid/operators/pad_op.cc index 91de48100a..577f4f3941 100644 --- a/paddle/fluid/operators/pad_op.cc +++ b/paddle/fluid/operators/pad_op.cc @@ -142,6 +142,19 @@ class PadOpGradMaker : public framework::SingleGradOpMaker { } }; +template +class PadOpDoubleGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + void Apply(GradOpPtr grad_op) const override { + grad_op->SetType("pad"); + grad_op->SetInput("X", this->OutputGrad(framework::GradVarName("X"))); + grad_op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out"))); + grad_op->SetAttrMap(this->Attrs()); + } +}; + } // namespace operators } // namespace paddle @@ -150,7 +163,9 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(pad, ops::PadOp, ops::PadOpMaker, ops::PadOpGradMaker, ops::PadOpGradMaker); -REGISTER_OPERATOR(pad_grad, ops::PadOpGrad); +REGISTER_OPERATOR(pad_grad, ops::PadOpGrad, + ops::PadOpDoubleGradMaker, + ops::PadOpDoubleGradMaker); REGISTER_OP_CPU_KERNEL( pad, ops::PadKernel, ops::PadKernel, diff --git a/python/paddle/fluid/tests/unittests/test_nn_grad.py b/python/paddle/fluid/tests/unittests/test_nn_grad.py index 6a5e1ba147..d7bbc355d5 100644 --- a/python/paddle/fluid/tests/unittests/test_nn_grad.py +++ b/python/paddle/fluid/tests/unittests/test_nn_grad.py @@ -394,5 +394,70 @@ class TestTransposeDoubleGradCheckCase1(unittest.TestCase): self.func(p) +class TestConstantPadDoubleGradCheck(unittest.TestCase): + @prog_scope() + def func(self, place): + x_shape = [2, 3, 4, 5] + pad = [1, 1, 1, 1] + eps = 0.005 + dtype = np.float64 + + x = layers.data('x', x_shape, False, dtype) + x.persistable = True + out = paddle.nn.functional.pad(x, pad) + x_arr = np.random.uniform(-1, 1, x_shape).astype(dtype) + + gradient_checker.double_grad_check( + [x], out, x_init=x_arr, place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +class TestConstantPadDoubleGradCheckCase1(TestConstantPadDoubleGradCheck): + @prog_scope() + def func(self, place): + x_shape = [2, 3, 4, 5] + pad = [1, 0, 1, 0, 1, 0, 1, 0] + dtype = np.float64 + + x = layers.data('x', x_shape, False, dtype) + x.persistable = True + out = paddle.nn.functional.pad(x, pad) + x_arr = np.random.uniform(-1, 1, x_shape).astype(dtype) + + gradient_checker.double_grad_check([x], out, x_init=x_arr, place=place) + + +class TestConcatDoubleGradCheck(unittest.TestCase): + @prog_scope() + def func(self, place): + x_shape = [2, 3, 4, 5] + pad = [1, 1, 1, 1] + dtype = np.float64 + + x1 = layers.data('x', x_shape, False, dtype) + x2 = layers.data('x', x_shape, False, dtype) + x1.persistable = True + x2.persistable = True + out = paddle.concat([x1, x2], axis=0) + x2_arr = np.random.uniform(-1, 1, x_shape).astype(dtype) + x1_arr = np.random.uniform(-1, 1, x_shape).astype(dtype) + + gradient_checker.double_grad_check( + [x1, x2], out, x_init=[x1_arr, x2_arr], place=place) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + if __name__ == "__main__": unittest.main() -- GitLab