diff --git a/doc/api/v2/config/layer.rst b/doc/api/v2/config/layer.rst index d4e9d53e5c0955912a594fe8cd9cd41a4080a2d2..203506d7ab84e5a5be2232b077eac2d433a99766 100644 --- a/doc/api/v2/config/layer.rst +++ b/doc/api/v2/config/layer.rst @@ -82,6 +82,11 @@ maxout .. autoclass:: paddle.v2.layer.maxout :noindex: +roi_pool +-------- +.. autoclass:: paddle.v2.layer.roi_pool + :noindex: + Norm Layer ========== diff --git a/paddle/framework/scope.cc b/paddle/framework/scope.cc index fb2c69105627f663ddcce07d31526c9e4278e863..9428b8a07ea0af005f6e960ddaa02da624ad9d97 100644 --- a/paddle/framework/scope.cc +++ b/paddle/framework/scope.cc @@ -98,5 +98,23 @@ void Scope::DeleteScope(Scope* scope) { delete scope; } +void Scope::Rename(const std::string& origin_name, + const std::string& new_name) const { + auto origin_it = vars_.find(origin_name); + PADDLE_ENFORCE(origin_it != vars_.end(), + "Cannot find original variable with name %s", origin_name); + auto new_it = vars_.find(new_name); + PADDLE_ENFORCE(new_it == vars_.end(), + "The variable with name %s is already in the scope", new_name); + vars_[new_name] = origin_it->second; + vars_.erase(origin_it); +} + +std::string Scope::Rename(const std::string& origin_name) const { + auto var_name = string::Sprintf("%p.%d", this, vars_.size()); + Rename(origin_name, var_name); + return var_name; +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/scope.h b/paddle/framework/scope.h index fb660949394149ebf2c6172a0ac3f4c7594f4286..c2aafb6ad825f9bd9ffef754923a15afdeaa8e5c 100644 --- a/paddle/framework/scope.h +++ b/paddle/framework/scope.h @@ -68,11 +68,18 @@ class Scope { // enumerate all the variables current contains. std::vector GetAllNames(bool recursive = false) const; + // Rename variable to a new name + void Rename(const std::string& origin_name, + const std::string& new_name) const; + + // Rename variable to a new name and return the new name + std::string Rename(const std::string& origin_name) const; + private: // Call Scope::NewScope for a sub-scope. explicit Scope(Scope const* parent) : parent_(parent) {} - std::unordered_map vars_; + mutable std::unordered_map vars_; mutable std::list kids_; Scope const* parent_{nullptr}; diff --git a/paddle/gserver/layers/ROIPoolLayer.cpp b/paddle/gserver/layers/ROIPoolLayer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..99cfddb0cf3337745a716a8c329713c18b99eda3 --- /dev/null +++ b/paddle/gserver/layers/ROIPoolLayer.cpp @@ -0,0 +1,220 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "ROIPoolLayer.h" + +namespace paddle { + +REGISTER_LAYER(roi_pool, ROIPoolLayer); + +bool ROIPoolLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + Layer::init(layerMap, parameterMap); + + const ROIPoolConfig& layerConf = config_.inputs(0).roi_pool_conf(); + pooledWidth_ = layerConf.pooled_width(); + pooledHeight_ = layerConf.pooled_height(); + spatialScale_ = layerConf.spatial_scale(); + + return true; +} + +void ROIPoolLayer::forward(PassType passType) { + Layer::forward(passType); + + const ROIPoolConfig& layerConf = config_.inputs(0).roi_pool_conf(); + height_ = getInput(0).getFrameHeight(); + if (!height_) height_ = layerConf.height(); + width_ = getInput(0).getFrameWidth(); + if (!width_) width_ = layerConf.width(); + channels_ = getInputValue(0)->getWidth() / width_ / height_; + + size_t batchSize = getInput(0).getBatchSize(); + size_t numROIs = getInput(1).getBatchSize(); + + MatrixPtr dataValue = getInputValue(0); + MatrixPtr roiValue = getInputValue(1); + resetOutput(numROIs, channels_ * pooledHeight_ * pooledWidth_); + MatrixPtr outputValue = getOutputValue(); + + if (useGpu_) { // TODO(guosheng): implement on GPU later + MatrixPtr dataCpuBuffer; + Matrix::resizeOrCreate(dataCpuBuffer, + dataValue->getHeight(), + dataValue->getWidth(), + false, + false); + MatrixPtr roiCpuBuffer; + Matrix::resizeOrCreate(roiCpuBuffer, + roiValue->getHeight(), + roiValue->getWidth(), + false, + false); + dataCpuBuffer->copyFrom(*dataValue); + roiCpuBuffer->copyFrom(*roiValue); + dataValue = dataCpuBuffer; + roiValue = roiCpuBuffer; + MatrixPtr outputCpuBuffer; + Matrix::resizeOrCreate(outputCpuBuffer, + outputValue->getHeight(), + outputValue->getWidth(), + false, + false); + outputCpuBuffer->copyFrom(*outputValue); + outputValue = outputCpuBuffer; + } + + real* bottomData = dataValue->getData(); + size_t batchOffset = dataValue->getWidth(); + size_t channelOffset = height_ * width_; + real* bottomROIs = roiValue->getData(); + size_t roiOffset = roiValue->getWidth(); + size_t poolChannelOffset = pooledHeight_ * pooledWidth_; + + real* outputData = outputValue->getData(); + Matrix::resizeOrCreate(maxIdxs_, + numROIs, + channels_ * pooledHeight_ * pooledWidth_, + false, + false); + real* argmaxData = maxIdxs_->getData(); + + for (size_t n = 0; n < numROIs; ++n) { + // the first five elememts of each RoI should be: + // batch_idx, roi_x_start, roi_y_start, roi_x_end, roi_y_end + size_t roiBatchIdx = bottomROIs[0]; + size_t roiStartW = round(bottomROIs[1] * spatialScale_); + size_t roiStartH = round(bottomROIs[2] * spatialScale_); + size_t roiEndW = round(bottomROIs[3] * spatialScale_); + size_t roiEndH = round(bottomROIs[4] * spatialScale_); + CHECK_GE(roiBatchIdx, 0); + CHECK_LT(roiBatchIdx, batchSize); + size_t roiHeight = std::max(roiEndH - roiStartH + 1, 1UL); + size_t roiWidth = std::max(roiEndW - roiStartW + 1, 1UL); + real binSizeH = + static_cast(roiHeight) / static_cast(pooledHeight_); + real binSizeW = + static_cast(roiWidth) / static_cast(pooledWidth_); + real* batchData = bottomData + batchOffset * roiBatchIdx; + for (size_t c = 0; c < channels_; ++c) { + for (size_t ph = 0; ph < pooledHeight_; ++ph) { + for (size_t pw = 0; pw < pooledWidth_; ++pw) { + size_t hstart = static_cast(std::floor(ph * binSizeH)); + size_t wstart = static_cast(std::floor(pw * binSizeW)); + size_t hend = static_cast(std::ceil((ph + 1) * binSizeH)); + size_t wend = static_cast(std::ceil((pw + 1) * binSizeW)); + hstart = std::min(std::max(hstart + roiStartH, 0UL), height_); + wstart = std::min(std::max(wstart + roiStartW, 0UL), width_); + hend = std::min(std::max(hend + roiStartH, 0UL), height_); + wend = std::min(std::max(wend + roiStartW, 0UL), width_); + + bool isEmpty = (hend <= hstart) || (wend <= wstart); + size_t poolIndex = ph * pooledWidth_ + pw; + if (isEmpty) { + outputData[poolIndex] = 0; + argmaxData[poolIndex] = -1; + } + + for (size_t h = hstart; h < hend; ++h) { + for (size_t w = wstart; w < wend; ++w) { + size_t index = h * width_ + w; + if (batchData[index] > outputData[poolIndex]) { + outputData[poolIndex] = batchData[index]; + argmaxData[poolIndex] = index; + } + } + } + } + } + batchData += channelOffset; + outputData += poolChannelOffset; + argmaxData += poolChannelOffset; + } + bottomROIs += roiOffset; + } + if (useGpu_) { + getOutputValue()->copyFrom(*outputValue); + } +} + +void ROIPoolLayer::backward(const UpdateCallback& callback) { + MatrixPtr inGradValue = getInputGrad(0); + MatrixPtr outGradValue = getOutputGrad(); + MatrixPtr roiValue = getInputValue(1); + + if (useGpu_) { + MatrixPtr inGradCpuBuffer; + Matrix::resizeOrCreate(inGradCpuBuffer, + inGradValue->getHeight(), + inGradValue->getWidth(), + false, + false); + MatrixPtr outGradCpuBuffer; + Matrix::resizeOrCreate(outGradCpuBuffer, + outGradValue->getHeight(), + outGradValue->getWidth(), + false, + false); + MatrixPtr roiCpuBuffer; + Matrix::resizeOrCreate(roiCpuBuffer, + roiValue->getHeight(), + roiValue->getWidth(), + false, + false); + inGradCpuBuffer->copyFrom(*inGradValue); + outGradCpuBuffer->copyFrom(*outGradValue); + roiCpuBuffer->copyFrom(*roiValue); + inGradValue = inGradCpuBuffer; + outGradValue = outGradCpuBuffer; + roiValue = roiCpuBuffer; + } + + real* bottomROIs = roiValue->getData(); + size_t numROIs = getInput(1).getBatchSize(); + size_t roiOffset = getInputValue(1)->getWidth(); + + real* inDiffData = inGradValue->getData(); + size_t batchOffset = getInputValue(0)->getWidth(); + size_t channelOffset = height_ * width_; + + real* outDiffData = outGradValue->getData(); + size_t poolChannelOffset = pooledHeight_ * pooledWidth_; + real* argmaxData = maxIdxs_->getData(); + + for (size_t n = 0; n < numROIs; ++n) { + size_t roiBatchIdx = bottomROIs[0]; + real* batchDiffData = inDiffData + batchOffset * roiBatchIdx; + for (size_t c = 0; c < channels_; ++c) { + for (size_t ph = 0; ph < pooledHeight_; ++ph) { + for (size_t pw = 0; pw < pooledWidth_; ++pw) { + size_t poolIndex = ph * pooledWidth_ + pw; + if (argmaxData[poolIndex] > 0) { + size_t index = static_cast(argmaxData[poolIndex]); + batchDiffData[index] += outDiffData[poolIndex]; + } + } + } + batchDiffData += channelOffset; + outDiffData += poolChannelOffset; + argmaxData += poolChannelOffset; + } + bottomROIs += roiOffset; + } + + if (useGpu_) { + getInputGrad(0)->copyFrom(*inGradValue); + } +} + +} // namespace paddle diff --git a/paddle/gserver/layers/ROIPoolLayer.h b/paddle/gserver/layers/ROIPoolLayer.h new file mode 100644 index 0000000000000000000000000000000000000000..4f07e49d6fd1eda9fa7bd46e4cec771a75f571be --- /dev/null +++ b/paddle/gserver/layers/ROIPoolLayer.h @@ -0,0 +1,56 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "Layer.h" + +namespace paddle { + +/** + * A layer used by Fast R-CNN to extract feature maps of ROIs from the last + * feature map. + * - Input: This layer needs two input layers: The first input layer is a + * convolution layer; The second input layer contains the ROI data + * which is the output of ProposalLayer in Faster R-CNN. layers for + * generating bbox location offset and the classification confidence. + * - Output: The ROIs' feature map. + * Reference: + * Shaoqing Ren, Kaiming He, Ross Girshick, and Jian Sun. + * Faster R-CNN: Towards Real-Time Object Detection with Region Proposal + * Networks + */ + +class ROIPoolLayer : public Layer { +protected: + size_t channels_; + size_t width_; + size_t height_; + size_t pooledWidth_; + size_t pooledHeight_; + real spatialScale_; + + // Since there is no int matrix, use real maxtrix instead. + MatrixPtr maxIdxs_; + +public: + explicit ROIPoolLayer(const LayerConfig& config) : Layer(config) {} + + bool init(const LayerMap& layerMap, + const ParameterMap& parameterMap) override; + + void forward(PassType passType) override; + void backward(const UpdateCallback& callback = nullptr) override; +}; +} // namespace paddle diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index b129f6ec560f8f2ce49853341ffbd6c224c42912..3afaab630b21ecdf5aa7c4eb4dde922bd1962579 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -2058,6 +2058,43 @@ TEST(Layer, CropLayer) { } } +TEST(Layer, roi_pool) { + TestConfig config; + config.layerConfig.set_type("roi_pool"); + config.biasSize = 0; + LayerInputConfig* input = config.layerConfig.add_inputs(); + ROIPoolConfig* roiPoolConf = input->mutable_roi_pool_conf(); + roiPoolConf->set_pooled_width(7); + roiPoolConf->set_pooled_height(7); + roiPoolConf->set_spatial_scale(1. / 16); + roiPoolConf->set_width(14); + roiPoolConf->set_height(14); + + const size_t roiNum = 10; + const size_t roiDim = 10; + const size_t batchSize = 5; + MatrixPtr roiValue = Matrix::create(roiNum, roiDim, false, false); + roiValue->zeroMem(); + real* roiData = roiValue->getData(); + for (size_t i = 0; i < roiNum; ++i) { + roiData[i * roiDim + 0] = std::rand() % batchSize; + roiData[i * roiDim + 1] = std::rand() % 224; // xMin + roiData[i * roiDim + 2] = std::rand() % 224; // yMin + size_t xMin = static_cast(roiData[i * roiDim + 1]); + size_t yMin = static_cast(roiData[i * roiDim + 2]); + roiData[i * roiDim + 3] = xMin + std::rand() % (224 - xMin); // xMax + roiData[i * roiDim + 4] = yMin + std::rand() % (224 - yMin); // yMax + } + + config.inputDefs.push_back({INPUT_DATA, "input", 3 * 14 * 14, {}}); + config.inputDefs.push_back({INPUT_SELF_DEFINE_DATA, "rois", roiValue, {}}); + config.layerConfig.add_inputs(); + + for (auto useGpu : {false, true}) { + testLayerGrad(config, "roi_pool", batchSize, false, useGpu, false); + } +} + TEST(Layer, SwitchOrderLayer) { TestConfig config; // config input_0 diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index 09c3f0b1e6f787547b9253d3aeadf70674708ba0..1b0d4c8bdc683b5203a4bc4b3838560cffe00bc8 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -234,8 +234,8 @@ void gemv(const platform::DeviceContext& context, template struct SetConstant; -struct TensorSetConstant { - TensorSetConstant(framework::Tensor* tensor, float value) +struct TensorSetConstantCPU { + TensorSetConstantCPU(framework::Tensor* tensor, float value) : tensor_(tensor), value_(value) {} template void operator()() const { @@ -252,7 +252,7 @@ void set_constant_with_place( const platform::DeviceContext& context, framework::Tensor* tensor, float value) { framework::VisitDataType(framework::ToDataType(tensor->type()), - TensorSetConstant(tensor, value)); + TensorSetConstantCPU(tensor, value)); } struct TensorSetConstantWithPlace : public boost::static_visitor { diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 255e480680499877ff599b96b8336a968cccbb34..817deec94314bdfd2ed7e4b0ba5212c72b813455 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -233,8 +233,8 @@ void gemv(const platform::DeviceContext& context, template struct SetConstant; -struct TensorSetConstant { - TensorSetConstant(const platform::DeviceContext& context, +struct TensorSetConstantGPU { + TensorSetConstantGPU(const platform::DeviceContext& context, framework::Tensor* tensor, float value) : context_(context), tensor_(tensor), value_(value) {} @@ -254,7 +254,7 @@ void set_constant_with_place( const platform::DeviceContext& context, framework::Tensor* tensor, float value) { framework::VisitDataType(framework::ToDataType(tensor->type()), - TensorSetConstant(context, tensor, value)); + TensorSetConstantGPU(context, tensor, value)); } } // namespace math diff --git a/paddle/operators/nccl/nccl_gpu_common.h b/paddle/operators/nccl/nccl_gpu_common.h index 5858cd4839d367bb888b2b98cde2225751391162..48e322f99398a7f1d6af9cab653d0cc92d981fe0 100644 --- a/paddle/operators/nccl/nccl_gpu_common.h +++ b/paddle/operators/nccl/nccl_gpu_common.h @@ -35,6 +35,7 @@ constexpr int kInvalidGPUId = -1; struct Communicator { std::vector comms_; std::unordered_map comm_id_map_; + bool inited_; Communicator() {} @@ -42,17 +43,21 @@ struct Communicator { void InitAll(const std::vector& gpus) { comms_.resize(gpus.size()); + inited_ = false; for (size_t i = 0; i < gpus.size(); ++i) { comm_id_map_[gpus[i]] = i; } PADDLE_ENFORCE( dynload::ncclCommInitAll(comms_.data(), gpus.size(), gpus.data())); + inited_ = true; } ~Communicator() { - for (size_t i = 0; i < comms_.size(); ++i) { - // FIXME(dzh) : PADDLE_ENFORCE return void - dynload::ncclCommDestroy(comms_[i]); + if (inited_) { + for (size_t i = 0; i < comms_.size(); ++i) { + // FIXME(dzh) : PADDLE_ENFORCE return void + dynload::ncclCommDestroy(comms_[i]); + } } } diff --git a/paddle/operators/recurrent_op.cc b/paddle/operators/recurrent_op.cc index b0e87b7059eab3772c179fe31cdb09477b589ed1..0075ccd24271bf83f139e121efad00c2316cc11b 100644 --- a/paddle/operators/recurrent_op.cc +++ b/paddle/operators/recurrent_op.cc @@ -387,8 +387,8 @@ class RecurrentGradOp : public RecurrentBase { auto &p_names = Inputs(kParameters); PADDLE_ENFORCE_EQ(pg_names.size(), p_names.size()); - for (size_t prog_id = 0; prog_id < pg_names.size(); ++prog_id) { - auto inside_grad_name = framework::GradVarName(p_names[prog_id]); + for (size_t param_id = 0; param_id < pg_names.size(); ++param_id) { + auto inside_grad_name = framework::GradVarName(p_names[param_id]); // If does not compute gradient of that variable inside rnn, just // continue @@ -406,27 +406,19 @@ class RecurrentGradOp : public RecurrentBase { attrs["value"] = 0.0f; auto zero_op = framework::OpRegistry::CreateOp( - "fill_constant", {}, {{"Out", {pg_names[prog_id]}}}, attrs); + "fill_constant", {}, {{"Out", {pg_names[param_id]}}}, attrs); zero_op->Run(scope, dev_ctx); } + auto new_inside_name = cur_scope.Rename(inside_grad_name); // sum gradient - auto *outside_var = scope.FindVar(pg_names[prog_id]); - PADDLE_ENFORCE(outside_var != nullptr); - auto &outside_tensor = - *outside_var->GetMutable(); - - std::string result_var_name; - auto *local_result_var = cur_scope.Var(&result_var_name); - auto &local_result_tensor = - *local_result_var->GetMutable(); - - local_result_tensor.ShareDataWith(outside_tensor); auto sum_op = framework::OpRegistry::CreateOp( - "sum", {{"X", {result_var_name, inside_grad_name}}}, - {{"Out", {result_var_name}}}, {}); + "sum", {{"X", {pg_names[param_id], new_inside_name}}}, + {{"Out", {pg_names[param_id]}}}, {}); sum_op->Run(cur_scope, dev_ctx); + + cur_scope.Rename(new_inside_name, inside_grad_name); } } VLOG(5) << "Accumulate Parameter finished "; diff --git a/paddle/platform/call_once.h b/paddle/platform/call_once.h new file mode 100644 index 0000000000000000000000000000000000000000..248baf6613c5caebdabcecff5d57290585238d78 --- /dev/null +++ b/paddle/platform/call_once.h @@ -0,0 +1,50 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include + +namespace paddle { +namespace platform { + +/* + The current implementation of std::call_once has a bug described in + https://stackoverflow.com/questions/41717579/stdcall-once-hangs-on-second-call-after-callable-threw-on-first-call. + This is likely caused by a deeper bug of pthread_once, which is discussed in + https://patchwork.ozlabs.org/patch/482350/ + + This wrap is a hack to avoid this bug. +*/ +template +inline void call_once(std::once_flag& flag, Callable&& f, Args&&... args) { + bool good = false; + std::exception ex; + std::call_once(flag, [&]() { + try { + f(args...); + good = true; + } catch (const std::exception& e) { + ex = e; + } catch (...) { + ex = std::runtime_error("excption caught in call_once"); + } + }); + if (!good) { + throw std::exception(ex); + } +} + +} // namespace platform +} // namespace paddle diff --git a/paddle/platform/dynload/nccl.h b/paddle/platform/dynload/nccl.h index 0618c7414fd1235e81ee9d92a3a07b53d6ad6ebc..981b2ab258a34ce92f02ee12b5957f88ba61d1c0 100644 --- a/paddle/platform/dynload/nccl.h +++ b/paddle/platform/dynload/nccl.h @@ -17,6 +17,7 @@ #include #include #include +#include "paddle/platform/call_once.h" #include "paddle/platform/dynload/dynamic_loader.h" namespace paddle { @@ -27,18 +28,18 @@ extern std::once_flag nccl_dso_flag; extern void* nccl_dso_handle; #ifdef PADDLE_USE_DSO -#define DECLARE_DYNAMIC_LOAD_NCCL_WRAP(__name) \ - struct DynLoad__##__name { \ - template \ - auto operator()(Args... args) -> decltype(__name(args...)) { \ - using nccl_func = decltype(__name(args...)) (*)(Args...); \ - std::call_once(nccl_dso_flag, \ - paddle::platform::dynload::GetNCCLDsoHandle, \ - &nccl_dso_handle); \ - void* p_##__name = dlsym(nccl_dso_handle, #__name); \ - return reinterpret_cast(p_##__name)(args...); \ - } \ - }; \ +#define DECLARE_DYNAMIC_LOAD_NCCL_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + auto operator()(Args... args) -> decltype(__name(args...)) { \ + using nccl_func = decltype(__name(args...)) (*)(Args...); \ + platform::call_once(nccl_dso_flag, \ + paddle::platform::dynload::GetNCCLDsoHandle, \ + &nccl_dso_handle); \ + void* p_##__name = dlsym(nccl_dso_handle, #__name); \ + return reinterpret_cast(p_##__name)(args...); \ + } \ + }; \ extern DynLoad__##__name __name #else #define DECLARE_DYNAMIC_LOAD_NCCL_WRAP(__name) \ diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto index 2d7ff1df98a9a448b447890537f20dd416a9ae9d..2c2cc6245932d4af56a68d6399ce31f008bf3748 100644 --- a/proto/ModelConfig.proto +++ b/proto/ModelConfig.proto @@ -321,6 +321,14 @@ message ClipConfig { required double max = 2; } +message ROIPoolConfig { + required uint32 pooled_width = 1; + required uint32 pooled_height = 2; + required float spatial_scale = 3; + optional uint32 height = 4 [ default = 1 ]; + optional uint32 width = 5 [ default = 1 ]; +} + message ScaleSubRegionConfig { required ImageConfig image_conf = 1; required float value = 2; @@ -348,6 +356,7 @@ message LayerInputConfig { optional DetectionOutputConfig detection_output_conf = 17; optional ClipConfig clip_conf = 18; optional ScaleSubRegionConfig scale_sub_region_conf = 19; + optional ROIPoolConfig roi_pool_conf = 20; } message LayerConfig { diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index fd7a3e7e8f26bb3ee4e4d2be949b3538a8a915e2..5891b1f0cd657ccf0d4577ec85ec51336c360534 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -1969,6 +1969,18 @@ class DetectionOutputLayer(LayerBase): self.config.size = size +@config_layer('roi_pool') +class ROIPoolLayer(LayerBase): + def __init__(self, name, inputs, pooled_width, pooled_height, spatial_scale, + num_channels, **xargs): + super(ROIPoolLayer, self).__init__(name, 'roi_pool', 0, inputs) + config_assert(len(inputs) == 2, 'ROIPoolLayer must have 2 inputs') + self.config.inputs[0].roi_pool_conf.pooled_width = pooled_width + self.config.inputs[0].roi_pool_conf.pooled_height = pooled_height + self.config.inputs[0].roi_pool_conf.spatial_scale = spatial_scale + self.set_cnn_layer(name, pooled_height, pooled_width, num_channels) + + @config_layer('data') class DataLayer(LayerBase): def __init__(self, diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 6da5662d436ef6aa7b647450e0708db5ebaca950..f7ab7a5ca0a369e976c783fdc793abd69a683c3c 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -122,6 +122,7 @@ __all__ = [ 'cross_channel_norm_layer', 'multibox_loss_layer', 'detection_output_layer', + 'roi_pool_layer', 'spp_layer', 'pad_layer', 'eos_layer', @@ -221,6 +222,7 @@ class LayerType(object): PRIORBOX_LAYER = 'priorbox' MULTIBOX_LOSS_LAYER = 'multibox_loss' DETECTION_OUTPUT_LAYER = 'detection_output' + ROI_POOL_LAYER = 'roi_pool' CTC_LAYER = 'ctc' WARP_CTC_LAYER = 'warp_ctc' @@ -1305,6 +1307,50 @@ def detection_output_layer(input_loc, name, LayerType.DETECTION_OUTPUT_LAYER, parents=parents, size=size) +@wrap_name_default("roi_pool") +def roi_pool_layer(input, + rois, + pooled_width, + pooled_height, + spatial_scale, + num_channels=None, + name=None): + """ + A layer used by Fast R-CNN to extract feature maps of ROIs from the last + feature map. + + :param name: The Layer Name. + :type name: basestring + :param input: The input layer. + :type input: LayerOutput. + :param rois: The input ROIs' data. + :type rois: LayerOutput. + :param pooled_width: The width after pooling. + :type pooled_width: int + :param pooled_height: The height after pooling. + :type pooled_height: int + :param spatial_scale: The spatial scale between the image and feature map. + :type spatial_scale: float + :param num_channels: number of input channel. + :type num_channels: int + :return: LayerOutput + """ + if num_channels is None: + assert input.num_filters is not None + num_channels = input.num_filters + size = num_channels * pooled_width * pooled_height + Layer( + name=name, + type=LayerType.ROI_POOL_LAYER, + inputs=[input.name, rois.name], + pooled_width=pooled_width, + pooled_height=pooled_height, + spatial_scale=spatial_scale, + num_channels=num_channels) + return LayerOutput( + name, LayerType.ROI_POOL_LAYER, parents=[input, rois], size=size) + + @wrap_name_default("cross_channel_norm") def cross_channel_norm_layer(input, name=None, param_attr=None): """ diff --git a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh index 42aaed7a6469342086b8273eb5b80eaea905f851..1c7451e0abf5dc1b99671f292e2ffc2d2282abe9 100755 --- a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh +++ b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh @@ -9,7 +9,7 @@ test_seq_concat_reshape test_pad test_smooth_l1 test_multiplex_layer test_prelu_layer test_row_conv test_detection_output_layer test_multibox_loss_layer test_recursive_topology test_gated_unit_layer test_clip_layer test_row_l2_norm_layer test_kmax_seq_socre_layer test_sub_nested_seq_select_layer test_scale_shift_layer -test_seq_slice_layer test_cross_entropy_over_beam test_pooling3D_layer +test_seq_slice_layer test_cross_entropy_over_beam test_roi_pool_layer test_pooling3D_layer test_conv3d_layer test_deconv3d_layer test_BatchNorm3D test_resize_layer test_scale_sub_region_layer) export whole_configs=(test_split_datasource) diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_roi_pool_layer.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_roi_pool_layer.protostr new file mode 100644 index 0000000000000000000000000000000000000000..f1bc65b3aee7488700a9d24e049adb510649c475 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_roi_pool_layer.protostr @@ -0,0 +1,98 @@ +type: "nn" +layers { + name: "data" + type: "data" + size: 588 + active_type: "" + height: 14 + width: 14 +} +layers { + name: "rois" + type: "data" + size: 10 + active_type: "" +} +layers { + name: "__conv_0__" + type: "exconv" + size: 3136 + active_type: "" + inputs { + input_layer_name: "data" + input_parameter_name: "___conv_0__.w0" + conv_conf { + filter_size: 3 + channels: 3 + stride: 1 + padding: 1 + groups: 1 + filter_channels: 3 + output_x: 14 + img_size: 14 + caffe_mode: true + filter_size_y: 3 + padding_y: 1 + stride_y: 1 + output_y: 14 + img_size_y: 14 + } + } + bias_parameter_name: "___conv_0__.wbias" + num_filters: 16 + shared_biases: true + height: 14 + width: 14 +} +layers { + name: "__roi_pool_0__" + type: "roi_pool" + size: 784 + active_type: "" + inputs { + input_layer_name: "__conv_0__" + roi_pool_conf { + pooled_width: 7 + pooled_height: 7 + spatial_scale: 0.0625 + } + } + inputs { + input_layer_name: "rois" + } + height: 7 + width: 7 +} +parameters { + name: "___conv_0__.w0" + size: 432 + initial_mean: 0.0 + initial_std: 0.272165526976 + initial_strategy: 0 + initial_smart: false +} +parameters { + name: "___conv_0__.wbias" + size: 16 + initial_mean: 0.0 + initial_std: 0.0 + dims: 16 + dims: 1 + initial_strategy: 0 + initial_smart: false +} +input_layer_names: "data" +input_layer_names: "rois" +output_layer_names: "__roi_pool_0__" +sub_models { + name: "root" + layer_names: "data" + layer_names: "rois" + layer_names: "__conv_0__" + layer_names: "__roi_pool_0__" + input_layer_names: "data" + input_layer_names: "rois" + output_layer_names: "__roi_pool_0__" + is_recurrent_layer_group: false +} + diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_roi_pool_layer.py b/python/paddle/trainer_config_helpers/tests/configs/test_roi_pool_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..b739a81b8505c94a2312ac735647fb114982f1f7 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/test_roi_pool_layer.py @@ -0,0 +1,23 @@ +from paddle.trainer_config_helpers import * + +data = data_layer(name='data', size=3 * 14 * 14, height=14, width=14) + +rois = data_layer(name='rois', size=10) + +conv = img_conv_layer( + input=data, + filter_size=3, + num_channels=3, + num_filters=16, + padding=1, + act=LinearActivation(), + bias_attr=True) + +roi_pool = roi_pool_layer( + input=conv, + rois=rois, + pooled_width=7, + pooled_height=7, + spatial_scale=1. / 16) + +outputs(roi_pool)