提交 9d6243b6 编写于 作者: W whs 提交者: qingqing01

Fix crop op. (#12603)

* Fix infer shape of crop op.
* Speed crop op.
上级 49ad570e
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -188,6 +188,7 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(crop, ops::CropOp, ops::CropOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(crop_grad, ops::CropOpGrad);
REGISTER_OP_CPU_KERNEL(crop, ops::CropKernel<float>);
REGISTER_OP_CPU_KERNEL(
crop, ops::CropKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL(
crop_grad, ops::CropGradKernel<paddle::platform::CPUDeviceContext, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/operators/crop_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(crop, ops::CropKernel<float>);
REGISTER_OP_CUDA_KERNEL(
crop, ops::CropKernel<paddle::platform::CUDADeviceContext, float>);
REGISTER_OP_CUDA_KERNEL(
crop_grad, ops::CropGradKernel<paddle::platform::CUDADeviceContext, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -58,14 +58,15 @@ static std::vector<int> GetOffsets(const framework::ExecutionContext& ctx) {
return res;
}
template <typename T>
class CropKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
template <typename DeviceContext, typename T, size_t D>
void CropFunction(const framework::ExecutionContext& context) {
auto* x = context.Input<Tensor>("X");
auto* out = context.Output<Tensor>("Out");
const T* x_data = x->data<T>();
T* out_data = out->mutable_data<T>(context.GetPlace());
auto out_dims = out->dims();
if (out_dims[0] == -1) {
out_dims[0] = x->dims()[0];
}
out->mutable_data<T>(out_dims, context.GetPlace());
auto x_stride = framework::stride(x->dims());
auto out_stride = framework::stride(out->dims());
auto offsets = GetOffsets(context);
......@@ -73,17 +74,58 @@ class CropKernel : public framework::OpKernel<T> {
for (size_t i = 0; i < offsets.size(); ++i) {
offset += (x_stride[i] * offsets[i]);
}
StridedMemcpy<T>(context.device_context(), x_data + offset, x_stride,
out->dims(), out_stride, out_data);
auto x_tensor = EigenTensor<T, D>::From(*x);
auto out_tensor = EigenTensor<T, D>::From(*out);
Eigen::array<int, D> e_offsets;
Eigen::array<int, D> e_shape;
for (size_t i = 0; i < D; ++i) {
e_offsets[i] = offsets[i];
e_shape[i] = out->dims()[i];
}
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
out_tensor.device(place) = x_tensor.slice(e_offsets, e_shape);
}
template <typename DeviceContext, typename T>
class CropKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
int rank = context.Input<Tensor>("X")->dims().size();
switch (rank) {
case 1:
CropFunction<DeviceContext, T, 1>(context);
break;
case 2:
CropFunction<DeviceContext, T, 2>(context);
break;
case 3:
CropFunction<DeviceContext, T, 3>(context);
break;
case 4:
CropFunction<DeviceContext, T, 4>(context);
break;
case 5:
CropFunction<DeviceContext, T, 5>(context);
break;
case 6:
CropFunction<DeviceContext, T, 6>(context);
break;
default:
PADDLE_THROW(
"CropOp only support tensors with no more than 6 dimensions.");
}
}
};
template <typename DeviceContext, typename T, size_t D>
void CropGradFunction(const framework::ExecutionContext& context) {
auto* d_x = context.Output<Tensor>(framework::GradVarName("X"));
auto* x = context.Input<Tensor>("X");
if (d_x != nullptr) {
auto* d_out = context.Input<Tensor>(framework::GradVarName("Out"));
d_x->mutable_data<T>(context.GetPlace());
d_x->mutable_data<T>(x->dims(), context.GetPlace());
auto offsets = GetOffsets(context);
Eigen::array<std::pair<int, int>, D> paddings;
for (size_t i = 0; i < D; ++i) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册