提交 e0bca5f8 编写于 作者: D Dun 提交者: qingqing01

Implement slice grad operator. #8130 (#12330)

* Implement slice grad operator. #8130
* test slice grad operator and bug fix
* Fix pre commit style
上级 03dc7b79
...@@ -25,7 +25,7 @@ class SliceOp : public framework::OperatorWithKernel { ...@@ -25,7 +25,7 @@ class SliceOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"), PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input (Input) of slice op should not be null."); "Input (Input) of slice op should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
...@@ -58,7 +58,7 @@ class SliceOp : public framework::OperatorWithKernel { ...@@ -58,7 +58,7 @@ class SliceOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), framework::ToDataType(ctx.Input<Tensor>("Input")->type()),
ctx.GetPlace()); ctx.GetPlace());
...@@ -87,13 +87,13 @@ Slice Operator. ...@@ -87,13 +87,13 @@ Slice Operator.
Produces a slice of the input tensor along multiple axes. Similar to numpy: Produces a slice of the input tensor along multiple axes. Similar to numpy:
https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html
Slice uses `axes`, `starts` and `ends` attributes to specify the start and Slice uses `axes`, `starts` and `ends` attributes to specify the start and
end dimension for each axis in the list of axes, it uses this information end dimension for each axis in the list of axes, it uses this information
to slice the input data tensor. If a negative value is passed for any of to slice the input data tensor. If a negative value is passed for any of
the start or end indices, it represents number of elements before the end the start or end indices, it represents number of elements before the end
of that dimension. If the value passed to start or end is larger than of that dimension. If the value passed to start or end is larger than
the n (the number of elements in this dimension), it represents n. the n (the number of elements in this dimension), it represents n.
For slicing to the end of a dimension with unknown size, it is recommended For slicing to the end of a dimension with unknown size, it is recommended
to pass in INT_MAX. If axes are omitted, they are set to [0, ..., ndim-1]. to pass in INT_MAX. If axes are omitted, they are set to [0, ..., ndim-1].
Following examples will explain how slice works: Following examples will explain how slice works:
...@@ -119,15 +119,54 @@ Following examples will explain how slice works: ...@@ -119,15 +119,54 @@ Following examples will explain how slice works:
} }
}; };
class SliceOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"), "Input should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_dims = ctx->GetInputDim("Input");
auto x_grad_name = framework::GradVarName("Input");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
}
};
class SliceOpGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
auto* bind = new framework::OpDesc();
bind->SetInput("Input", Input("Input"));
bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
bind->SetOutput(framework::GradVarName("Input"), InputGrad("Input"));
bind->SetAttrMap(Attrs());
bind->SetType("slice_grad");
return std::unique_ptr<framework::OpDesc>(bind);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(slice, ops::SliceOp, ops::SliceOpMaker, REGISTER_OPERATOR(slice, ops::SliceOp, ops::SliceOpMaker,
paddle::framework::EmptyGradOpMaker); ops::SliceOpGradMaker);
REGISTER_OPERATOR(slice_grad, ops::SliceOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
slice, ops::SliceKernel<paddle::platform::CPUDeviceContext, int>, slice, ops::SliceKernel<paddle::platform::CPUDeviceContext, int>,
ops::SliceKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::SliceKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::SliceKernel<paddle::platform::CPUDeviceContext, float>, ops::SliceKernel<paddle::platform::CPUDeviceContext, float>,
ops::SliceKernel<paddle::platform::CPUDeviceContext, double>); ops::SliceKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
slice_grad, ops::SliceGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::SliceGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::SliceGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::SliceGradKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -20,3 +20,10 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -20,3 +20,10 @@ REGISTER_OP_CUDA_KERNEL(
ops::SliceKernel<paddle::platform::CUDADeviceContext, double>, ops::SliceKernel<paddle::platform::CUDADeviceContext, double>,
ops::SliceKernel<paddle::platform::CUDADeviceContext, int>, ops::SliceKernel<paddle::platform::CUDADeviceContext, int>,
ops::SliceKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::SliceKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
slice_grad,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -84,5 +85,79 @@ class SliceKernel : public framework::OpKernel<T> { ...@@ -84,5 +85,79 @@ class SliceKernel : public framework::OpKernel<T> {
out_t.device(place) = in_t.slice(offsets, extents); out_t.device(place) = in_t.slice(offsets, extents);
} }
}; };
template <typename DeviceContext, typename T>
class SliceGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
size_t rank = ctx.Input<framework::Tensor>(framework::GradVarName("Out"))
->dims()
.size();
switch (rank) {
case 1:
SliceCompute<1>(ctx);
break;
case 2:
SliceCompute<2>(ctx);
break;
case 3:
SliceCompute<3>(ctx);
break;
case 4:
SliceCompute<4>(ctx);
break;
case 5:
SliceCompute<5>(ctx);
break;
case 6:
SliceCompute<6>(ctx);
break;
}
}
private:
template <size_t D>
void SliceCompute(const framework::ExecutionContext& context) const {
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto* d_out =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* d_input =
context.Output<framework::Tensor>(framework::GradVarName("Input"));
d_input->mutable_data<T>(context.GetPlace());
auto out_dims = d_out->dims();
auto in_dims = d_input->dims();
auto axes = context.Attr<std::vector<int>>("axes");
auto starts = context.Attr<std::vector<int>>("starts");
auto offsets = Eigen::array<int, D>();
auto extents = Eigen::array<int, D>();
for (size_t i = 0; i < D; ++i) {
offsets[i] = 0;
extents[i] = out_dims[i];
}
int start;
for (size_t i = 0; i < axes.size(); ++i) {
start = starts[i];
if (start < 0) {
start = (start + in_dims[axes[i]]);
}
start = std::max(start, 0);
offsets[axes[i]] = start;
}
Eigen::array<std::pair<int, int>, D> paddings;
for (size_t i = 0; i < paddings.size(); ++i) {
paddings[i].first = offsets[i];
paddings[i].second = (in_dims[i] - out_dims[i]) - offsets[i];
}
auto d_in_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*d_input);
auto d_out_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*d_out);
d_in_t.device(place) = d_out_t.pad(paddings, 0);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -41,6 +41,9 @@ class TestSliceOp(OpTest): ...@@ -41,6 +41,9 @@ class TestSliceOp(OpTest):
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad_normal(self):
self.check_grad(['Input'], 'Out', max_relative_error=0.006)
class TestCase1(TestSliceOp): class TestCase1(TestSliceOp):
def config(self): def config(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册