nets.py 3.2 KB
Newer Older
1
import layers
F
fengjiayi 已提交
2

3 4 5 6
__all__ = [
    "simple_img_conv_pool",
    "sequence_conv_pool",
]
D
dzhwinter 已提交
7

F
fengjiayi 已提交
8 9 10

def simple_img_conv_pool(input,
                         num_filters,
D
dzhwinter 已提交
11
                         filter_size,
F
fengjiayi 已提交
12 13 14
                         pool_size,
                         pool_stride,
                         act,
F
fengjiayi 已提交
15
                         param_attr=None,
C
chengduoZH 已提交
16
                         pool_type='max',
C
chengduoZH 已提交
17
                         use_cudnn=True):
F
fengjiayi 已提交
18 19 20 21
    conv_out = layers.conv2d(
        input=input,
        num_filters=num_filters,
        filter_size=filter_size,
F
fengjiayi 已提交
22
        param_attr=param_attr,
C
chengduoZH 已提交
23 24
        act=act,
        use_cudnn=use_cudnn)
F
fengjiayi 已提交
25 26 27 28

    pool_out = layers.pool2d(
        input=conv_out,
        pool_size=pool_size,
Q
Qiao Longfei 已提交
29
        pool_type=pool_type,
C
chengduoZH 已提交
30 31
        pool_stride=pool_stride,
        use_cudnn=use_cudnn)
Q
Qiao Longfei 已提交
32 33 34 35 36 37 38 39 40
    return pool_out


def img_conv_group(input,
                   conv_num_filter,
                   pool_size,
                   conv_padding=1,
                   conv_filter_size=3,
                   conv_act=None,
F
fengjiayi 已提交
41
                   param_attr=None,
Q
Qiao Longfei 已提交
42 43
                   conv_with_batchnorm=False,
                   conv_batchnorm_drop_rate=None,
C
chengduoZH 已提交
44
                   conv_use_cudnn=True,
Q
Qiao Longfei 已提交
45
                   pool_stride=1,
C
chengduoZH 已提交
46
                   pool_type=None,
C
chengduoZH 已提交
47
                   pool_use_cudnn=True):
Q
Qiao Longfei 已提交
48 49 50 51 52
    """
    Image Convolution Group, Used for vgg net.
    """
    tmp = input
    assert isinstance(conv_num_filter, list) or \
53
        isinstance(conv_num_filter, tuple)
Q
Qiao Longfei 已提交
54 55 56 57 58 59 60 61 62

    def __extend_list__(obj):
        if not hasattr(obj, '__len__'):
            return [obj] * len(conv_num_filter)
        else:
            return obj

    conv_padding = __extend_list__(conv_padding)
    conv_filter_size = __extend_list__(conv_filter_size)
F
fengjiayi 已提交
63
    param_attr = __extend_list__(param_attr)
Q
Qiao Longfei 已提交
64 65 66 67 68 69 70 71 72 73 74 75 76
    conv_with_batchnorm = __extend_list__(conv_with_batchnorm)
    conv_batchnorm_drop_rate = __extend_list__(conv_batchnorm_drop_rate)

    for i in xrange(len(conv_num_filter)):
        local_conv_act = conv_act
        if conv_with_batchnorm[i]:
            local_conv_act = None

        tmp = layers.conv2d(
            input=tmp,
            num_filters=conv_num_filter[i],
            filter_size=conv_filter_size[i],
            padding=conv_padding[i],
F
fengjiayi 已提交
77
            param_attr=param_attr[i],
C
chengduoZH 已提交
78 79
            act=local_conv_act,
            use_cudnn=conv_use_cudnn)
Q
Qiao Longfei 已提交
80 81

        if conv_with_batchnorm[i]:
82
            tmp = layers.batch_norm(input=tmp, act=conv_act)
Q
Qiao Longfei 已提交
83 84
            drop_rate = conv_batchnorm_drop_rate[i]
            if abs(drop_rate) > 1e-5:
85
                tmp = layers.dropout(x=tmp, dropout_prob=drop_rate)
Q
Qiao Longfei 已提交
86 87 88 89 90

    pool_out = layers.pool2d(
        input=tmp,
        pool_size=pool_size,
        pool_type=pool_type,
C
chengduoZH 已提交
91 92
        pool_stride=pool_stride,
        use_cudnn=pool_use_cudnn)
F
fengjiayi 已提交
93
    return pool_out
D
dzhwinter 已提交
94 95 96 97 98


def sequence_conv_pool(input,
                       num_filters,
                       filter_size,
F
fengjiayi 已提交
99
                       param_attr=None,
100
                       act="sigmoid",
101
                       pool_type="max"):
D
dzhwinter 已提交
102 103 104 105
    conv_out = layers.sequence_conv(
        input=input,
        num_filters=num_filters,
        filter_size=filter_size,
F
fengjiayi 已提交
106
        param_attr=param_attr,
107
        act=act)
D
dzhwinter 已提交
108

109
    pool_out = layers.sequence_pool(input=conv_out, pool_type=pool_type)
D
dzhwinter 已提交
110
    return pool_out