From e0bca5f86640b79386b2020b043efa8fd223c47a Mon Sep 17 00:00:00 2001 From: Dun Date: Mon, 17 Sep 2018 10:25:17 +0800 Subject: [PATCH] Implement slice grad operator. #8130 (#12330) * Implement slice grad operator. #8130 * test slice grad operator and bug fix * Fix pre commit style --- paddle/fluid/operators/slice_op.cc | 55 ++++++++++++-- paddle/fluid/operators/slice_op.cu | 7 ++ paddle/fluid/operators/slice_op.h | 75 +++++++++++++++++++ .../fluid/tests/unittests/test_slice_op.py | 3 + 4 files changed, 132 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/operators/slice_op.cc b/paddle/fluid/operators/slice_op.cc index 4bd23d594..e55462d6c 100644 --- a/paddle/fluid/operators/slice_op.cc +++ b/paddle/fluid/operators/slice_op.cc @@ -25,7 +25,7 @@ class SliceOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Input"), "Input (Input) of slice op should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), @@ -58,7 +58,7 @@ class SliceOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { + const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( framework::ToDataType(ctx.Input("Input")->type()), ctx.GetPlace()); @@ -87,13 +87,13 @@ Slice Operator. Produces a slice of the input tensor along multiple axes. Similar to numpy: https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html -Slice uses `axes`, `starts` and `ends` attributes to specify the start and +Slice uses `axes`, `starts` and `ends` attributes to specify the start and end dimension for each axis in the list of axes, it uses this information -to slice the input data tensor. If a negative value is passed for any of -the start or end indices, it represents number of elements before the end +to slice the input data tensor. If a negative value is passed for any of +the start or end indices, it represents number of elements before the end of that dimension. If the value passed to start or end is larger than -the n (the number of elements in this dimension), it represents n. -For slicing to the end of a dimension with unknown size, it is recommended +the n (the number of elements in this dimension), it represents n. +For slicing to the end of a dimension with unknown size, it is recommended to pass in INT_MAX. If axes are omitted, they are set to [0, ..., ndim-1]. Following examples will explain how slice works: @@ -119,15 +119,54 @@ Following examples will explain how slice works: } }; +class SliceOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), "Input should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto x_dims = ctx->GetInputDim("Input"); + auto x_grad_name = framework::GradVarName("Input"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + } +}; + +class SliceOpGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto* bind = new framework::OpDesc(); + bind->SetInput("Input", Input("Input")); + bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + bind->SetOutput(framework::GradVarName("Input"), InputGrad("Input")); + bind->SetAttrMap(Attrs()); + bind->SetType("slice_grad"); + return std::unique_ptr(bind); + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(slice, ops::SliceOp, ops::SliceOpMaker, - paddle::framework::EmptyGradOpMaker); + ops::SliceOpGradMaker); +REGISTER_OPERATOR(slice_grad, ops::SliceOpGrad); REGISTER_OP_CPU_KERNEL( slice, ops::SliceKernel, ops::SliceKernel, ops::SliceKernel, ops::SliceKernel); + +REGISTER_OP_CPU_KERNEL( + slice_grad, ops::SliceGradKernel, + ops::SliceGradKernel, + ops::SliceGradKernel, + ops::SliceGradKernel); diff --git a/paddle/fluid/operators/slice_op.cu b/paddle/fluid/operators/slice_op.cu index 8c1767c70..5efecb78d 100644 --- a/paddle/fluid/operators/slice_op.cu +++ b/paddle/fluid/operators/slice_op.cu @@ -20,3 +20,10 @@ REGISTER_OP_CUDA_KERNEL( ops::SliceKernel, ops::SliceKernel, ops::SliceKernel); + +REGISTER_OP_CUDA_KERNEL( + slice_grad, + ops::SliceGradKernel, + ops::SliceGradKernel, + ops::SliceGradKernel, + ops::SliceGradKernel); diff --git a/paddle/fluid/operators/slice_op.h b/paddle/fluid/operators/slice_op.h index ba231aee1..f38d08d76 100644 --- a/paddle/fluid/operators/slice_op.h +++ b/paddle/fluid/operators/slice_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once #include +#include #include #include "paddle/fluid/framework/op_registry.h" @@ -84,5 +85,79 @@ class SliceKernel : public framework::OpKernel { out_t.device(place) = in_t.slice(offsets, extents); } }; + +template +class SliceGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + size_t rank = ctx.Input(framework::GradVarName("Out")) + ->dims() + .size(); + switch (rank) { + case 1: + SliceCompute<1>(ctx); + break; + case 2: + SliceCompute<2>(ctx); + break; + case 3: + SliceCompute<3>(ctx); + break; + case 4: + SliceCompute<4>(ctx); + break; + case 5: + SliceCompute<5>(ctx); + break; + case 6: + SliceCompute<6>(ctx); + break; + } + } + + private: + template + void SliceCompute(const framework::ExecutionContext& context) const { + auto& place = + *context.template device_context().eigen_device(); + auto* d_out = + context.Input(framework::GradVarName("Out")); + auto* d_input = + context.Output(framework::GradVarName("Input")); + d_input->mutable_data(context.GetPlace()); + auto out_dims = d_out->dims(); + auto in_dims = d_input->dims(); + auto axes = context.Attr>("axes"); + auto starts = context.Attr>("starts"); + + auto offsets = Eigen::array(); + auto extents = Eigen::array(); + for (size_t i = 0; i < D; ++i) { + offsets[i] = 0; + extents[i] = out_dims[i]; + } + int start; + for (size_t i = 0; i < axes.size(); ++i) { + start = starts[i]; + if (start < 0) { + start = (start + in_dims[axes[i]]); + } + start = std::max(start, 0); + offsets[axes[i]] = start; + } + Eigen::array, D> paddings; + for (size_t i = 0; i < paddings.size(); ++i) { + paddings[i].first = offsets[i]; + paddings[i].second = (in_dims[i] - out_dims[i]) - offsets[i]; + } + auto d_in_t = + framework::EigenTensor::From( + *d_input); + auto d_out_t = + framework::EigenTensor::From( + *d_out); + d_in_t.device(place) = d_out_t.pad(paddings, 0); + } +}; } // namespace operators } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_slice_op.py b/python/paddle/fluid/tests/unittests/test_slice_op.py index 134df38ee..4e6ed3a74 100644 --- a/python/paddle/fluid/tests/unittests/test_slice_op.py +++ b/python/paddle/fluid/tests/unittests/test_slice_op.py @@ -41,6 +41,9 @@ class TestSliceOp(OpTest): def test_check_output(self): self.check_output() + def test_check_grad_normal(self): + self.check_grad(['Input'], 'Out', max_relative_error=0.006) + class TestCase1(TestSliceOp): def config(self): -- GitLab