diff --git a/CMakeLists.txt b/CMakeLists.txt index dcff6b54cafce35846627e78cfcdac65fae7e686..2a6b0a20e441676c85c9ed8f8ad1a6e7abdf1ea8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,7 +13,6 @@ # limitations under the License cmake_minimum_required(VERSION 3.0) - set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake") set(PROJ_ROOT ${CMAKE_CURRENT_SOURCE_DIR}) set(PROJ_BINARY_ROOT ${CMAKE_CURRENT_BINARY_DIR}) diff --git a/paddle/function/CMakeLists.txt b/paddle/function/CMakeLists.txt index a5b14c0c71c18da1bb0b506c663f8680b1c3830a..2bec00cdb2d32d01a5a24e662bcca07f4154939c 100644 --- a/paddle/function/CMakeLists.txt +++ b/paddle/function/CMakeLists.txt @@ -36,6 +36,7 @@ if(WITH_GPU) add_simple_unittest(MulOpTest) add_simple_unittest(CosSimOpTest) add_simple_unittest(RowConvOpTest) + add_simple_unittest(CropOpTest) endif() add_simple_unittest(ConvOpTest) diff --git a/paddle/function/CropOp.cpp b/paddle/function/CropOp.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f12ee43e3d72f9ac776eaff93914228850694dd2 --- /dev/null +++ b/paddle/function/CropOp.cpp @@ -0,0 +1,177 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "CropOp.h" +#include "paddle/function/TensorShape.h" +#include "paddle/math/Vector.h" + +namespace paddle { + +template <> +void Crop(real* outputs, + const real* inputs, + const TensorShape inShape, + const TensorShape outShape, + const FuncConfig& conf) { + std::vector crop_corner = + conf.get>("crop_corner"); + int cCrop = crop_corner[1]; + int hCrop = crop_corner[2]; + int wCrop = crop_corner[3]; + + int num = inShape[0]; + int inC = inShape[1]; + int inH = inShape[2]; + int inW = inShape[3]; + + int outC = outShape[1]; + int outH = outShape[2]; + int outW = outShape[3]; + + for (int n = 0; n < num; n++) { + for (int c = 0; c < outC; c++) { + for (int h = 0; h < outH; h++) { + int outoff = ((n * outC + c) * outH + h) * outW; + int inoff = ((n * inC + c + cCrop) * inH + h + hCrop) * inW + wCrop; + memcpy(outputs + outoff, inputs + inoff, outW * sizeof(real)); + } + } + } +} + +template <> +void CropGrad(const real* inGrad, + real* outGrad, + const TensorShape inShape, + const TensorShape outShape, + const FuncConfig& conf) { + std::vector crop_corner = + conf.get>("crop_corner"); + int cCrop = crop_corner[1]; + int hCrop = crop_corner[2]; + int wCrop = crop_corner[3]; + + int num = outShape[0]; + int outC = outShape[1]; + int outH = outShape[2]; + int outW = outShape[3]; + + int inC = inShape[1]; + int inH = inShape[2]; + int inW = inShape[3]; + + for (int n = 0; n < num; n++) { + for (int c = 0; c < inC; c++) { + for (int h = 0; h < inH; h++) { + int outoff = ((n * outC + c + cCrop) * outH + h + hCrop) * outW + wCrop; + int inoff = ((n * inC + c) * inH + h) * inW; + CpuVector inG = CpuVector(inW, const_cast(inGrad + inoff)); + CpuVector outG = CpuVector(inW, outGrad + outoff); + outG += inG; + } + } + } +} + +/** + * \brief Crop input according to the specify corner and shape. + * The input and output is a 4D tensor. In CropFunc, we only + * crop the 2nd to 4th dimension. + * + * Argument in this Function: + * \param pad_ A struct object contains the cropping corner and shape. + * \param inputs A 4D tensor, only one input. + * \param outputs A 4D tensor, the output value after cropping. + * + * For example, + * Input(2,2,2,3) = [ + * [ [[1,2,3], [3,4,5]], + * [[2,3,5], [1,6,7]] ], + * [ [[4,3,1], [1,8,7]], + * [[3,8,9], [2,3,5]] ] + * ] # the input shape is (2,2,2,3) + * + * pad_: if corner = (0,1,1) and crop_shape = (2,1,2) + * Output(2,2,1,2) = [ + * [ [[4,5]], + * [[6,7]] ], + * [ [[8,7]], + * [[3,5]] ] + * ] # the input shape is (2,2,2,3) + */ +template +class CropFunc : public FunctionBase { +public: + void init(const FuncConfig& config) override { conf_ = config; } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(1UL, inputs.size()); + CHECK_EQ(1UL, outputs.size()); + CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); + + TensorShape inShape = inputs[0].shape(); + TensorShape outShape = outputs[0].shape(); + + Crop(outputs[0].data(), + inputs[0].data(), + inShape, + outShape, + conf_); + } + +private: + FuncConfig conf_; +}; + +/** + * \brief The backward propagation of cropping Function. + * + * Argument in this Function: + * \param crop_ The same meaning as it in CropFunc. + * \param inputs The gradient with respect to the output value of CropFunc. + * \param outputs The gradient with respect to the input value of CropFunc. + */ + +template +class CropGradFunc : public FunctionBase { +public: + void init(const FuncConfig& config) override { conf_ = config; } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(1UL, inputs.size()); + CHECK_EQ(1UL, outputs.size()); + CHECK_EQ(outputs[0].getArgType(), ADD_TO); + + TensorShape outShape = outputs[0].shape(); + TensorShape inShape = inputs[0].shape(); + + CropGrad(inputs[0].data(), + outputs[0].data(), + inShape, + outShape, + conf_); + } + +private: + FuncConfig conf_; +}; + +REGISTER_TYPED_FUNC(Crop, CPU, CropFunc); +REGISTER_TYPED_FUNC(CropGrad, CPU, CropGradFunc); +#ifndef PADDLE_ONLY_CPU +REGISTER_TYPED_FUNC(Crop, GPU, CropFunc); +REGISTER_TYPED_FUNC(CropGrad, GPU, CropGradFunc); +#endif + +} // namespace paddle diff --git a/paddle/function/CropOp.h b/paddle/function/CropOp.h new file mode 100644 index 0000000000000000000000000000000000000000..87986fbdc7e33aeb24d947e82a5d67ba23f532de --- /dev/null +++ b/paddle/function/CropOp.h @@ -0,0 +1,51 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "Function.h" + +namespace paddle { + +/** + * \brief This funtion crops inputs according to the specify start point and + *shape. + * + * \param[out] outputs save results. + * \param[in] inputs input data. + * \param[in] inShape the shape of input tensor. + * \param[in] conf the cropping config + */ +template +void Crop(real* outputs, + const real* inputs, + const TensorShape inShape, + const TensorShape outShape, + const FuncConfig& conf); + +/** + * \brief Cropping operation backward. + * + * \param[out] inGrad gradients of previous layer + * \param[in] outGrad output gradient + * \param[in] inShape the shape of input tensor. + * \param[in] conf the cropping config + */ +template +void CropGrad(const real* inGrad, + real* outGrad, + const TensorShape inShape, + const TensorShape outShape, + const FuncConfig& conf); +} // namespace paddle diff --git a/paddle/function/CropOpGpu.cu b/paddle/function/CropOpGpu.cu new file mode 100644 index 0000000000000000000000000000000000000000..37ce6de0647e5e06a231710b5a53089533de2407 --- /dev/null +++ b/paddle/function/CropOpGpu.cu @@ -0,0 +1,113 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "hl_base.h" +#include "CropOp.h" + +namespace paddle { + +__global__ void KeCrop(real* outputs, const real* inputs, + int inC, int inH, int inW, + int cropC, int cropH, int cropW, + int outC, int outH, int outW, int nthreads) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < nthreads) { + const int w = idx % outW; + const int h = (idx / outW) % outH; + const int c = (idx / outW / outH) % outC; + const int n = idx / outW / outH / outC; + + const int off = ((n * inC + c + cropC) * inH + h + cropH) * inW + cropW + w; + outputs[idx] = inputs[off]; + } +} + +template <> +void Crop(real* outputs, + const real* inputs, + const TensorShape inShape, + const TensorShape outShape, + const FuncConfig& conf) { + std::vector crop_corner = conf.get>("crop_corner"); + int cropC = crop_corner[1]; + int cropH = crop_corner[2]; + int cropW = crop_corner[3]; + + int num = inShape[0]; + int inC = inShape[1]; + int inH = inShape[2]; + int inW = inShape[3]; + + int outC = outShape[1]; + int outH = outShape[2]; + int outW = outShape[3]; + + size_t nth = num * outC * outH * outW; + int blockSize = 1024; + int gridSize = (nth + blockSize - 1) / blockSize; + + KeCrop<<>> + (outputs, inputs, inC, inH, inW, cropC, cropH, cropW, + outC, outH, outW, nth); + CHECK_SYNC("Crop"); +} + +__global__ void KeCropDiff(const real* inGrad, real* outGrad, + int inC, int inH, int inW, + int cropC, int cropH, int cropW, + int outC, int outH, int outW, int nthreads) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < nthreads) { + const int w = idx % inW; + const int h = (idx / inW) % inH; + const int c = (idx / inW / inH) % inC; + const int n = idx / inW / inH / inC; + + const int off = ((n * outC + c + cropC) * outH + h + cropH) * outW + cropW + w; + + outGrad[off] += inGrad[idx]; + } +} + +template <> +void CropGrad(const real* inGrad, + real* outGrad, + const TensorShape inShape, + const TensorShape outShape, + const FuncConfig& conf) { + std::vector crop_corner = conf.get>("crop_corner"); + int cropC = crop_corner[1]; + int cropH = crop_corner[2]; + int cropW = crop_corner[3]; + + int num = outShape[0]; + int outC = outShape[1]; + int outH = outShape[2]; + int outW = outShape[3]; + + int inC = inShape[1]; + int inH = inShape[2]; + int inW = inShape[3]; + + size_t nth = num * inC * inH * inW; + int blockSize = 1024; + int gridSize = (nth + blockSize - 1) / blockSize; + + KeCropDiff <<>> + (inGrad, outGrad, inC, inH, inW, cropC, cropH, cropW, + outC, outH, outW, nth); + CHECK_SYNC("CropGrad"); +} + +} // namespace paddle diff --git a/paddle/function/CropOpTest.cpp b/paddle/function/CropOpTest.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6f11abfdf6f752857e0a75c62fb2b5c089c206d9 --- /dev/null +++ b/paddle/function/CropOpTest.cpp @@ -0,0 +1,49 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "FunctionTest.h" + +namespace paddle { + +TEST(Crop, real) { + for (size_t numSamples : {5, 32}) { + for (size_t channels : {5, 5, 32}) { + for (size_t imgSizeH : {5, 33, 100}) { + for (size_t imgSizeW : {5, 32, 96}) { + VLOG(3) << " numSamples=" << numSamples << " channels=" << channels + << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW; + for (bool test_grad : {false, true}) { + CpuGpuFuncCompare compare( + test_grad ? "CropGrad" : "Crop", + FuncConfig() + .set>("crop_corner", {0, 1, 1, 1}) + .set>("crop_shape", {0, 2, 3, 3})); + TensorShape inDims{numSamples, channels, imgSizeH, imgSizeW}; + TensorShape outDims{numSamples, 2, 3, 3}; + compare.addInputs( + BufferArg(VALUE_TYPE_FLOAT, test_grad ? outDims : inDims)); + compare.addOutputs(BufferArg(VALUE_TYPE_FLOAT, + test_grad ? inDims : outDims, + test_grad ? ADD_TO : ASSIGN_TO), + test_grad ? ADD_TO : ASSIGN_TO); + compare.run(); + } + } + } + } + } +} + +} // namespace paddle diff --git a/paddle/gserver/layers/CropLayer.cpp b/paddle/gserver/layers/CropLayer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..69ad913420bdb6e1b2ed0618b7f9b78d7477be99 --- /dev/null +++ b/paddle/gserver/layers/CropLayer.cpp @@ -0,0 +1,146 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "CropLayer.h" +#include "paddle/utils/Stat.h" +namespace paddle { + +REGISTER_LAYER(crop, CropLayer); + +bool CropLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + /* Initialize the basic parent class */ + Layer::init(layerMap, parameterMap); + CHECK_LE(static_cast(inputLayers_.size()), 2); + CHECK_GE(static_cast(inputLayers_.size()), 1); + crop_axis_ = config_.axis(); + for (int i = 0; i < config_.offset_size(); i++) { + crop_offsets_.push_back(config_.offset(i)); + } + + // 1. get input_0 shape + auto& input0_img_conf = config_.inputs(0).image_conf(); + inDims_ = TensorShape({0, + input0_img_conf.channels(), + input0_img_conf.has_img_size_y() + ? input0_img_conf.img_size_y() + : input0_img_conf.img_size(), + input0_img_conf.img_size()}); + // 2. get target dims from config + if (config_.inputs_size() == 1) { + targetDims_ = TensorShape({config_.shape(0), + config_.shape(1), + config_.shape(2), + config_.shape(3)}); + } else { + // 2. get input_1 shape + auto& input1_img_conf = config_.inputs(1).image_conf(); + targetDims_ = TensorShape({0, + input1_img_conf.channels(), + input1_img_conf.has_img_size_y() + ? input1_img_conf.img_size_y() + : input1_img_conf.img_size(), + input1_img_conf.img_size()}); + } + + // 3. get final crop corner + int dimSize = 4; + crop_corner_ = {0, 0, 0, 0}; + for (int i = 0; i < dimSize; i++) { + if (i >= crop_axis_) { + if (crop_offsets_.size() > 1) { + crop_corner_[i] = crop_offsets_[i - crop_axis_]; + } else { + crop_corner_[i] = crop_offsets_[0]; + } + } + } + + outDims_ = TensorShape(4); + + createFunction( + forward_, "Crop", FuncConfig().set("crop_corner", crop_corner_)); + createFunction( + backward_, "CropGrad", FuncConfig().set("crop_corner", crop_corner_)); + + return true; +} + +void CropLayer::setOutDims() { + MatrixPtr input = inputLayers_[1]->getOutputValue(); + size_t batchSize = input->getHeight(); + // get target dims from input_1 + if (config_.inputs_size() == 2) { + targetDims_.setDim(0, batchSize); + int ch = config_.inputs(0).image_conf().channels(); + if (ch != 0) targetDims_.setDim(1, ch); + int h = inputLayers_[1]->getOutput().getFrameHeight(); + if (h != 0) targetDims_.setDim(2, h); + int w = inputLayers_[1]->getOutput().getFrameWidth(); + if (w != 0) targetDims_.setDim(3, w); + } + // get final crop shape from target dims and crop axis + std::vector crop_shape; + int dimSize = 4; + for (int i = 0; i < dimSize; i++) { + if (i >= crop_axis_) { + crop_shape.push_back(targetDims_[i]); + } else { + crop_shape.push_back(inDims_[i]); + } + } + + outDims_.reshape( + {crop_shape[0], crop_shape[1], crop_shape[2], crop_shape[3]}); + output_.setFrameHeight(crop_shape[2]); + output_.setFrameWidth(crop_shape[3]); +} + +void CropLayer::setInDims() { + MatrixPtr input = inputLayers_[0]->getOutputValue(); + size_t batchSize = input->getHeight(); + inDims_.setDim(0, batchSize); + int h = inputLayers_[0]->getOutput().getFrameHeight(); + if (h != 0) inDims_.setDim(2, h); + int w = inputLayers_[0]->getOutput().getFrameWidth(); + if (w != 0) inDims_.setDim(3, w); +} + +void CropLayer::forward(PassType passType) { + Layer::forward(passType); + setInDims(); + setOutDims(); + int size = outDims_[1] * outDims_[2] * outDims_[3]; + resetOutput(outDims_[0], size); + MatrixPtr outV = getOutputValue(); + REGISTER_TIMER_INFO("CropForward", getName().c_str()); + + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*getInputValue(0), inDims_); + outputs.addArg(*getOutputValue(), outDims_, ASSIGN_TO); + forward_[0]->calc(inputs, outputs); +} + +void CropLayer::backward(const UpdateCallback& callback) { + (void)callback; + REGISTER_TIMER_INFO("CropBackward", getName().c_str()); + + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*getOutputGrad(), outDims_); + outputs.addArg(*getInputGrad(0), inDims_, ADD_TO); + backward_[0]->calc(inputs, outputs); +} +} // namespace paddle diff --git a/paddle/gserver/layers/CropLayer.h b/paddle/gserver/layers/CropLayer.h new file mode 100644 index 0000000000000000000000000000000000000000..6b6202621023575c1c83049ecbd019656c726e3f --- /dev/null +++ b/paddle/gserver/layers/CropLayer.h @@ -0,0 +1,52 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "Layer.h" + +namespace paddle { + +/** + * \brief This layer crop input according to the specify conf. + * input_0: input to be cropped + * input_1: optional reference input + * axis: start dimension to be croped + * offset: offset of cropping in each dimension + * shape: if reference input layer was not setted, + * crop input as this shape conf + */ +class CropLayer : public Layer { +public: + explicit CropLayer(const LayerConfig& config) : Layer(config) {} + + ~CropLayer() {} + + bool init(const LayerMap& layerMap, + const ParameterMap& parameterMap) override; + void forward(PassType passType) override; + void backward(const UpdateCallback& callback = nullptr) override; + +protected: + void setOutDims(); + void setInDims(); + + int32_t crop_axis_; + std::vector crop_offsets_; + std::vector crop_corner_; + TensorShape inDims_; + TensorShape targetDims_; + TensorShape outDims_; +}; +} // namespace paddle diff --git a/paddle/gserver/tests/CMakeLists.txt b/paddle/gserver/tests/CMakeLists.txt index 92f6cbcfe5a0e23c5939b1689a3e339367450387..a43adc7ce7db937bd62ea9bf1533b8a5899c259a 100644 --- a/paddle/gserver/tests/CMakeLists.txt +++ b/paddle/gserver/tests/CMakeLists.txt @@ -56,7 +56,7 @@ add_test(NAME test_DetectionOutput add_unittest_without_exec(test_ConvUnify test_ConvUnify.cpp LayerGradUtil.cpp) - + add_test(NAME test_ConvUnify COMMAND test_ConvUnify) ################# test_BatchNorm ####################### diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index 67251f08e34faff57d9e6fd6a1163ba655619a8b..9af083468c0f01218117211f9e4931ca0669e96a 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -1802,6 +1802,34 @@ TEST(Layer, RowConvLayer) { } } +TEST(Layer, CropLayer) { + TestConfig config; + // config input_0 + config.inputDefs.push_back({INPUT_DATA, "layer_0", 1024, 0}); + LayerInputConfig* input = config.layerConfig.add_inputs(); + ImageConfig* img = input->mutable_image_conf(); + img->set_channels(4); + img->set_img_size(16); + config.layerConfig.set_axis(2); + config.layerConfig.add_offset(0); + config.layerConfig.add_offset(0); + + // config input_1 + config.inputDefs.push_back({INPUT_DATA, "layer_1", 128, 0}); + input = config.layerConfig.add_inputs(); + img = input->mutable_image_conf(); + img->set_channels(2); + img->set_img_size(8); + + // config crop layer + config.layerConfig.set_type("crop"); + config.layerConfig.set_name("cropLayer"); + + for (auto useGpu : {false, true}) { + testLayerGrad(config, "crop", 100, false, useGpu, false); + } +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); initMain(argc, argv); diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto index 37cd16c79890738f6d8966579e15686c653d4df3..83f72c137bdf5e55f28be908321bd2ccd6c906fe 100644 --- a/proto/ModelConfig.proto +++ b/proto/ModelConfig.proto @@ -472,10 +472,16 @@ message LayerConfig { // blank label used in ctc loss optional uint32 blank = 52 [default = 0]; - // stride parameter for seqlastins layer, AverageLayer, MaxLayer, which + // stride parameter for seqlastins layer, AverageLayer, MaxLayer, which // controls the scope of pooling operation. can be set > 0. // leave empty or set to -1 to disable this stride pooling. optional int32 seq_pool_stride = 53 [default = -1]; + + // for crop layer + optional int32 axis = 54 [default = 2]; + repeated uint32 offset = 55; + repeated uint32 shape = 56; + } message EvaluatorConfig { diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index ef3d81e4c0791ca7847dc607682fa39ff15967da..ab81e67579e39a34e3ace18d14434eb86b66fa5b 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -1998,6 +1998,23 @@ class PadLayer(LayerBase): self.config.size = out_ch * out_h * out_w +@config_layer('crop') +class CropLayer(LayerBase): + def __init__(self, name, inputs, axis, offset, shape, **xargs): + super(CropLayer, self).__init__(name, 'crop', 0, inputs=inputs, **xargs) + self.config.axis = axis + self.config.offset.extend(offset) + self.config.shape.extend(shape) + + # get channel, width and height from input_0 layer + input_layer = self.get_input_layer(0) + image_conf = self.config.inputs[0].image_conf + image_conf.img_size = input_layer.width + image_conf.img_size_y = input_layer.height + image_conf.channels = input_layer.size / (input_layer.width * + input_layer.height) + + @config_layer('batch_norm') class BatchNormLayer(LayerBase): layer_type = 'batch_norm' diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 78aa0778f8d1dca9fae82f0411be5a00e636cbc9..fdb6f83f2ba510232714fb8a9c7c1af837a753ff 100755 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -127,6 +127,7 @@ __all__ = [ 'dropout_layer', 'prelu_layer', 'gated_unit_layer', + 'crop_layer', ] @@ -218,6 +219,7 @@ class LayerType(object): SMOOTH_L1 = 'smooth_l1' PRELU = 'prelu' + CROP_LAYER = 'crop' @staticmethod def is_layer_type(type_name): @@ -5970,3 +5972,52 @@ def gated_unit_layer(input, name="%s_gated_act" % name, input=dotmul_operator(input_proj, gate), layer_attr=layer_attr) + + +@wrap_name_default() +@layer_support() +def crop_layer(input, offset, axis=2, shape=None, name=None, layer_attr=None): + """ + The crop layer crops images by offset and shape. User can set crop shape by + args 'shape' explicitly or by reference input layer. + + The example usage is: + + .. code-block:: python + crop = crop_layer(input=[image_input, reference_input], axis=2, offset=[2, 3]) + + :param input: The input layer.If two inputs were setted, + the second input will be regarded as reference input + :type input: LayerOutput or Sequence + :param offset: The crop offset + :type offset: Sequence + :param axis: start axis to be cropped. To image input layer: + - 0: batch size + - 1: channels + - 2: height + - 3: width + :type partial_sum: int + :param shape: The shape to be cropped. Default is None. + :type shape: Sequence | None + :param name: Name of this layer. + :type name: basestring + :return: LayerOutput object. + :rtype: LayerOutput + """ + if isinstance(input, LayerOutput): + input = [input] + else: + assert isinstance(input, collections.Sequence) + l = Layer( + inputs=[x.name for x in input], + axis=axis, + offset=offset, + shape=shape, + name=name, + type=LayerType.CROP_LAYER, + **ExtraLayerAttribute.to_kwargs(layer_attr)) + return LayerOutput( + name=name, + layer_type=LayerType.CROP_LAYER, + parents=input, + size=l.config.size) diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_crop.py b/python/paddle/trainer_config_helpers/tests/configs/test_crop.py new file mode 100644 index 0000000000000000000000000000000000000000..8314a7e9a5586647c70ff010156817110919c72b --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/test_crop.py @@ -0,0 +1,21 @@ +from paddle.trainer_config_helpers import * + +settings(batch_size=1000, learning_rate=1e-5) + +data = data_layer(name='data', size=2016, height=48, width=42) +refernce_data = data_layer(name='data', size=768, height=16, width=16) + +conv = img_conv_layer( + input=data, + filter_size=3, + num_channels=1, + num_filters=16, + padding=1, + act=LinearActivation(), + bias_attr=True) + +pool = img_pool_layer(input=conv, pool_size=2, stride=2, pool_type=MaxPooling()) + +crop = crop_layer(input=[pool, refernce_data], axis=2) + +outputs(pad)