resnet.py 3.3 KB
Newer Older
W
wwhu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
import paddle.v2 as paddle

__all__ = ['resnet_imagenet', 'resnet_cifar10']


def conv_bn_layer(input,
                  ch_out,
                  filter_size,
                  stride,
                  padding,
                  active_type=paddle.activation.Relu(),
                  ch_in=None):
    tmp = paddle.layer.img_conv(
        input=input,
        filter_size=filter_size,
        num_channels=ch_in,
        num_filters=ch_out,
        stride=stride,
        padding=padding,
        act=paddle.activation.Linear(),
        bias_attr=False)
    return paddle.layer.batch_norm(input=tmp, act=active_type)


W
wwhu 已提交
25 26 27
def shortcut(input, ch_in, ch_out, stride):
    if ch_in != ch_out:
        return conv_bn_layer(input, ch_out, 1, stride, 0,
W
wwhu 已提交
28 29 30 31 32
                             paddle.activation.Linear())
    else:
        return input


W
wwhu 已提交
33 34
def basicblock(input, ch_in, ch_out, stride):
    short = shortcut(input, ch_in, ch_out, stride)
W
wwhu 已提交
35 36 37
    conv1 = conv_bn_layer(input, ch_out, 3, stride, 1)
    conv2 = conv_bn_layer(conv1, ch_out, 3, 1, 1, paddle.activation.Linear())
    return paddle.layer.addto(
W
wwhu 已提交
38
        input=[short, conv2], act=paddle.activation.Relu())
W
wwhu 已提交
39 40


W
wwhu 已提交
41 42
def bottleneck(input, ch_in, ch_out, stride):
    short = shortcut(input, ch_in, ch_out * 4, stride)
W
wwhu 已提交
43 44 45 46 47
    conv1 = conv_bn_layer(input, ch_out, 1, stride, 0)
    conv2 = conv_bn_layer(conv1, ch_out, 3, 1, 1)
    conv3 = conv_bn_layer(conv2, ch_out * 4, 1, 1, 0,
                          paddle.activation.Linear())
    return paddle.layer.addto(
W
wwhu 已提交
48
        input=[short, conv3], act=paddle.activation.Relu())
W
wwhu 已提交
49 50


W
wwhu 已提交
51 52
def layer_warp(block_func, input, ch_in, ch_out, count, stride):
    conv = block_func(input, ch_in, ch_out, stride)
W
wwhu 已提交
53
    for i in range(1, count):
W
fix bug  
wwhu 已提交
54
        conv = block_func(conv, ch_out, ch_out, 1)
W
wwhu 已提交
55 56 57
    return conv


58
def resnet_imagenet(input, class_dim, depth=50):
W
wwhu 已提交
59 60 61 62 63 64 65 66 67 68 69
    cfg = {
        18: ([2, 2, 2, 1], basicblock),
        34: ([3, 4, 6, 3], basicblock),
        50: ([3, 4, 6, 3], bottleneck),
        101: ([3, 4, 23, 3], bottleneck),
        152: ([3, 8, 36, 3], bottleneck)
    }
    stages, block_func = cfg[depth]
    conv1 = conv_bn_layer(
        input, ch_in=3, ch_out=64, filter_size=7, stride=2, padding=3)
    pool1 = paddle.layer.img_pool(input=conv1, pool_size=3, stride=2)
W
wwhu 已提交
70 71 72 73
    res1 = layer_warp(block_func, pool1, 64, 64, stages[0], 1)
    res2 = layer_warp(block_func, res1, 64, 128, stages[1], 2)
    res3 = layer_warp(block_func, res2, 128, 256, stages[2], 2)
    res4 = layer_warp(block_func, res3, 256, 512, stages[3], 2)
W
wwhu 已提交
74 75
    pool2 = paddle.layer.img_pool(
        input=res4, pool_size=7, stride=1, pool_type=paddle.pooling.Avg())
W
wwhu 已提交
76 77 78
    out = paddle.layer.fc(
        input=pool2, size=class_dim, act=paddle.activation.Softmax())
    return out
W
wwhu 已提交
79 80


81
def resnet_cifar10(input, class_dim, depth=32):
W
wwhu 已提交
82 83 84 85 86 87
    # depth should be one of 20, 32, 44, 56, 110, 1202
    assert (depth - 2) % 6 == 0
    n = (depth - 2) / 6
    nStages = {16, 64, 128}
    conv1 = conv_bn_layer(
        input, ch_in=3, ch_out=16, filter_size=3, stride=1, padding=1)
88 89 90
    res1 = layer_warp(basicblock, conv1, 16, 16, n, 1)
    res2 = layer_warp(basicblock, res1, 16, 32, n, 2)
    res3 = layer_warp(basicblock, res2, 32, 64, n, 2)
W
wwhu 已提交
91 92
    pool = paddle.layer.img_pool(
        input=res3, pool_size=8, stride=1, pool_type=paddle.pooling.Avg())
W
wwhu 已提交
93 94 95
    out = paddle.layer.fc(
        input=pool, size=class_dim, act=paddle.activation.Softmax())
    return out