diff --git a/paddle/function/CMakeLists.txt b/paddle/function/CMakeLists.txt index 4fd72d64a90ae6f16dd1499ceb7fba6e40fe4cea..9b2779b42cad324253dadf27dbff20fd8e8c8e16 100644 --- a/paddle/function/CMakeLists.txt +++ b/paddle/function/CMakeLists.txt @@ -45,6 +45,7 @@ if(WITH_GPU) add_simple_unittest(BlockExpandOpTest) add_simple_unittest(CropOpTest) add_simple_unittest(SwitchOpTest) + add_simple_unittest(ScaleSubRegionOpTest) endif() add_simple_unittest(Im2ColTest) diff --git a/paddle/function/FunctionTest.h b/paddle/function/FunctionTest.h index ba446bf92da264fafa1fb47a2c30da9cb13176ce..370940532ef40335be54a3e6467de0409e923ec4 100644 --- a/paddle/function/FunctionTest.h +++ b/paddle/function/FunctionTest.h @@ -110,6 +110,7 @@ public: function2_(FunctionBase::funcRegistrar_.createByType(name2)) { function1_->init(config); function2_->init(config); + initArgsCallback_ = nullptr; } ~Compare2Function() {} @@ -170,6 +171,10 @@ public: *seq2_)); } + void registerInitCallback(std::function callback) { + initArgsCallback_ = callback; + } + // output need only contains shape, do not contains data. void addOutputs(const BufferArg& output, ArgType argType = ASSIGN_TO) { size_t size = @@ -340,6 +345,10 @@ protected: initArg(*func1Inputs_[i]); } + if (initArgsCallback_ != nullptr) { + initArgsCallback_(*func1Inputs_[i], i); + } + copyArg_(*func1Inputs_[i], *func2Inputs_[i]); } } @@ -386,6 +395,7 @@ protected: std::shared_ptr seq1_; std::shared_ptr seq2_; test::CopyArgument copyArg_; + std::function initArgsCallback_; }; class CpuGpuFuncCompare diff --git a/paddle/function/ScaleSubRegionOp.cpp b/paddle/function/ScaleSubRegionOp.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a080505d7df83a6c0a9d88fbcb7863fc0e1f7b21 --- /dev/null +++ b/paddle/function/ScaleSubRegionOp.cpp @@ -0,0 +1,155 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "ScaleSubRegionOp.h" +#include "paddle/function/TensorShape.h" + +namespace paddle { + +template <> +void ScaleSubRegion(real* outputs, + const real* inputs, + const real* indices, + const TensorShape shape, + const FuncConfig& conf) { + real value = conf.get("value"); + + int number = shape[0]; + int channel = shape[1]; + int height = shape[2]; + int width = shape[3]; + + memcpy(outputs, inputs, number * channel * height * width * sizeof(real)); + + for (int n = 0; n < number; ++n) { + // indices start from 1 + int offset = n * 6; + for (int c = indices[offset] - 1; c < indices[offset + 1]; ++c) { + for (int h = indices[offset + 2] - 1; h < indices[offset + 3]; ++h) { + for (int w = indices[offset + 4] - 1; w < indices[offset + 5]; ++w) { + int idx = ((n * channel + c) * height + h) * width + w; + outputs[idx] *= value; + } + } + } + } +} + +template <> +void ScaleSubRegionGrad(const real* inGrad, + real* outGrad, + const real* indices, + const TensorShape shape, + const FuncConfig& conf) { + real value = conf.get("value"); + + int number = shape[0]; + int channel = shape[1]; + int height = shape[2]; + int width = shape[3]; + + for (int n = 0; n < number; ++n) { + for (int c = 0; c < channel; ++c) { + for (int h = 0; h < height; ++h) { + for (int w = 0; w < width; ++w) { + int idx = ((n * channel + c) * height + h) * width + w; + int offset = n * 6; + if (c >= (indices[offset] - 1) && c <= (indices[offset + 1] - 1) && + h >= (indices[offset + 2] - 1) && + h <= (indices[offset + 3] - 1) && + w >= (indices[offset + 4] - 1) && + w <= (indices[offset + 5] - 1)) { + outGrad[idx] += inGrad[idx] * value; + } else { + outGrad[idx] += inGrad[idx]; + } + } + } + } + } +} + +/** + * \brief For each instance, ScaleSubRegion can be used to multiply a value to + * a specified sub continuous region. By providing start index and end + * index for C/H/W, you can specify the location and shape of the region. + * + * Argument in this Function: + * \param inputs A 4-D tensor with shape [N, C, H, W], only one input. + * \param indices A 2-D tensor with shape [N, 6], indicates the sub region. + * \param outputs A 4-D tensor with same shape as inputs, output value. + */ +template +class ScaleSubRegionFunc : public FunctionBase { +public: + void init(const FuncConfig& config) override { conf_ = config; } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(2UL, inputs.size()); + CHECK_EQ(1UL, outputs.size()); + CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); + + TensorShape shape = inputs[0].shape(); + + ScaleSubRegion(outputs[0].data(), + inputs[0].data(), + inputs[1].data(), + shape, + conf_); + } + +private: + FuncConfig conf_; +}; + +/** + * \brief The backward propagation of ScaleSubRegion Function. + * + * Argument in this Function: + * \param inputs A 4-D tensor with shape [N, C, H, W], output gradient. + * \param indices A 2-D tensor with shape [N, 6], indicates the sub region. + * \param outputs A 4-D tensor with shape [N, C, H, W], gradient of input value. + */ + +template +class ScaleSubRegionGradFunc : public FunctionBase { +public: + void init(const FuncConfig& config) override { conf_ = config; } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(2UL, inputs.size()); + CHECK_EQ(1UL, outputs.size()); + CHECK_EQ(outputs[0].getArgType(), ADD_TO); + + TensorShape shape = inputs[0].shape(); + + ScaleSubRegionGrad(inputs[0].data(), + outputs[0].data(), + inputs[1].data(), + shape, + conf_); + } + +private: + FuncConfig conf_; +}; + +REGISTER_TYPED_FUNC(ScaleSubRegion, CPU, ScaleSubRegionFunc); +REGISTER_TYPED_FUNC(ScaleSubRegionGrad, CPU, ScaleSubRegionGradFunc); +#ifdef PADDLE_WITH_CUDA +REGISTER_TYPED_FUNC(ScaleSubRegion, GPU, ScaleSubRegionFunc); +REGISTER_TYPED_FUNC(ScaleSubRegionGrad, GPU, ScaleSubRegionGradFunc); +#endif + +} // namespace paddle diff --git a/paddle/function/ScaleSubRegionOp.h b/paddle/function/ScaleSubRegionOp.h new file mode 100644 index 0000000000000000000000000000000000000000..0480c8577f3fbf3bc9e94b635df96a31b103e9e3 --- /dev/null +++ b/paddle/function/ScaleSubRegionOp.h @@ -0,0 +1,55 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "Function.h" + +namespace paddle { + +/** + * \brief Function to multiply a value to values in specified sub continuous + * region. Indices must be provided to indcate the location and shape of + * the region and the multiplied value is passed by configure variable. + * + * + * \param[out] outputs Output value. + * \param[in] inputs Input data which contains NCHW information. + * \param[in] indices Indices data to indcate the sub region. + * \param[in] shape Tensor shape of input value. + * \param[in] conf Configure variable which contains the multiplied value. + */ +template +void ScaleSubRegion(real* outputs, + const real* inputs, + const real* indices, + const TensorShape shape, + const FuncConfig& conf); + +/** + * \brief Backward propagation function of ScaleSubRegion. + * + * \param[out] inGrad Gradients of previous layer. + * \param[in] outGrad Output gradient. + * \param[in] indices Indices data. + * \param[in] shape The Shape of input tensor. + * \param[in] conf Configure variable. + */ +template +void ScaleSubRegionGrad(const real* inGrad, + real* outGrad, + const real* indices, + const TensorShape shape, + const FuncConfig& conf); +} // namespace paddle diff --git a/paddle/function/ScaleSubRegionOpGpu.cu b/paddle/function/ScaleSubRegionOpGpu.cu new file mode 100644 index 0000000000000000000000000000000000000000..8aae2e44c3fdc8b516e66ecfd2e04f466a17dde9 --- /dev/null +++ b/paddle/function/ScaleSubRegionOpGpu.cu @@ -0,0 +1,116 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "ScaleSubRegionOp.h" +#include "hl_base.h" + +namespace paddle { + +__global__ void KeScaleSubRegion(real* outputs, + const real* inputs, + const real* indices, + real value, + int channel, + int height, + int width, + int nthreads) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < nthreads) { + const int w = idx % width; + const int h = (idx / width) % height; + const int c = (idx / width / height) % channel; + const int n = idx / width / height / channel; + + const int offset = n * 6; + if (c >= (indices[offset] - 1) && c <= (indices[offset + 1] - 1) && + h >= (indices[offset + 2] - 1) && h <= (indices[offset + 3] - 1) && + w >= (indices[offset + 4] - 1) && w <= (indices[offset + 5] - 1)) { + outputs[idx] = inputs[idx] * value; + } else { + outputs[idx] = inputs[idx]; + } + } +} + +template <> +void ScaleSubRegion(real* outputs, + const real* inputs, + const real* indices, + const TensorShape shape, + const FuncConfig& conf) { + real value = conf.get("value"); + + int number = shape[0]; + int channel = shape[1]; + int height = shape[2]; + int width = shape[3]; + + size_t nth = number * channel * height * width; + int blockSize = 1024; + int gridSize = (nth + blockSize - 1) / blockSize; + + KeScaleSubRegion<<>>( + outputs, inputs, indices, value, channel, height, width, nth); + CHECK_SYNC("ScaleSubRegion"); +} + +__global__ void KeScaleSubRegionDiff(const real* inGrad, + real* outGrad, + const real* indices, + real value, + int channel, + int height, + int width, + int nthreads) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < nthreads) { + const int w = idx % width; + const int h = (idx / width) % height; + const int c = (idx / width / height) % channel; + const int n = idx / width / height / channel; + + const int offset = n * 6; + if (c >= (indices[offset] - 1) && c <= (indices[offset + 1] - 1) && + h >= (indices[offset + 2] - 1) && h <= (indices[offset + 3] - 1) && + w >= (indices[offset + 4] - 1) && w <= (indices[offset + 5] - 1)) { + outGrad[idx] += inGrad[idx] * value; + } else { + outGrad[idx] += inGrad[idx]; + } + } +} + +template <> +void ScaleSubRegionGrad(const real* inGrad, + real* outGrad, + const real* indices, + const TensorShape shape, + const FuncConfig& conf) { + real value = conf.get("value"); + + int number = shape[0]; + int channel = shape[1]; + int height = shape[2]; + int width = shape[3]; + + size_t nth = number * channel * height * width; + int blockSize = 1024; + int gridSize = (nth + blockSize - 1) / blockSize; + + KeScaleSubRegionDiff<<>>( + inGrad, outGrad, indices, value, channel, height, width, nth); + CHECK_SYNC("ScaleSubRegionGrad"); +} + +} // namespace paddle diff --git a/paddle/function/ScaleSubRegionOpTest.cpp b/paddle/function/ScaleSubRegionOpTest.cpp new file mode 100644 index 0000000000000000000000000000000000000000..43331f258dddaa43cbc8cc77519e299de7e98290 --- /dev/null +++ b/paddle/function/ScaleSubRegionOpTest.cpp @@ -0,0 +1,72 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "FunctionTest.h" + +namespace paddle { + +TEST(ScaleSubRegion, real) { + for (size_t numSamples : {5, 32}) { + for (size_t channels : {5, 32}) { + for (size_t imgSizeH : {5, 33}) { + for (size_t imgSizeW : {5, 32}) { + for (real value : {-0.5, 0.0, 0.5}) { + for (bool firstHalf : {false, true}) { + VLOG(3) << " numSamples=" << numSamples + << " channels=" << channels << " imgSizeH=" << imgSizeH + << " imgSizeW=" << imgSizeW; + + for (bool testGrad : {false, true}) { + CpuGpuFuncCompare compare( + testGrad ? "ScaleSubRegionGrad" : "ScaleSubRegion", + FuncConfig().set("value", value)); + + TensorShape shape{numSamples, channels, imgSizeH, imgSizeW}; + TensorShape indicesShape{numSamples, 6}; + + compare.addInputs(BufferArg(VALUE_TYPE_FLOAT, shape)); + compare.addInputs(BufferArg(VALUE_TYPE_FLOAT, indicesShape)); + + compare.registerInitCallback([=](BufferArg& arg, size_t index) { + if (index == 1) { + real* data = (real*)arg.data(); + + for (size_t i = 0; i < numSamples; ++i) { + size_t offset = i * 6; + data[offset] = firstHalf ? 1 : channels / 2; + data[offset + 1] = firstHalf ? channels / 2 : channels; + data[offset + 2] = firstHalf ? 1 : imgSizeH / 2; + data[offset + 3] = firstHalf ? imgSizeH / 2 : imgSizeH; + data[offset + 4] = firstHalf ? 1 : imgSizeW / 2; + data[offset + 5] = firstHalf ? imgSizeW / 2 : imgSizeW; + } + } + }); + + compare.addOutputs( + BufferArg( + VALUE_TYPE_FLOAT, shape, testGrad ? ADD_TO : ASSIGN_TO), + testGrad ? ADD_TO : ASSIGN_TO); + compare.run(); + } + } + } + } + } + } + } +} + +} // namespace paddle diff --git a/paddle/gserver/layers/ScaleSubRegionLayer.cpp b/paddle/gserver/layers/ScaleSubRegionLayer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b18bc0c1b9065e09324d8ab4ed165679f6196d00 --- /dev/null +++ b/paddle/gserver/layers/ScaleSubRegionLayer.cpp @@ -0,0 +1,78 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "ScaleSubRegionLayer.h" +#include "paddle/utils/Stat.h" +namespace paddle { + +REGISTER_LAYER(scale_sub_region, ScaleSubRegionLayer); + +bool ScaleSubRegionLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + Layer::init(layerMap, parameterMap); + CHECK_EQ(static_cast(inputLayers_.size()), 2); + auto& conf = config_.inputs(0).scale_sub_region_conf(); + value_ = conf.value(); + + createFunction(forward_, "ScaleSubRegion", FuncConfig().set("value", value_)); + createFunction( + backward_, "ScaleSubRegionGrad", FuncConfig().set("value", value_)); + + return true; +} + +void ScaleSubRegionLayer::forward(PassType passType) { + Layer::forward(passType); + auto in0 = getInput(0); + imgH_ = in0.getFrameHeight(); + imgW_ = in0.getFrameWidth(); + if (imgH_ == 0 || imgW_ == 0) { + auto& conf = config_.inputs(0).scale_sub_region_conf(); + imgH_ = conf.image_conf().img_size_y(); + imgW_ = conf.image_conf().img_size(); + } + MatrixPtr imgV = in0.value; + size_t batchSize = imgV->getHeight(); + size_t spatialSize = imgH_ * imgW_; + channelsNum_ = imgV->getWidth() / spatialSize; + shape_ = TensorShape({batchSize, channelsNum_, imgH_, imgW_}); + + resetOutput(batchSize, imgV->getWidth()); + auto out = getOutput(); + out.setFrameHeight(imgH_); + out.setFrameWidth(imgW_); + + MatrixPtr indicesV = getInputValue(1); + indicesShape_ = TensorShape({batchSize, 6}); + + REGISTER_TIMER_INFO("ScaleSubRegionForward", getName().c_str()); + BufferArgs inArgs; + BufferArgs outArgs; + inArgs.addArg(*imgV, shape_); + inArgs.addArg(*indicesV, indicesShape_); + outArgs.addArg(*out.value, shape_, ASSIGN_TO); + forward_[0]->calc(inArgs, outArgs); +} + +void ScaleSubRegionLayer::backward(const UpdateCallback& callback) { + REGISTER_TIMER_INFO("ScaleSubRegionBackward", getName().c_str()); + BufferArgs inArgs; + BufferArgs outArgs; + inArgs.addArg(*getOutputGrad(), shape_); + inArgs.addArg(*getInputValue(1), indicesShape_); + outArgs.addArg(*getInputGrad(0), shape_, ADD_TO); + backward_[0]->calc(inArgs, outArgs); +} + +} // namespace paddle diff --git a/paddle/gserver/layers/ScaleSubRegionLayer.h b/paddle/gserver/layers/ScaleSubRegionLayer.h new file mode 100644 index 0000000000000000000000000000000000000000..a27c56de93bb6fdde0f95cd4c5abe5dfabe4e858 --- /dev/null +++ b/paddle/gserver/layers/ScaleSubRegionLayer.h @@ -0,0 +1,52 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "Layer.h" + +namespace paddle { + +/** + * \brief For each instance, this layer can be used to multiply a value to a + * specified sub continuous region. By providing start index and end + * index for C/H/W, you can specify the location and shape of the + * region. + * + * input_0: Input value. + * input_1: Indices value to specify the location an shape of the + * region. + */ +class ScaleSubRegionLayer : public Layer { +public: + explicit ScaleSubRegionLayer(const LayerConfig& config) : Layer(config) {} + + ~ScaleSubRegionLayer() {} + + bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); + + void forward(PassType passType); + + void backward(const UpdateCallback& callback = nullptr); + +protected: + TensorShape shape_; + TensorShape indicesShape_; + size_t imgH_; + size_t imgW_; + size_t channelsNum_; + real value_; +}; + +} // namespace paddle diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index 1a46fb49153a0aa4228f58db481b950bc2d6de83..3f7d8810513ed4b765219be722cf8ae9adc7909f 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -2358,6 +2358,38 @@ TEST(Layer, ScaleShiftLayer) { } } +TEST(Layer, ScaleSubRegionLayer) { + const size_t batchSize = 64; + const size_t size = 4096; + TestConfig config; + config.layerConfig.set_type("scale_sub_region"); + config.inputDefs.push_back({INPUT_DATA, "input", size, 0}); + MatrixPtr indicesV = Matrix::create(batchSize, 6, false, false); + auto* data = indicesV->getData(); + for (size_t i = 0; i < batchSize; ++i) { + data[i * 2] = 2; + data[i * 2 + 1] = 4; + data[i * 2 + 2] = 16; + data[i * 2 + 3] = 32; + data[i * 2 + 4] = 16; + data[i * 2 + 5] = 32; + } + config.inputDefs.push_back({INPUT_SELF_DEFINE_DATA, "indices", indicesV, {}}); + LayerInputConfig* input = config.layerConfig.add_inputs(); + ScaleSubRegionConfig* scaleSubRegionConf = + input->mutable_scale_sub_region_conf(); + ImageConfig* imgConf = scaleSubRegionConf->mutable_image_conf(); + imgConf->set_img_size(32); + imgConf->set_img_size_y(32); + imgConf->set_channels(4); + scaleSubRegionConf->set_value(2.0); + config.layerConfig.add_inputs(); + + for (auto useGpu : {false, true}) { + testLayerGrad(config, "scale_sub_region", batchSize, false, useGpu, false); + } +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); initMain(argc, argv); diff --git a/paddle/math/tests/TensorCheck.h b/paddle/math/tests/TensorCheck.h index 5bc4a03067a75527fa30e5bb5526f93dc7b9fdcc..b998e5772e70d0a0ec79dc4064dcbaa2c302efd2 100644 --- a/paddle/math/tests/TensorCheck.h +++ b/paddle/math/tests/TensorCheck.h @@ -169,7 +169,7 @@ void TensorCheck(AssertEq compare, count++; } } - EXPECT_EQ(count, 0) << "There are " << count << " different element."; + EXPECT_EQ(count, 0) << "There are " << count << " different elements."; } template diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto index ebf0911d6ea0b39d51447859ae2aef485b50b0e6..2d7ff1df98a9a448b447890537f20dd416a9ae9d 100644 --- a/proto/ModelConfig.proto +++ b/proto/ModelConfig.proto @@ -321,6 +321,11 @@ message ClipConfig { required double max = 2; } +message ScaleSubRegionConfig { + required ImageConfig image_conf = 1; + required float value = 2; +} + message LayerInputConfig { required string input_layer_name = 1; optional string input_parameter_name = 2; @@ -342,6 +347,7 @@ message LayerInputConfig { optional MultiBoxLossConfig multibox_loss_conf = 16; optional DetectionOutputConfig detection_output_conf = 17; optional ClipConfig clip_conf = 18; + optional ScaleSubRegionConfig scale_sub_region_conf = 19; } message LayerConfig { diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 0e65598485d8785b3f5b2f1bc7e87f377b35792e..9e2c6f59bd0af1627c79a8a29bd1515ae5c9c6b5 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -3801,6 +3801,25 @@ class SwitchOrderLayer(LayerBase): self.config.reshape_conf.width_axis.extend(reshape['width']) +@config_layer('scale_sub_region') +class ScaleSubRegionLayer(LayerBase): + def __init__(self, name, inputs, value, **xargs): + super(ScaleSubRegionLayer, self).__init__( + name, 'scale_sub_region', 0, inputs=inputs, **xargs) + scale_sub_region_conf = self.config.inputs[0].scale_sub_region_conf + scale_sub_region_conf.value = value + + # get channel, width and height from input_0 layer + input_layer = self.get_input_layer(0) + image_conf = scale_sub_region_conf.image_conf + image_conf.img_size = input_layer.width + image_conf.img_size_y = input_layer.height + image_conf.channels = input_layer.size / (input_layer.width * + input_layer.height) + self.set_cnn_layer(name, image_conf.img_size_y, image_conf.img_size, + image_conf.channels) + + # Deprecated, use a new layer specific class instead @config_func def Layer(name, type, **xargs): diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 0fd77a0be60124c882e43a71fd1fed3587ec48a4..9dbfb93217262b0599fd4af75449838041da7a23 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -144,6 +144,7 @@ __all__ = [ 'img_conv3d_layer', 'resize_layer', 'sub_seq_layer', + 'scale_sub_region_layer', ] @@ -255,6 +256,8 @@ class LayerType(object): RESIZE = 'resize' SUB_SEQ_LAYER = 'subseq' + SCALE_SUB_REGION_LAYER = 'scale_sub_region' + @staticmethod def is_layer_type(type_name): """ @@ -7042,3 +7045,54 @@ def sub_seq_layer(input, offsets, sizes, act=None, bias_attr=None, name=None): LayerType.SUB_SEQ_LAYER, parents=[input, offsets, sizes], size=input.size) + + +@wrap_name_default('scale_sub_region') +def scale_sub_region_layer(input, indices, value, name=None): + """ + Given an image or feature map with CHW information, scale_sub_region_layer + can be used to multiply a real value to values of a sub continuous region. + You can provide start and end indices of CHW for each instance. + Please notice that all start indices are counting from 1. + The shape of indices should be [batch_size, 6] and the layout for each row + is [C_Start, C_End, H_Start, H_End, W_Start, W_End]. + + .. code-block:: python + + scale_sub_region = scale_sub_region_layer(input=input, + indices=indices, + value=value) + + :param name: The name of this layer. It is optional. + :type name: basestring + :param input: The input of this layer which should contains CHW information. + :type input: LayerOutput + :param indices: Start index and end index for C H W, the input value should + be a 2-D matrix with shape [batch_size, 6]. + :type indices: LayerOutput. + :param value: value to multiply. + :type value: float + :return: LayerOutput object. + :rtype: LayerOutput + """ + + assert isinstance(input, LayerOutput), ( + 'The first input of scale_sub_region_layer, ' + 'must be a PaddlePaddle layer.') + assert isinstance(indices, LayerOutput), ( + 'The start and end indices for CHW, must be a PaddlePaddle layer.') + assert isinstance(value, float), ( + 'The value to multiply, must be a real value.') + + Layer( + name=name, + type=LayerType.SCALE_SUB_REGION_LAYER, + inputs=[input.name, indices.name], + value=value) + + return LayerOutput( + name, + LayerType.SCALE_SUB_REGION_LAYER, + parents=[input, indices], + num_filters=input.num_filters, + size=input.size) diff --git a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh index 6a4550c209762362d40f8a2afaf526a1fe53ca6b..42aaed7a6469342086b8273eb5b80eaea905f851 100755 --- a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh +++ b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh @@ -10,6 +10,6 @@ test_prelu_layer test_row_conv test_detection_output_layer test_multibox_loss_la test_recursive_topology test_gated_unit_layer test_clip_layer test_row_l2_norm_layer test_kmax_seq_socre_layer test_sub_nested_seq_select_layer test_scale_shift_layer test_seq_slice_layer test_cross_entropy_over_beam test_pooling3D_layer -test_conv3d_layer test_deconv3d_layer test_BatchNorm3D test_resize_layer) +test_conv3d_layer test_deconv3d_layer test_BatchNorm3D test_resize_layer test_scale_sub_region_layer) export whole_configs=(test_split_datasource) diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_scale_sub_region_layer.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_scale_sub_region_layer.protostr new file mode 100644 index 0000000000000000000000000000000000000000..d20133a10ec605654bd3744297673068a77020b8 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_scale_sub_region_layer.protostr @@ -0,0 +1,51 @@ +type: "nn" +layers { + name: "data" + type: "data" + size: 2016 + active_type: "" + height: 48 + width: 42 +} +layers { + name: "indices" + type: "data" + size: 6 + active_type: "" +} +layers { + name: "__scale_sub_region_0__" + type: "scale_sub_region" + size: 2016 + active_type: "" + inputs { + input_layer_name: "data" + scale_sub_region_conf { + image_conf { + channels: 1 + img_size: 42 + img_size_y: 48 + } + value: 0.0 + } + } + inputs { + input_layer_name: "indices" + } + height: 48 + width: 42 +} +input_layer_names: "data" +input_layer_names: "indices" +output_layer_names: "__scale_sub_region_0__" +sub_models { + name: "root" + layer_names: "data" + layer_names: "indices" + layer_names: "__scale_sub_region_0__" + input_layer_names: "data" + input_layer_names: "indices" + output_layer_names: "__scale_sub_region_0__" + is_recurrent_layer_group: false +} + diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_scale_sub_region_layer.py b/python/paddle/trainer_config_helpers/tests/configs/test_scale_sub_region_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..8d4bf28bf1eaf58e1fd0eb62fd10efe998587edd --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/test_scale_sub_region_layer.py @@ -0,0 +1,11 @@ +from paddle.trainer_config_helpers import * + +settings(batch_size=1000, learning_rate=1e-5) + +data = data_layer(name='data', size=2016, height=48, width=42) +indices = data_layer(name='indices', size=6) + +scale_sub_region = scale_sub_region_layer( + input=data, indices=indices, value=0.0) + +outputs(scale_sub_region)