提交 7829034d 编写于 作者: G guosheng

Refine ROIPoolLayer by following comments

上级 1ffdecf1
......@@ -91,6 +91,8 @@ void ROIPoolLayer::forward(PassType passType) {
real* argmaxData = maxIdxs_->getData();
for (size_t n = 0; n < numROIs; ++n) {
// the first five elememts of each RoI should be:
// batch_idx, roi_x_start, roi_y_start, roi_x_end, roi_y_end
size_t roiBatchIdx = bottomROIs[0];
size_t roiStartW = round(bottomROIs[1] * spatialScale_);
size_t roiStartH = round(bottomROIs[2] * spatialScale_);
......
......@@ -41,6 +41,7 @@ protected:
size_t pooledHeight_;
real spatialScale_;
// Since there is no int matrix, use real maxtrix instead.
MatrixPtr maxIdxs_;
public:
......
......@@ -1971,13 +1971,14 @@ class DetectionOutputLayer(LayerBase):
@config_layer('roi_pool')
class ROIPoolLayer(LayerBase):
def __init__(self, name, inputs, pooled_width, pooled_height,
spatial_scale):
def __init__(self, name, inputs, pooled_width, pooled_height, spatial_scale,
num_channels, **xargs):
super(ROIPoolLayer, self).__init__(name, 'roi_pool', 0, inputs)
config_assert(len(inputs) == 2, 'ROIPoolLayer must have 2 inputs')
self.config.inputs[0].roi_pool_conf.pooled_width = pooled_width
self.config.inputs[0].roi_pool_conf.pooled_height = pooled_height
self.config.inputs[0].roi_pool_conf.spatial_scale = spatial_scale
self.set_cnn_layer(name, pooled_height, pooled_width, num_channels)
@config_layer('data')
......
......@@ -1345,7 +1345,8 @@ def roi_pool_layer(input,
inputs=[input.name, rois.name],
pooled_width=pooled_width,
pooled_height=pooled_height,
spatial_scale=spatial_scale)
spatial_scale=spatial_scale,
num_channels=num_channels)
return LayerOutput(
name, LayerType.ROI_POOL_LAYER, parents=[input, rois], size=size)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册