diff --git a/03.image_classification/resnet.py b/03.image_classification/resnet.py index 19d20540780becf504973a23b50445d4b65dc2ef..c60d19fc59dfea31d8a9b22d974047f60475b092 100644 --- a/03.image_classification/resnet.py +++ b/03.image_classification/resnet.py @@ -36,26 +36,25 @@ def conv_bn_layer(input, return paddle.layer.batch_norm(input=tmp, act=active_type) -def shortcut(ipt, n_in, n_out, stride): - if n_in != n_out: - return conv_bn_layer(ipt, n_out, 1, stride, 0, +def shortcut(ipt, ch_in, ch_out, stride): + if ch_in != ch_out: + return conv_bn_layer(ipt, ch_out, 1, stride, 0, paddle.activation.Linear()) else: return ipt -def basicblock(ipt, ch_out, stride): - ch_in = ch_out * 2 +def basicblock(ipt, ch_in, ch_out, stride): tmp = conv_bn_layer(ipt, ch_out, 3, stride, 1) tmp = conv_bn_layer(tmp, ch_out, 3, 1, 1, paddle.activation.Linear()) short = shortcut(ipt, ch_in, ch_out, stride) return paddle.layer.addto(input=[tmp, short], act=paddle.activation.Relu()) -def layer_warp(block_func, ipt, features, count, stride): - tmp = block_func(ipt, features, stride) +def layer_warp(block_func, ipt, ch_in, ch_out, count, stride): + tmp = block_func(ipt, ch_in, ch_out, stride) for i in range(1, count): - tmp = block_func(tmp, features, 1) + tmp = block_func(tmp, ch_out, ch_out, 1) return tmp @@ -66,9 +65,9 @@ def resnet_cifar10(ipt, depth=32): nStages = {16, 64, 128} conv1 = conv_bn_layer( ipt, ch_in=3, ch_out=16, filter_size=3, stride=1, padding=1) - res1 = layer_warp(basicblock, conv1, 16, n, 1) - res2 = layer_warp(basicblock, res1, 32, n, 2) - res3 = layer_warp(basicblock, res2, 64, n, 2) + res1 = layer_warp(basicblock, conv1, 16, 16, n, 1) + res2 = layer_warp(basicblock, res1, 16, 32, n, 2) + res3 = layer_warp(basicblock, res2, 32, 64, n, 2) pool = paddle.layer.img_pool( input=res3, pool_size=8, stride=1, pool_type=paddle.pooling.Avg()) return pool