From 9d6243b6fb9544de33302ecb4ef8cfb7129109af Mon Sep 17 00:00:00 2001 From: whs Date: Mon, 13 Aug 2018 22:33:44 +0800 Subject: [PATCH] Fix crop op. (#12603) * Fix infer shape of crop op. * Speed crop op. --- paddle/fluid/operators/crop_op.cc | 5 ++- paddle/fluid/operators/crop_op.cu | 5 ++- paddle/fluid/operators/crop_op.h | 72 ++++++++++++++++++++++++------- 3 files changed, 63 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/operators/crop_op.cc b/paddle/fluid/operators/crop_op.cc index 5b5a220cf9..a2a871efa8 100644 --- a/paddle/fluid/operators/crop_op.cc +++ b/paddle/fluid/operators/crop_op.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -188,6 +188,7 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(crop, ops::CropOp, ops::CropOpMaker, paddle::framework::DefaultGradOpDescMaker); REGISTER_OPERATOR(crop_grad, ops::CropOpGrad); -REGISTER_OP_CPU_KERNEL(crop, ops::CropKernel); +REGISTER_OP_CPU_KERNEL( + crop, ops::CropKernel); REGISTER_OP_CPU_KERNEL( crop_grad, ops::CropGradKernel); diff --git a/paddle/fluid/operators/crop_op.cu b/paddle/fluid/operators/crop_op.cu index 1a39186046..b75678217e 100644 --- a/paddle/fluid/operators/crop_op.cu +++ b/paddle/fluid/operators/crop_op.cu @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/fluid/operators/crop_op.h" namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(crop, ops::CropKernel); +REGISTER_OP_CUDA_KERNEL( + crop, ops::CropKernel); REGISTER_OP_CUDA_KERNEL( crop_grad, ops::CropGradKernel); diff --git a/paddle/fluid/operators/crop_op.h b/paddle/fluid/operators/crop_op.h index 772e80bbea..2d7d33bd4f 100644 --- a/paddle/fluid/operators/crop_op.h +++ b/paddle/fluid/operators/crop_op.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -58,32 +58,74 @@ static std::vector GetOffsets(const framework::ExecutionContext& ctx) { return res; } -template +template +void CropFunction(const framework::ExecutionContext& context) { + auto* x = context.Input("X"); + auto* out = context.Output("Out"); + auto out_dims = out->dims(); + if (out_dims[0] == -1) { + out_dims[0] = x->dims()[0]; + } + out->mutable_data(out_dims, context.GetPlace()); + auto x_stride = framework::stride(x->dims()); + auto out_stride = framework::stride(out->dims()); + auto offsets = GetOffsets(context); + int64_t offset = 0; + for (size_t i = 0; i < offsets.size(); ++i) { + offset += (x_stride[i] * offsets[i]); + } + + auto x_tensor = EigenTensor::From(*x); + auto out_tensor = EigenTensor::From(*out); + Eigen::array e_offsets; + Eigen::array e_shape; + for (size_t i = 0; i < D; ++i) { + e_offsets[i] = offsets[i]; + e_shape[i] = out->dims()[i]; + } + auto& place = + *context.template device_context().eigen_device(); + out_tensor.device(place) = x_tensor.slice(e_offsets, e_shape); +} + +template class CropKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* x = context.Input("X"); - auto* out = context.Output("Out"); - const T* x_data = x->data(); - T* out_data = out->mutable_data(context.GetPlace()); - auto x_stride = framework::stride(x->dims()); - auto out_stride = framework::stride(out->dims()); - auto offsets = GetOffsets(context); - int64_t offset = 0; - for (size_t i = 0; i < offsets.size(); ++i) { - offset += (x_stride[i] * offsets[i]); + int rank = context.Input("X")->dims().size(); + switch (rank) { + case 1: + CropFunction(context); + break; + case 2: + CropFunction(context); + break; + case 3: + CropFunction(context); + break; + case 4: + CropFunction(context); + break; + case 5: + CropFunction(context); + break; + case 6: + CropFunction(context); + break; + default: + PADDLE_THROW( + "CropOp only support tensors with no more than 6 dimensions."); } - StridedMemcpy(context.device_context(), x_data + offset, x_stride, - out->dims(), out_stride, out_data); } }; template void CropGradFunction(const framework::ExecutionContext& context) { auto* d_x = context.Output(framework::GradVarName("X")); + auto* x = context.Input("X"); if (d_x != nullptr) { auto* d_out = context.Input(framework::GradVarName("Out")); - d_x->mutable_data(context.GetPlace()); + d_x->mutable_data(x->dims(), context.GetPlace()); auto offsets = GetOffsets(context); Eigen::array, D> paddings; for (size_t i = 0; i < D; ++i) { -- GitLab