roipooling.py 1.0 KB
Newer Older
S
SunAhong1993 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
from .register import register
from x2paddle.core.util import *


def roipooling_shape(input_shape, pooled_w=None, pooled_h=None):
    base_fea_shape = input_shapes[0]
    rois_shape = input_shapes[1]
    output_shape = base_fea_shape
    output_shape[0] = rois_shape[0]
    output_shape[2] = pooled_h
    output_shape[3] = pooled_w
    return [output_shape]


def roipooling_layer(inputs,
                     pooled_w=None,
                     pooled_h=None,
                     spatial_scale=None,
                     input_shape=None,
                     name=None):
    input = inputs[0]
    roi = inputs[1]
S
SunAhong1993 已提交
23
    roi = paddle.slice(roi, axes=[1], starts=[1], ends=[5])
J
jiangjiajun 已提交
24 25 26 27 28 29
    out = fluid.layers.roi_pool(
        input,
        roi,
        pooled_height=pooled_h,
        pooled_width=pooled_w,
        spatial_scale=spatial_scale)
S
SunAhong1993 已提交
30 31 32 33 34 35 36 37
    return out


def roipooling_weights(name, data=None):
    weights_name = []
    return weights_name


J
jiangjiajun 已提交
38 39 40 41 42
register(
    kind='ROIPooling',
    shape=roipooling_shape,
    layer=roipooling_layer,
    weights=roipooling_weights)