提交 842c865e 编写于 作者: Z Zhaolong Xing 提交者: GitHub

Merge pull request #325 from NHZlX/resnet_modify

the resnet.py is inconsistent with the standard model. modify the basic…
...@@ -36,26 +36,25 @@ def conv_bn_layer(input, ...@@ -36,26 +36,25 @@ def conv_bn_layer(input,
return paddle.layer.batch_norm(input=tmp, act=active_type) return paddle.layer.batch_norm(input=tmp, act=active_type)
def shortcut(ipt, n_in, n_out, stride): def shortcut(ipt, ch_in, ch_out, stride):
if n_in != n_out: if ch_in != ch_out:
return conv_bn_layer(ipt, n_out, 1, stride, 0, return conv_bn_layer(ipt, ch_out, 1, stride, 0,
paddle.activation.Linear()) paddle.activation.Linear())
else: else:
return ipt return ipt
def basicblock(ipt, ch_out, stride): def basicblock(ipt, ch_in, ch_out, stride):
ch_in = ch_out * 2
tmp = conv_bn_layer(ipt, ch_out, 3, stride, 1) tmp = conv_bn_layer(ipt, ch_out, 3, stride, 1)
tmp = conv_bn_layer(tmp, ch_out, 3, 1, 1, paddle.activation.Linear()) tmp = conv_bn_layer(tmp, ch_out, 3, 1, 1, paddle.activation.Linear())
short = shortcut(ipt, ch_in, ch_out, stride) short = shortcut(ipt, ch_in, ch_out, stride)
return paddle.layer.addto(input=[tmp, short], act=paddle.activation.Relu()) return paddle.layer.addto(input=[tmp, short], act=paddle.activation.Relu())
def layer_warp(block_func, ipt, features, count, stride): def layer_warp(block_func, ipt, ch_in, ch_out, count, stride):
tmp = block_func(ipt, features, stride) tmp = block_func(ipt, ch_in, ch_out, stride)
for i in range(1, count): for i in range(1, count):
tmp = block_func(tmp, features, 1) tmp = block_func(tmp, ch_out, ch_out, 1)
return tmp return tmp
...@@ -66,9 +65,9 @@ def resnet_cifar10(ipt, depth=32): ...@@ -66,9 +65,9 @@ def resnet_cifar10(ipt, depth=32):
nStages = {16, 64, 128} nStages = {16, 64, 128}
conv1 = conv_bn_layer( conv1 = conv_bn_layer(
ipt, ch_in=3, ch_out=16, filter_size=3, stride=1, padding=1) ipt, ch_in=3, ch_out=16, filter_size=3, stride=1, padding=1)
res1 = layer_warp(basicblock, conv1, 16, n, 1) res1 = layer_warp(basicblock, conv1, 16, 16, n, 1)
res2 = layer_warp(basicblock, res1, 32, n, 2) res2 = layer_warp(basicblock, res1, 16, 32, n, 2)
res3 = layer_warp(basicblock, res2, 64, n, 2) res3 = layer_warp(basicblock, res2, 32, 64, n, 2)
pool = paddle.layer.img_pool( pool = paddle.layer.img_pool(
input=res3, pool_size=8, stride=1, pool_type=paddle.pooling.Avg()) input=res3, pool_size=8, stride=1, pool_type=paddle.pooling.Avg())
return pool return pool
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册