提交 7829034d 编写于 作者: G guosheng

Refine ROIPoolLayer by following comments

上级 1ffdecf1
...@@ -91,6 +91,8 @@ void ROIPoolLayer::forward(PassType passType) { ...@@ -91,6 +91,8 @@ void ROIPoolLayer::forward(PassType passType) {
real* argmaxData = maxIdxs_->getData(); real* argmaxData = maxIdxs_->getData();
for (size_t n = 0; n < numROIs; ++n) { for (size_t n = 0; n < numROIs; ++n) {
// the first five elememts of each RoI should be:
// batch_idx, roi_x_start, roi_y_start, roi_x_end, roi_y_end
size_t roiBatchIdx = bottomROIs[0]; size_t roiBatchIdx = bottomROIs[0];
size_t roiStartW = round(bottomROIs[1] * spatialScale_); size_t roiStartW = round(bottomROIs[1] * spatialScale_);
size_t roiStartH = round(bottomROIs[2] * spatialScale_); size_t roiStartH = round(bottomROIs[2] * spatialScale_);
......
...@@ -41,6 +41,7 @@ protected: ...@@ -41,6 +41,7 @@ protected:
size_t pooledHeight_; size_t pooledHeight_;
real spatialScale_; real spatialScale_;
// Since there is no int matrix, use real maxtrix instead.
MatrixPtr maxIdxs_; MatrixPtr maxIdxs_;
public: public:
......
...@@ -1971,13 +1971,14 @@ class DetectionOutputLayer(LayerBase): ...@@ -1971,13 +1971,14 @@ class DetectionOutputLayer(LayerBase):
@config_layer('roi_pool') @config_layer('roi_pool')
class ROIPoolLayer(LayerBase): class ROIPoolLayer(LayerBase):
def __init__(self, name, inputs, pooled_width, pooled_height, def __init__(self, name, inputs, pooled_width, pooled_height, spatial_scale,
spatial_scale): num_channels, **xargs):
super(ROIPoolLayer, self).__init__(name, 'roi_pool', 0, inputs) super(ROIPoolLayer, self).__init__(name, 'roi_pool', 0, inputs)
config_assert(len(inputs) == 2, 'ROIPoolLayer must have 2 inputs') config_assert(len(inputs) == 2, 'ROIPoolLayer must have 2 inputs')
self.config.inputs[0].roi_pool_conf.pooled_width = pooled_width self.config.inputs[0].roi_pool_conf.pooled_width = pooled_width
self.config.inputs[0].roi_pool_conf.pooled_height = pooled_height self.config.inputs[0].roi_pool_conf.pooled_height = pooled_height
self.config.inputs[0].roi_pool_conf.spatial_scale = spatial_scale self.config.inputs[0].roi_pool_conf.spatial_scale = spatial_scale
self.set_cnn_layer(name, pooled_height, pooled_width, num_channels)
@config_layer('data') @config_layer('data')
......
...@@ -1345,7 +1345,8 @@ def roi_pool_layer(input, ...@@ -1345,7 +1345,8 @@ def roi_pool_layer(input,
inputs=[input.name, rois.name], inputs=[input.name, rois.name],
pooled_width=pooled_width, pooled_width=pooled_width,
pooled_height=pooled_height, pooled_height=pooled_height,
spatial_scale=spatial_scale) spatial_scale=spatial_scale,
num_channels=num_channels)
return LayerOutput( return LayerOutput(
name, LayerType.ROI_POOL_LAYER, parents=[input, rois], size=size) name, LayerType.ROI_POOL_LAYER, parents=[input, rois], size=size)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册