nets.py 3.7 KB
Newer Older
1
import layers
F
fengjiayi 已提交
2

D
dzhwinter 已提交
3 4
__all__ = ["simple_img_conv_pool", "sequence_conv_pool"]

F
fengjiayi 已提交
5 6 7

def simple_img_conv_pool(input,
                         num_filters,
D
dzhwinter 已提交
8
                         filter_size,
F
fengjiayi 已提交
9 10 11
                         pool_size,
                         pool_stride,
                         act,
Q
Qiao Longfei 已提交
12
                         pool_type='max',
13 14
                         main_program=None,
                         startup_program=None):
F
fengjiayi 已提交
15 16 17 18 19
    conv_out = layers.conv2d(
        input=input,
        num_filters=num_filters,
        filter_size=filter_size,
        act=act,
20 21
        main_program=main_program,
        startup_program=startup_program)
F
fengjiayi 已提交
22 23 24 25

    pool_out = layers.pool2d(
        input=conv_out,
        pool_size=pool_size,
Q
Qiao Longfei 已提交
26 27
        pool_type=pool_type,
        pool_stride=pool_stride,
28 29
        main_program=main_program,
        startup_program=startup_program)
Q
Qiao Longfei 已提交
30 31 32 33 34 35 36 37 38 39 40 41 42
    return pool_out


def img_conv_group(input,
                   conv_num_filter,
                   pool_size,
                   conv_padding=1,
                   conv_filter_size=3,
                   conv_act=None,
                   conv_with_batchnorm=False,
                   conv_batchnorm_drop_rate=None,
                   pool_stride=1,
                   pool_type=None,
43 44
                   main_program=None,
                   startup_program=None):
Q
Qiao Longfei 已提交
45 46 47 48 49
    """
    Image Convolution Group, Used for vgg net.
    """
    tmp = input
    assert isinstance(conv_num_filter, list) or \
50
        isinstance(conv_num_filter, tuple)
Q
Qiao Longfei 已提交
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73

    def __extend_list__(obj):
        if not hasattr(obj, '__len__'):
            return [obj] * len(conv_num_filter)
        else:
            return obj

    conv_padding = __extend_list__(conv_padding)
    conv_filter_size = __extend_list__(conv_filter_size)
    conv_with_batchnorm = __extend_list__(conv_with_batchnorm)
    conv_batchnorm_drop_rate = __extend_list__(conv_batchnorm_drop_rate)

    for i in xrange(len(conv_num_filter)):
        local_conv_act = conv_act
        if conv_with_batchnorm[i]:
            local_conv_act = None

        tmp = layers.conv2d(
            input=tmp,
            num_filters=conv_num_filter[i],
            filter_size=conv_filter_size[i],
            padding=conv_padding[i],
            act=local_conv_act,
74 75
            main_program=main_program,
            startup_program=startup_program)
Q
Qiao Longfei 已提交
76 77 78 79 80

        if conv_with_batchnorm[i]:
            tmp = layers.batch_norm(
                input=tmp,
                act=conv_act,
81 82
                main_program=main_program,
                startup_program=startup_program)
Q
Qiao Longfei 已提交
83 84 85 86 87
            drop_rate = conv_batchnorm_drop_rate[i]
            if abs(drop_rate) > 1e-5:
                tmp = layers.dropout(
                    x=tmp,
                    dropout_prob=drop_rate,
88 89
                    main_program=main_program,
                    startup_program=startup_program)
Q
Qiao Longfei 已提交
90 91 92 93 94

    pool_out = layers.pool2d(
        input=tmp,
        pool_size=pool_size,
        pool_type=pool_type,
F
fengjiayi 已提交
95
        pool_stride=pool_stride,
96 97
        main_program=main_program,
        startup_program=startup_program)
F
fengjiayi 已提交
98
    return pool_out
D
dzhwinter 已提交
99 100 101 102 103


def sequence_conv_pool(input,
                       num_filters,
                       filter_size,
104
                       act="sigmoid",
D
dzhwinter 已提交
105
                       pool_type="max",
106 107
                       main_program=None,
                       startup_program=None):
D
dzhwinter 已提交
108 109 110 111
    conv_out = layers.sequence_conv(
        input=input,
        num_filters=num_filters,
        filter_size=filter_size,
112
        act=act,
113 114
        main_program=main_program,
        startup_program=startup_program)
D
dzhwinter 已提交
115 116 117

    pool_out = layers.sequence_pool(
        input=conv_out,
D
dzhwinter 已提交
118
        pool_type=pool_type,
119 120
        main_program=main_program,
        startup_program=startup_program)
D
dzhwinter 已提交
121
    return pool_out