resnet.py 2.4 KB
Newer Older
L
liaogang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle.v2 as paddle

__all__ = ['resnet_cifar10']


def conv_bn_layer(input,
                  ch_out,
                  filter_size,
                  stride,
                  padding,
                  active_type=paddle.activation.Relu(),
                  ch_in=None):
    tmp = paddle.layer.img_conv(
        input=input,
        filter_size=filter_size,
        num_channels=ch_in,
        num_filters=ch_out,
        stride=stride,
        padding=padding,
        act=paddle.activation.Linear(),
Z
Zhaolong Xing 已提交
35
        bias_attr=False,)
L
liaogang 已提交
36 37 38
    return paddle.layer.batch_norm(input=tmp, act=active_type)


39 40 41
def shortcut(ipt, ch_in, ch_out, stride):
    if ch_in != ch_out:
        return conv_bn_layer(ipt, ch_out, 1, stride, 0,
L
liaogang 已提交
42 43 44 45 46
                             paddle.activation.Linear())
    else:
        return ipt


47
def basicblock(ipt, ch_in, ch_out, stride):
L
liaogang 已提交
48 49 50 51 52 53
    tmp = conv_bn_layer(ipt, ch_out, 3, stride, 1)
    tmp = conv_bn_layer(tmp, ch_out, 3, 1, 1, paddle.activation.Linear())
    short = shortcut(ipt, ch_in, ch_out, stride)
    return paddle.layer.addto(input=[tmp, short], act=paddle.activation.Relu())


54 55
def layer_warp(block_func, ipt, ch_in, ch_out, count, stride):
    tmp = block_func(ipt, ch_in, ch_out, stride)
L
liaogang 已提交
56
    for i in range(1, count):
57
        tmp = block_func(tmp, ch_out, ch_out, 1)
L
liaogang 已提交
58 59 60 61 62 63 64 65 66 67
    return tmp


def resnet_cifar10(ipt, depth=32):
    # depth should be one of 20, 32, 44, 56, 110, 1202
    assert (depth - 2) % 6 == 0
    n = (depth - 2) / 6
    nStages = {16, 64, 128}
    conv1 = conv_bn_layer(
        ipt, ch_in=3, ch_out=16, filter_size=3, stride=1, padding=1)
68 69 70
    res1 = layer_warp(basicblock, conv1, 16, 16, n, 1)
    res2 = layer_warp(basicblock, res1, 16, 32, n, 2)
    res3 = layer_warp(basicblock, res2, 32, 64, n, 2)
L
liaogang 已提交
71 72 73
    pool = paddle.layer.img_pool(
        input=res3, pool_size=8, stride=1, pool_type=paddle.pooling.Avg())
    return pool