提交 4a0109ff 编写于 作者: Y Yibing Liu

Some modifications to run the model

上级 e07b56a9
...@@ -2,35 +2,38 @@ import os ...@@ -2,35 +2,38 @@ import os
import paddle.fluid as fluid import paddle.fluid as fluid
def inception_v4(image, label): def inception_v4(img, class_dim):
tmp = stem(input=image) tmp = stem(input=img)
for i in range(0, 4): for i in range(1):
tmp = inception_A(input=tmp, depth=i) tmp = inception_A(input=tmp, depth=i)
tmp = reduction_A(input=tmp) tmp = reduction_A(input=tmp)
for i in range(0, 7): for i in range(7):
tmp = inception_B(input=tmp, depth=i) tmp = inception_B(input=tmp, depth=i)
reduction_B(input=tmp) reduction_B(input=tmp)
for i in range(0, 3): for i in range(3):
tmp = inception_C(input=tmp, depth=i) tmp = inception_C(input=tmp, depth=i)
pool = fluid.layers.pool2d( pool = fluid.layers.pool2d(
pool_type='ave', input=tmp, pool_size=7, pool_stride=1) pool_type='avg', input=tmp, pool_size=7, pool_stride=1)
dropout = fluid.layers.dropout(input=pool, drop_prob=0.2) dropout = fluid.layers.dropout(x=pool, dropout_prob=0.2)
out = fluid.layers.softmax(input=dropout) fc = fluid.layers.fc(input=dropout, size=class_dim, act='softmax')
out = fluid.layers.softmax(input=fc)
return out return out
def conv_bn_layer(input, def conv_bn_layer(name,
input,
num_filters, num_filters,
filter_size, filter_size,
padding=1, padding=0,
stride=1, stride=1,
groups=1, groups=1,
act=None): act=None):
conv = fluid.layers.conv2d( conv = fluid.layers.conv2d(
name=name,
input=input, input=input,
num_filters=num_filters, num_filters=num_filters,
filter_size=filter_size, filter_size=filter_size,
...@@ -39,50 +42,106 @@ def conv_bn_layer(input, ...@@ -39,50 +42,106 @@ def conv_bn_layer(input,
groups=groups, groups=groups,
act=None, act=None,
bias_attr=False) bias_attr=False)
return fluid.layers.batch_norm(input=conv, act=act) return fluid.layers.batch_norm(name=name + '_norm', input=conv, act=act)
def stem(input): def stem(input):
conv1 = conv_bn_layer(input=input, num_filters=32, filter_size=3, stride=2) conv0 = conv_bn_layer(
conv2 = conv_bn_layer(input=conv1, num_filters=32, filter_size=3) name='stem_conv_0',
conv3 = conv_bn_layer(input=conv2, num_filters=64, filter_size=3) input=input,
num_filters=32,
filter_size=3,
padding=1,
stride=2)
conv1 = conv_bn_layer(
name='stem_conv_1',
input=conv0,
num_filters=32,
filter_size=3,
padding=1)
conv2 = conv_bn_layer(
name='stem_conv_2',
input=conv1,
num_filters=64,
filter_size=3,
padding=1)
def block0(input): def block0(input):
pool0 = fluid.layers.pool2d( pool0 = fluid.layers.pool2d(
input=input, pool_size=3, pool_stride=2, pool_type='max') input=input,
pool_size=3,
pool_stride=2,
pool_type='max',
pool_padding=1)
conv0 = conv_bn_layer( conv0 = conv_bn_layer(
input=input, num_filters=96, filter_size=3, stride=2) name='stem_block0_conv',
return fluid.layers.concat(input=[pool0, conv0]) input=input,
num_filters=96,
filter_size=3,
stride=2,
padding=1)
return fluid.layers.concat(input=[pool0, conv0], axis=1)
def block1(input): def block1(input):
l_conv0 = conv_bn_layer( l_conv0 = conv_bn_layer(
input=input, num_filters=64, filter_size=1, stride=1, padding=0) name='stem_block1_l_conv0',
input=input,
num_filters=64,
filter_size=1,
stride=1,
padding=0)
l_conv1 = conv_bn_layer( l_conv1 = conv_bn_layer(
input=l_conv0, num_filters=96, filter_size=3, stride=1, padding=1) name='stem_block1_l_conv1',
input=l_conv0,
num_filters=96,
filter_size=3,
stride=1,
padding=1)
r_conv0 = conv_bn_layer( r_conv0 = conv_bn_layer(
input=input, num_filters=64, filter_size=1, stride=1, padding=0) name='stem_block1_r_conv0',
input=input,
num_filters=64,
filter_size=1,
stride=1,
padding=0)
r_conv1 = conv_bn_layer( r_conv1 = conv_bn_layer(
name='stem_block1_r_conv1',
input=r_conv0, input=r_conv0,
num_filters=64, num_filters=64,
filter_size=(7, 1), filter_size=(7, 1),
stride=1, stride=1,
padding=(3, 0)) padding=(3, 0))
r_conv2 = conv_bn_layer( r_conv2 = conv_bn_layer(
name='stem_block1_r_conv2',
input=r_conv1, input=r_conv1,
num_filters=64, num_filters=64,
filter_size=(1, 7), filter_size=(1, 7),
stride=1, stride=1,
padding=(0, 3)) padding=(0, 3))
r_conv3 = conv_bn_layer( r_conv3 = conv_bn_layer(
input=r_conv2, num_filters=96, filter_size=3, stride=1, padding=1) name='stem_block1_r_conv3',
return fluid.layers.concat(input=[l_conv1, r_conv3]) input=r_conv2,
num_filters=96,
filter_size=3,
stride=1,
padding=1)
return fluid.layers.concat(input=[l_conv1, r_conv3], axis=3)
def block2(input): def block2(input):
conv0 = conv_bn_layer( conv0 = conv_bn_layer(
input=input, num_filters=192, filter_size=3, stride=2, padding=1) name='stem_block2_conv',
input=input,
num_filters=192,
filter_size=3,
stride=2,
padding=1)
pool0 = fluid.layers.pool2d( pool0 = fluid.layers.pool2d(
input=input, pool_size=3, pool_stride=2, pool_type='max') input=input,
return fluid.layers.concat(input=[conv0, pool0]) pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')
return fluid.layers.concat(input=[conv0, pool0], axis=1)
conv3 = block0(conv2) conv3 = block0(conv2)
conv4 = block1(conv3) conv4 = block1(conv3)
...@@ -91,12 +150,12 @@ def stem(input): ...@@ -91,12 +150,12 @@ def stem(input):
def inception_A(input, depth): def inception_A(input, depth):
b0_pool0 = paddle.layer.pool2d( b0_pool0 = fluid.layers.pool2d(
name='inceptA{0}_branch0_pool0'.format(depth), name='inceptA{0}_branch0_pool0'.format(depth),
input=input, input=input,
pool_size=3, pool_size=3,
stride=1, pool_stride=1,
padding=1, pool_padding=1,
pool_type='avg') pool_type='avg')
b0_conv0 = conv_bn_layer( b0_conv0 = conv_bn_layer(
name='inceptA{0}_branch0_conv0'.format(depth), name='inceptA{0}_branch0_conv0'.format(depth),
...@@ -122,7 +181,6 @@ def inception_A(input, depth): ...@@ -122,7 +181,6 @@ def inception_A(input, depth):
b2_conv1 = conv_bn_layer( b2_conv1 = conv_bn_layer(
name='inceptA{0}_branch2_conv1'.format(depth), name='inceptA{0}_branch2_conv1'.format(depth),
input=b2_conv0, input=b2_conv0,
num_channels=64,
num_filters=96, num_filters=96,
filter_size=3, filter_size=3,
stride=1, stride=1,
...@@ -130,7 +188,6 @@ def inception_A(input, depth): ...@@ -130,7 +188,6 @@ def inception_A(input, depth):
b3_conv0 = conv_bn_layer( b3_conv0 = conv_bn_layer(
name='inceptA{0}_branch3_conv0'.format(depth), name='inceptA{0}_branch3_conv0'.format(depth),
input=input, input=input,
num_channels=384,
num_filters=64, num_filters=64,
filter_size=1, filter_size=1,
stride=1, stride=1,
...@@ -149,7 +206,8 @@ def inception_A(input, depth): ...@@ -149,7 +206,8 @@ def inception_A(input, depth):
filter_size=3, filter_size=3,
stride=1, stride=1,
padding=1) padding=1)
return paddle.layer.concat(input=[b0_conv0, b1_conv0, b2_conv1, b3_conv2]) return fluid.layers.concat(
input=[b0_conv0, b1_conv0, b2_conv1, b3_conv2], axis=1)
def reduction_A(input): def reduction_A(input):
...@@ -158,6 +216,7 @@ def reduction_A(input): ...@@ -158,6 +216,7 @@ def reduction_A(input):
input=input, input=input,
pool_size=3, pool_size=3,
pool_stride=2, pool_stride=2,
pool_padding=1,
pool_type='max') pool_type='max')
b1_conv0 = conv_bn_layer( b1_conv0 = conv_bn_layer(
name='ReductA_branch1_conv0', name='ReductA_branch1_conv0',
...@@ -187,7 +246,7 @@ def reduction_A(input): ...@@ -187,7 +246,7 @@ def reduction_A(input):
filter_size=3, filter_size=3,
stride=2, stride=2,
padding=1) padding=1)
return fluid.layers.concat(input=[b0_pool0, b1_conv0, b2_conv2]) return fluid.layers.concat(input=[b0_pool0, b1_conv0, b2_conv2], axis=1)
def inception_B(input, depth): def inception_B(input, depth):
...@@ -268,7 +327,8 @@ def inception_B(input, depth): ...@@ -268,7 +327,8 @@ def inception_B(input, depth):
filter_size=(7, 1), filter_size=(7, 1),
stride=1, stride=1,
padding=(3, 0)) padding=(3, 0))
return fluid.layers.concat(input=[b0_conv0, b1_conv0, b2_conv2, b3_conv4]) return fluid.layers.concat(
input=[b0_conv0, b1_conv0, b2_conv2, b3_conv4], axis=1)
def reduction_B(input): def reduction_B(input):
...@@ -277,6 +337,7 @@ def reduction_B(input): ...@@ -277,6 +337,7 @@ def reduction_B(input):
input=input, input=input,
pool_size=3, pool_size=3,
pool_stride=2, pool_stride=2,
pool_padding=1,
pool_type='max') pool_type='max')
b1_conv0 = conv_bn_layer( b1_conv0 = conv_bn_layer(
name='ReductB_branch1_conv0', name='ReductB_branch1_conv0',
...@@ -320,7 +381,7 @@ def reduction_B(input): ...@@ -320,7 +381,7 @@ def reduction_B(input):
filter_size=3, filter_size=3,
stride=2, stride=2,
padding=1) padding=1)
return fluid.layers.concat(input=[b0_pool0, b1_conv1, b2_conv3]) return fluid.layers.concat(input=[b0_pool0, b1_conv1, b2_conv3], axis=1)
def inception_C(input, depth): def inception_C(input, depth):
...@@ -402,4 +463,5 @@ def inception_C(input, depth): ...@@ -402,4 +463,5 @@ def inception_C(input, depth):
stride=1, stride=1,
padding=(0, 1)) padding=(0, 1))
return fluid.layers.concat( return fluid.layers.concat(
input=[b0_conv0, b1_conv0, b2_conv1, b2_conv2, b3_conv3, b3_conv4]) input=[b0_conv0, b1_conv0, b2_conv1, b2_conv2, b3_conv3, b3_conv4],
axis=1)
...@@ -222,7 +222,7 @@ def train_parallel_exe(args, ...@@ -222,7 +222,7 @@ def train_parallel_exe(args,
use_nccl=True, use_nccl=True,
lr_strategy=None, lr_strategy=None,
layers=50): layers=50):
class_dim = 1000 class_dim = 101
image_shape = [3, 224, 224] image_shape = [3, 224, 224]
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32') image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
...@@ -286,6 +286,7 @@ def train_parallel_exe(args, ...@@ -286,6 +286,7 @@ def train_parallel_exe(args,
train_reader = paddle.batch(reader.train(), batch_size=batch_size) train_reader = paddle.batch(reader.train(), batch_size=batch_size)
test_reader = paddle.batch(reader.test(), batch_size=batch_size) test_reader = paddle.batch(reader.test(), batch_size=batch_size)
feeder = fluid.DataFeeder(place=place, feed_list=[image, label]) feeder = fluid.DataFeeder(place=place, feed_list=[image, label])
train_exe = fluid.ParallelExecutor(use_cuda=True, loss_name=avg_cost.name) train_exe = fluid.ParallelExecutor(use_cuda=True, loss_name=avg_cost.name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册