resnet.py 3.3 KB
Newer Older
W
wwhu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
import paddle.v2 as paddle

__all__ = ['resnet_imagenet', 'resnet_cifar10']


def conv_bn_layer(input,
                  ch_out,
                  filter_size,
                  stride,
                  padding,
                  active_type=paddle.activation.Relu(),
                  ch_in=None):
    tmp = paddle.layer.img_conv(
        input=input,
        filter_size=filter_size,
        num_channels=ch_in,
        num_filters=ch_out,
        stride=stride,
        padding=padding,
        act=paddle.activation.Linear(),
        bias_attr=False)
    return paddle.layer.batch_norm(input=tmp, act=active_type)


def shortcut(input, n_out, stride, b_projection):
    if b_projection:
        return conv_bn_layer(input, n_out, 1, stride, 0,
                             paddle.activation.Linear())
    else:
        return input


def basicblock(input, ch_out, stride, b_projection):
    conv1 = conv_bn_layer(input, ch_out, 3, stride, 1)
    conv2 = conv_bn_layer(conv1, ch_out, 3, 1, 1, paddle.activation.Linear())
    short = shortcut(input, ch_out, stride, b_projection)
    return paddle.layer.addto(
        input=[conv2, short], act=paddle.activation.Relu())


def bottleneck(input, ch_out, stride, b_projection):
    conv1 = conv_bn_layer(input, ch_out, 1, stride, 0)
    conv2 = conv_bn_layer(conv1, ch_out, 3, 1, 1)
    conv3 = conv_bn_layer(conv2, ch_out * 4, 1, 1, 0,
                          paddle.activation.Linear())
    short = shortcut(input, ch_out * 4, stride, b_projection)
    return paddle.layer.addto(
        input=[conv3, short], act=paddle.activation.Relu())


def layer_warp(block_func, input, features, count, stride):
    conv = block_func(input, features, stride, True)
    for i in range(1, count):
        conv = block_func(conv, features, 1, False)
    return conv


W
wwhu 已提交
58
def resnet_imagenet(input, depth=50, class_dim=100):
W
wwhu 已提交
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
    cfg = {
        18: ([2, 2, 2, 1], basicblock),
        34: ([3, 4, 6, 3], basicblock),
        50: ([3, 4, 6, 3], bottleneck),
        101: ([3, 4, 23, 3], bottleneck),
        152: ([3, 8, 36, 3], bottleneck)
    }
    stages, block_func = cfg[depth]
    conv1 = conv_bn_layer(
        input, ch_in=3, ch_out=64, filter_size=7, stride=2, padding=3)
    pool1 = paddle.layer.img_pool(input=conv1, pool_size=3, stride=2)
    res1 = layer_warp(block_func, pool1, 64, stages[0], 1)
    res2 = layer_warp(block_func, res1, 128, stages[1], 2)
    res3 = layer_warp(block_func, res2, 256, stages[2], 2)
    res4 = layer_warp(block_func, res3, 512, stages[3], 2)
    pool2 = paddle.layer.img_pool(
        input=res4, pool_size=7, stride=1, pool_type=paddle.pooling.Avg())
W
wwhu 已提交
76 77 78
    out = paddle.layer.fc(
        input=pool2, size=class_dim, act=paddle.activation.Softmax())
    return out
W
wwhu 已提交
79 80


W
wwhu 已提交
81
def resnet_cifar10(input, depth=32, class_dim=10):
W
wwhu 已提交
82 83 84 85 86 87 88 89 90 91 92
    # depth should be one of 20, 32, 44, 56, 110, 1202
    assert (depth - 2) % 6 == 0
    n = (depth - 2) / 6
    nStages = {16, 64, 128}
    conv1 = conv_bn_layer(
        input, ch_in=3, ch_out=16, filter_size=3, stride=1, padding=1)
    res1 = layer_warp(basicblock, conv1, 16, n, 1)
    res2 = layer_warp(basicblock, res1, 32, n, 2)
    res3 = layer_warp(basicblock, res2, 64, n, 2)
    pool = paddle.layer.img_pool(
        input=res3, pool_size=8, stride=1, pool_type=paddle.pooling.Avg())
W
wwhu 已提交
93 94 95
    out = paddle.layer.fc(
        input=pool, size=class_dim, act=paddle.activation.Softmax())
    return out