提交 e0bca5f8 编写于 作者: D Dun 提交者: qingqing01

Implement slice grad operator. #8130 (#12330)

* Implement slice grad operator. #8130
* test slice grad operator and bug fix
* Fix pre commit style
上级 03dc7b79
...@@ -25,7 +25,7 @@ class SliceOp : public framework::OperatorWithKernel { ...@@ -25,7 +25,7 @@ class SliceOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"), PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input (Input) of slice op should not be null."); "Input (Input) of slice op should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
...@@ -58,7 +58,7 @@ class SliceOp : public framework::OperatorWithKernel { ...@@ -58,7 +58,7 @@ class SliceOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), framework::ToDataType(ctx.Input<Tensor>("Input")->type()),
ctx.GetPlace()); ctx.GetPlace());
...@@ -119,15 +119,54 @@ Following examples will explain how slice works: ...@@ -119,15 +119,54 @@ Following examples will explain how slice works:
} }
}; };
class SliceOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"), "Input should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_dims = ctx->GetInputDim("Input");
auto x_grad_name = framework::GradVarName("Input");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
}
};
class SliceOpGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
auto* bind = new framework::OpDesc();
bind->SetInput("Input", Input("Input"));
bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
bind->SetOutput(framework::GradVarName("Input"), InputGrad("Input"));
bind->SetAttrMap(Attrs());
bind->SetType("slice_grad");
return std::unique_ptr<framework::OpDesc>(bind);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(slice, ops::SliceOp, ops::SliceOpMaker, REGISTER_OPERATOR(slice, ops::SliceOp, ops::SliceOpMaker,
paddle::framework::EmptyGradOpMaker); ops::SliceOpGradMaker);
REGISTER_OPERATOR(slice_grad, ops::SliceOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
slice, ops::SliceKernel<paddle::platform::CPUDeviceContext, int>, slice, ops::SliceKernel<paddle::platform::CPUDeviceContext, int>,
ops::SliceKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::SliceKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::SliceKernel<paddle::platform::CPUDeviceContext, float>, ops::SliceKernel<paddle::platform::CPUDeviceContext, float>,
ops::SliceKernel<paddle::platform::CPUDeviceContext, double>); ops::SliceKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
slice_grad, ops::SliceGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::SliceGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::SliceGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::SliceGradKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -20,3 +20,10 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -20,3 +20,10 @@ REGISTER_OP_CUDA_KERNEL(
ops::SliceKernel<paddle::platform::CUDADeviceContext, double>, ops::SliceKernel<paddle::platform::CUDADeviceContext, double>,
ops::SliceKernel<paddle::platform::CUDADeviceContext, int>, ops::SliceKernel<paddle::platform::CUDADeviceContext, int>,
ops::SliceKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::SliceKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
slice_grad,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -84,5 +85,79 @@ class SliceKernel : public framework::OpKernel<T> { ...@@ -84,5 +85,79 @@ class SliceKernel : public framework::OpKernel<T> {
out_t.device(place) = in_t.slice(offsets, extents); out_t.device(place) = in_t.slice(offsets, extents);
} }
}; };
template <typename DeviceContext, typename T>
class SliceGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
size_t rank = ctx.Input<framework::Tensor>(framework::GradVarName("Out"))
->dims()
.size();
switch (rank) {
case 1:
SliceCompute<1>(ctx);
break;
case 2:
SliceCompute<2>(ctx);
break;
case 3:
SliceCompute<3>(ctx);
break;
case 4:
SliceCompute<4>(ctx);
break;
case 5:
SliceCompute<5>(ctx);
break;
case 6:
SliceCompute<6>(ctx);
break;
}
}
private:
template <size_t D>
void SliceCompute(const framework::ExecutionContext& context) const {
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto* d_out =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* d_input =
context.Output<framework::Tensor>(framework::GradVarName("Input"));
d_input->mutable_data<T>(context.GetPlace());
auto out_dims = d_out->dims();
auto in_dims = d_input->dims();
auto axes = context.Attr<std::vector<int>>("axes");
auto starts = context.Attr<std::vector<int>>("starts");
auto offsets = Eigen::array<int, D>();
auto extents = Eigen::array<int, D>();
for (size_t i = 0; i < D; ++i) {
offsets[i] = 0;
extents[i] = out_dims[i];
}
int start;
for (size_t i = 0; i < axes.size(); ++i) {
start = starts[i];
if (start < 0) {
start = (start + in_dims[axes[i]]);
}
start = std::max(start, 0);
offsets[axes[i]] = start;
}
Eigen::array<std::pair<int, int>, D> paddings;
for (size_t i = 0; i < paddings.size(); ++i) {
paddings[i].first = offsets[i];
paddings[i].second = (in_dims[i] - out_dims[i]) - offsets[i];
}
auto d_in_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*d_input);
auto d_out_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*d_out);
d_in_t.device(place) = d_out_t.pad(paddings, 0);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -41,6 +41,9 @@ class TestSliceOp(OpTest): ...@@ -41,6 +41,9 @@ class TestSliceOp(OpTest):
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad_normal(self):
self.check_grad(['Input'], 'Out', max_relative_error=0.006)
class TestCase1(TestSliceOp): class TestCase1(TestSliceOp):
def config(self): def config(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册