diff --git a/paddle/function/CMakeLists.txt b/paddle/function/CMakeLists.txt index 1518a8a654cfb54376a49760dc5873733c916937..f19a1eb7774fc3356eddbe574196504b514a843c 100644 --- a/paddle/function/CMakeLists.txt +++ b/paddle/function/CMakeLists.txt @@ -37,6 +37,7 @@ if(WITH_GPU) add_simple_unittest(MulOpTest) add_simple_unittest(CosSimOpTest) add_simple_unittest(RowConvOpTest) + add_simple_unittest(CropOpTest) endif() add_simple_unittest(ConvOpTest) diff --git a/paddle/function/CropOp.cpp b/paddle/function/CropOp.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4d47d9c149c3c766c63c22abbab88e9042b8b3ee --- /dev/null +++ b/paddle/function/CropOp.cpp @@ -0,0 +1,177 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "CropOp.h" +#include "paddle/math/Vector.h" +#include "paddle/function/TensorShape.h" +namespace paddle { + +static inline CropConf castToCropConf(const FuncConfig& conf) { + return {conf.get>("crop_corner"), + conf.get>("crop_shape")}; +} + +template <> +void Crop(real* outputs, + const real* inputs, + const TensorShape inShape, + const CropConf& crop) { + int cCrop = crop.corner[0]; + int hCrop = crop.corner[1]; + int wCrop = crop.corner[2]; + + int num = inShape[0]; + int inC = inShape[1]; + int inH = inShape[2]; + int inW = inShape[3]; + + int outC = crop.shape[0]; + int outH = crop.shape[1]; + int outW = crop.shape[2]; + + for (int n = 0; n < num; n++) { + for (int c = 0; c < outC; c++) { + for (int h = 0; h < outH; h++) { + int outoff = ((n * outC + c) * outH + h) * outW; + int inoff = ((n * inC + c + cCrop) * inH + h + hCrop) * inW + wCrop; + memcpy(outputs + outoff, inputs + inoff, outW * sizeof(real)); + } + } + } +} + +template <> +void CropGrad(const real* inGrad, + real* outGrad, + const TensorShape outShape, + const CropConf& crop) { + int cCrop = crop.corner[0]; + int hCrop = crop.corner[1]; + int wCrop = crop.corner[2]; + + int num = outShape[0]; + int outC = outShape[1]; + int outH = outShape[2]; + int outW = outShape[3]; + + int inC = crop.shape[0]; + int inH = crop.shape[1]; + int inW = crop.shape[2]; + + for (int n = 0; n < num; n++) { + for (int c = 0; c < inC; c++) { + for (int h = 0; h < inH; h++) { + int outoff = ((n * outC + c + cCrop) * outH + h + hCrop) * outW + wCrop; + int inoff = ((n * inC + c) * inH + h) * inW; + CpuVector inG = CpuVector(inW, const_cast(inGrad + inoff)); + CpuVector outG = CpuVector(inW, outGrad + outoff); + outG += inG; + } + } + } +} + +/** + * \brief Crop input according to the specify corner and shape. + * The input and output is a 4D tensor. In CropFunc, we only + * crop the 2nd to 4th dimension. + * + * Argument in this Function: + * \param pad_ A struct object contains the cropping corner and shape. + * \param inputs A 4D tensor, only one input. + * \param outputs A 4D tensor, the output value after cropping. + * + * For example, + * Input(2,2,2,3) = [ + * [ [[1,2,3], [3,4,5]], + * [[2,3,5], [1,6,7]] ], + * [ [[4,3,1], [1,8,7]], + * [[3,8,9], [2,3,5]] ] + * ] # the input shape is (2,2,2,3) + * + * pad_: if corner = (0,1,1) and crop_shape = (2,1,2) + * Output(2,2,1,2) = [ + * [ [[4,5]], + * [[6,7]] ], + * [ [[8,7]], + * [[3,5]] ] + * ] # the input shape is (2,2,2,3) + */ +template +class CropFunc : public FunctionBase { +public: + void init(const FuncConfig& config) override { + crop_ = castToCropConf(config); + } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(1UL, inputs.size()); + CHECK_EQ(1UL, outputs.size()); + CHECK_EQ(outputs[0].shape()[1], crop_.shape[0]); + CHECK_EQ(outputs[0].shape()[2], crop_.shape[1]); + CHECK_EQ(outputs[0].shape()[3], crop_.shape[2]); + CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); + + TensorShape inShape = inputs[0].shape(); + + Crop( + outputs[0].data(), inputs[0].data(), inShape, crop_); + } + +private: + CropConf crop_; +}; + +/** + * \brief The backward propagation of cropping Function. + * + * Argument in this Function: + * \param crop_ The same meaning as it in CropFunc. + * \param inputs The gradient with respect to the output value of CropFunc. + * \param outputs The gradient with respect to the input value of CropFunc. + */ + +template +class CropGradFunc : public FunctionBase { +public: + void init(const FuncConfig& config) override { + crop_ = castToCropConf(config); + } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(1UL, inputs.size()); + CHECK_EQ(1UL, outputs.size()); + CHECK_EQ(inputs[0].shape()[1], crop_.shape[0]); + CHECK_EQ(inputs[0].shape()[2], crop_.shape[1]); + CHECK_EQ(inputs[0].shape()[3], crop_.shape[2]); + CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); + + TensorShape outShape = outputs[0].shape(); + + CropGrad( + inputs[0].data(), outputs[0].data(), outShape, crop_); + } + +private: + CropConf crop_; +}; + +REGISTER_TYPED_FUNC(Crop, CPU, CropFunc); +REGISTER_TYPED_FUNC(CropGrad, CPU, CropGradFunc); +#ifndef PADDLE_ONLY_CPU +REGISTER_TYPED_FUNC(Crop, GPU, CropFunc); +REGISTER_TYPED_FUNC(CropGrad, GPU, CropGradFunc); +#endif + +} // namespace paddle diff --git a/paddle/function/CropOp.h b/paddle/function/CropOp.h new file mode 100644 index 0000000000000000000000000000000000000000..78a55bd43e9a8ef4ce98308306ac5357ab8c6185 --- /dev/null +++ b/paddle/function/CropOp.h @@ -0,0 +1,56 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "Function.h" + +namespace paddle { + +struct CropConf { + /// The upper left corner of croped result + std::vector corner; + /// The shape of croped result + std::vector shape; +}; + +/** + * \brief This funtion crops inputs according to the specify start point and + *shape. + * + * \param[out] outputs save results. + * \param[in] inputs input data. + * \param[in] inShape the shape of input tensor. + * \param[in] crop the cropping config + */ +template +void Crop(real* outputs, + const real* inputs, + const TensorShape inShape, + const CropConf& crop); + +/** + * \brief Cropping operation backward. + * + * \param[out] inGrad gradients of previous layer + * \param[in] outGrad output gradient + * \param[in] inShape the shape of input tensor. + * \param[in] crop the cropping config + */ +template +void CropGrad(const real* inGrad, + real* outGrad, + const TensorShape inShape, + const CropConf& crop); +} // namespace paddle diff --git a/paddle/function/CropOpGpu.cu b/paddle/function/CropOpGpu.cu new file mode 100644 index 0000000000000000000000000000000000000000..f7d7d03abd2575747e7cd2b8eb997057895fbf11 --- /dev/null +++ b/paddle/function/CropOpGpu.cu @@ -0,0 +1,109 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "hl_base.h" +#include "CropOp.h" + +namespace paddle { + +__global__ void KeCrop(real* outputs, const real* inputs, + int inC, int inH, int inW, + int cropC, int cropH, int cropW, + int outC, int outH, int outW, int nthreads) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < nthreads) { + const int w = idx % outW; + const int h = (idx / outW) % outH; + const int c = (idx / outW / outH) % outC; + const int n = idx / outW / outH / outC; + + const int off = ((n * inC + c + cropC) * inH + h + cropH) * inW + cropW + w; + outputs[idx] = inputs[off]; + } +} + +template <> +void Crop(real* outputs, + const real* inputs, + const TensorShape inShape, + const CropConf& crop) { + int cropC = crop.corner[0]; + int cropH = crop.corner[1]; + int cropW = crop.corner[2]; + + int num = inShape[0]; + int inC = inShape[1]; + int inH = inShape[2]; + int inW = inShape[3]; + + int outC = crop.shape[0]; + int outH = crop.shape[1]; + int outW = crop.shape[2]; + + size_t nth = num * outC * outH * outW; + int blockSize = 1024; + int gridSize = (nth + blockSize - 1) / blockSize; + + KeCrop<<>> + (outputs, inputs, inC, inH, inW, cropC, cropH, cropW, + outC, outH, outW, nth); + CHECK_SYNC("Crop"); +} + +__global__ void KeCropDiff(const real* inGrad, real* outGrad, + int inC, int inH, int inW, + int cropC, int cropH, int cropW, + int outC, int outH, int outW, int nthreads) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < nthreads) { + const int w = idx % inW; + const int h = (idx / inW) % inH; + const int c = (idx / inW / inH) % inC; + const int n = idx / inW / inH / inC; + + const int off = ((n * outC + c + cropC) * outH + h + cropH) * outW + cropW + w; + + outGrad[off] += inGrad[idx]; + } +} + +template <> +void CropGrad(const real* inGrad, + real* outGrad, + const TensorShape outShape, + const CropConf& crop) { + int cropC = crop.corner[0]; + int cropH = crop.corner[1]; + int cropW = crop.corner[2]; + + int num = outShape[0]; + int outC = outShape[1]; + int outH = outShape[2]; + int outW = outShape[3]; + + int inC = crop.shape[0]; + int inH = crop.shape[1]; + int inW = crop.shape[2]; + + size_t nth = num * inC * inH * inW; + int blockSize = 1024; + int gridSize = (nth + blockSize - 1) / blockSize; + + KeCropDiff <<>> + (inGrad, outGrad, inC, inH, inW, cropC, cropH, cropW, + outC, outH, outW, nth); + CHECK_SYNC("CropGrad"); +} + +} // namespace paddle diff --git a/paddle/function/CropOpTest.cpp b/paddle/function/CropOpTest.cpp new file mode 100644 index 0000000000000000000000000000000000000000..62b4bd9fdea040a00409313f890a9f5b80a949e4 --- /dev/null +++ b/paddle/function/CropOpTest.cpp @@ -0,0 +1,47 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "FunctionTest.h" + +namespace paddle { + +TEST(Crop, real) { + for (size_t numSamples : {5, 32}) { + for (size_t channels : {5, 5, 32}) { + for (size_t imgSizeH : {5, 33, 100}) { + for (size_t imgSizeW : {5, 32, 96}) { + VLOG(3) << " numSamples=" << numSamples << " channels=" << channels + << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW; + for (bool test_grad : {false, true}) { + FunctionCompare compare( + test_grad ? "CropGrad" : "Crop", + FuncConfig() + .set>("crop_corner", {1, 1, 1}) + .set>("crop_shape", {2, 3, 3})); + TensorShape inDims{numSamples, channels, imgSizeH, imgSizeW}; + TensorShape outDims{numSamples, 2, 3, 3}; + compare.addInputs( + BufferArg(VALUE_TYPE_FLOAT, test_grad ? outDims : inDims)); + compare.addOutputs(BufferArg( + VALUE_TYPE_FLOAT, test_grad ? inDims : outDims, ASSIGN_TO)); + compare.run(); + } + } + } + } + } +} + +} // namespace paddle diff --git a/paddle/gserver/layers/CropLayer.cpp b/paddle/gserver/layers/CropLayer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ab23d4617e3ae7f7c038bea6a153bc4d1fff79cc --- /dev/null +++ b/paddle/gserver/layers/CropLayer.cpp @@ -0,0 +1,101 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "CropLayer.h" +#include "paddle/utils/Stat.h" + +namespace paddle { + +REGISTER_LAYER(crop, CropLayer); + +bool CropLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + /* Initialize the basic parent class */ + Layer::init(layerMap, parameterMap); + + auto& crop_conf = config_.inputs(0).crop_conf(); + auto& img_conf = crop_conf.image_conf(); + CHECK_EQ(config_.inputs_size(), 1); + inDims_ = TensorShape( + {0, + img_conf.channels(), + img_conf.has_img_size_y() ? img_conf.img_size_y() : img_conf.img_size(), + img_conf.img_size()}); + + crop_corner_ = {crop_conf.crop_corner(0), + crop_conf.crop_corner(1), + crop_conf.crop_corner(2)}; + crop_shape_ = {crop_conf.crop_shape(0), + crop_conf.crop_shape(1), + crop_conf.crop_shape(2)}; + + outDims_ = TensorShape(4); + setOutDims(0); + + createFunction(forward_, + "Crop", + FuncConfig() + .set("crop_corner", crop_corner_) + .set("crop_shape", crop_shape_)); + createFunction(backward_, + "CropGrad", + FuncConfig() + .set("crop_corner", crop_corner_) + .set("crop_shape", crop_shape_)); + + return true; +} + +void CropLayer::setOutDims(const size_t batchSize) { + outDims_.reshape({batchSize, crop_shape_[0], crop_shape_[1], crop_shape_[2]}); +} + +void CropLayer::setTensorDim(const size_t batchSize) { + CHECK_EQ(static_cast(inputLayers_.size()), 1); + inDims_.setDim(0, batchSize); + int h = inputLayers_[0]->getOutput().getFrameHeight(); + if (h != 0) inDims_.setDim(2, h); + int w = inputLayers_[0]->getOutput().getFrameWidth(); + if (w != 0) inDims_.setDim(3, w); + setOutDims(batchSize); +} + +void CropLayer::forward(PassType passType) { + Layer::forward(passType); + MatrixPtr input = inputLayers_[0]->getOutputValue(); + size_t batchSize = input->getHeight(); + setTensorDim(batchSize); + int size = outDims_[1] * outDims_[2] * outDims_[3]; + resetOutput(batchSize, size); + MatrixPtr outV = getOutputValue(); + REGISTER_TIMER_INFO("CropForward", getName().c_str()); + + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*getInputValue(0), inDims_); + outputs.addArg(*getOutputValue(), outDims_, ASSIGN_TO); + forward_[0]->calc(inputs, outputs); +} + +void CropLayer::backward(const UpdateCallback& callback) { + (void)callback; + REGISTER_TIMER_INFO("CropBackward", getName().c_str()); + + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*getOutputGrad(), outDims_); + outputs.addArg(*getInputGrad(0), inDims_, ADD_TO); + backward_[0]->calc(inputs, outputs); +} +} // namespace paddle diff --git a/paddle/gserver/layers/CropLayer.h b/paddle/gserver/layers/CropLayer.h new file mode 100644 index 0000000000000000000000000000000000000000..3ce89707cafd80697068becf6603dda9907047b2 --- /dev/null +++ b/paddle/gserver/layers/CropLayer.h @@ -0,0 +1,46 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "Layer.h" + +namespace paddle { + +/** + * \brief This layer crop inputs according to the specify corner and shape. + * The input and output is a 4D tensor. Cropping from the 2nd to + * the 4th dimenstion. + */ +class CropLayer : public Layer { +public: + explicit CropLayer(const LayerConfig& config) : Layer(config) {} + + ~CropLayer() {} + + bool init(const LayerMap& layerMap, + const ParameterMap& parameterMap) override; + void forward(PassType passType) override; + void backward(const UpdateCallback& callback = nullptr) override; + +protected: + void setOutDims(const size_t batchSize); + void setTensorDim(const size_t batchSize); + + std::vector crop_corner_; + std::vector crop_shape_; + TensorShape inDims_; + TensorShape outDims_; +}; +} // namespace paddle