diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index ccf43330840ba7648c92a9f7a04c36503f6e61a6..5daa937b017dd19e0050b77181dcf2bcf4b8cc5c 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -363,6 +363,20 @@ class ReshapeGradKernel { } }; +class ReshapeDoubleGradKernel { + public: + void operator()(const framework::ExecutionContext &ctx) const { + auto *dd_x = ctx.Input("DDX"); + auto *dd_out = ctx.Output("DDOut"); + + auto out_dims = dd_out->dims(); + + dd_out->mutable_data(ctx.GetPlace(), dd_x->type()); + framework::TensorCopySync(*dd_x, ctx.GetPlace(), dd_out); + dd_out->Resize(out_dims); + } +}; + // FIXME(zcd): reshape2 adds an intermediate output(XShape) based on reshape, // the XShape is used to carry the shape and lod of X which will be used in // reshape_grad, in this way, the framework can reuse the memory of X @@ -409,6 +423,7 @@ class Reshape2GradMaker : public framework::SingleGradOpDescMaker { std::unique_ptr Apply() const override { auto *grad_op = new framework::OpDesc(); grad_op->SetType("reshape2_grad"); + grad_op->SetInput("X", Input("X")); grad_op->SetInput("XShape", Output("XShape")); grad_op->SetInput("ShapeTensor", Input("ShapeTensor")); grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); @@ -418,6 +433,27 @@ class Reshape2GradMaker : public framework::SingleGradOpDescMaker { } }; +class Reshape2DoubleGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + std::unique_ptr Apply() const override { + auto *grad_op = new framework::OpDesc(); + grad_op->SetType("reshape2_grad_grad"); + + grad_op->SetInput("X", Input("X")); + grad_op->SetInput("ShapeTensor", Input("ShapeTensor")); + grad_op->SetInput("DOut", Input(framework::GradVarName("Out"))); + grad_op->SetInput("DDX", OutputGrad(framework::GradVarName("X"))); + + auto ddx = OutputGrad(framework::GradVarName("X")); + + grad_op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out"))); + grad_op->SetAttrMap(Attrs()); + return std::unique_ptr(grad_op); + } +}; + class Reshape2GradOp : public framework::OperatorWithKernel { public: Reshape2GradOp(const std::string &type, @@ -456,10 +492,47 @@ class Reshape2GradOp : public framework::OperatorWithKernel { } }; +class Reshape2DoubleGradOp : public framework::OperatorWithKernel { + public: + Reshape2DoubleGradOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) shouldn't be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput("DDX"), true, + "Input(X@GRAD_GRAD) shouldn't be null."); + + if (ctx->HasOutput("DDOut") && ctx->HasInput("DDX")) { + ctx->ShareDim("DOut", "DDOut"); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType(ctx.Input("DDX")->type(), + ctx.device_context()); + } + + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, const Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const override { + if (var_name == "ShapeTensor") { + return expected_kernel_type; + } + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } +}; + DECLARE_INPLACE_OP_INFERER(ReshapeOpInplaceInToOut, {"X", "Out"}); DECLARE_INPLACE_OP_INFERER(ReshapeGradInplaceInToOut, {framework::GradVarName("Out"), framework::GradVarName("X")}); +DECLARE_INPLACE_OP_INFERER(ReshapeDoubleGradInplaceInToOut, {"DDX", "DDOut"}); } // namespace operators } // namespace paddle @@ -471,6 +544,7 @@ REGISTER_OPERATOR(reshape, ops::ReshapeOp, ops::ReshapeOpMaker, ops::ReshapeOpInplaceInToOut); REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp, ops::ReshapeGradInplaceInToOut); + REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double, ops::ReshapeKernel, int, ops::ReshapeKernel, int64_t, ops::ReshapeKernel); @@ -478,11 +552,13 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel, double, ops::ReshapeGradKernel, int, ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel); - REGISTER_OPERATOR(reshape2, ops::Reshape2Op, ops::Reshape2OpMaker, ops::Reshape2GradMaker, ops::ReshapeOpInplaceInToOut); REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp, - ops::ReshapeGradInplaceInToOut); + ops::Reshape2DoubleGradMaker, ops::ReshapeGradInplaceInToOut); +REGISTER_OPERATOR(reshape2_grad_grad, ops::Reshape2DoubleGradOp, + ops::ReshapeDoubleGradInplaceInToOut); + REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double, ops::ReshapeKernel, int, ops::ReshapeKernel, int64_t, ops::ReshapeKernel); @@ -490,6 +566,11 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel, double, ops::ReshapeGradKernel, int, ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel); +REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad_grad, float, + ops::ReshapeDoubleGradKernel, double, + ops::ReshapeDoubleGradKernel, int, + ops::ReshapeDoubleGradKernel, int64_t, + ops::ReshapeDoubleGradKernel); #ifdef PADDLE_WITH_CUDA REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double, @@ -510,4 +591,11 @@ REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel, ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel, plat::float16, ops::ReshapeGradKernel); + +REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad_grad, float, + ops::ReshapeDoubleGradKernel, double, + ops::ReshapeDoubleGradKernel, int, + ops::ReshapeDoubleGradKernel, int64_t, + ops::ReshapeDoubleGradKernel, plat::float16, + ops::ReshapeDoubleGradKernel); #endif diff --git a/python/paddle/fluid/tests/unittests/test_nn_grad.py b/python/paddle/fluid/tests/unittests/test_nn_grad.py index 8bbd9443230a695a0daeafe171195e0169f65ca7..1434fdf0d0a89e633de409538008845baefae444 100644 --- a/python/paddle/fluid/tests/unittests/test_nn_grad.py +++ b/python/paddle/fluid/tests/unittests/test_nn_grad.py @@ -94,5 +94,29 @@ class TestMulDoubleGradCheck(unittest.TestCase): self.func(p) +class TestReshapeDoubleGradCheck(unittest.TestCase): + @prog_scope() + def func(self, place): + x_shape = [3, 12] + new_shape = [4, 9] + eps = 0.005 + dtype = np.float64 + + x = layers.data('x', x_shape, False, dtype) + x.persistable = True + out = layers.reshape(x, new_shape) + x_arr = np.random.uniform(-1, 1, x_shape).astype(dtype) + + gradient_checker.double_grad_check( + [x], out, x_init=x_arr, place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + if __name__ == "__main__": unittest.main()