未验证 提交 e675734c 编写于 作者: J Jeff Wang 提交者: GitHub

Merge pull request #533 from jetfuel/image_classification_new_api

[High-Level-API] Image classification train.py update
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import paddle.v2 as paddle import paddle.fluid as fluid
__all__ = ['resnet_cifar10'] __all__ = ['resnet_cifar10']
...@@ -22,37 +22,35 @@ def conv_bn_layer(input, ...@@ -22,37 +22,35 @@ def conv_bn_layer(input,
filter_size, filter_size,
stride, stride,
padding, padding,
active_type=paddle.activation.Relu(), act='relu',
ch_in=None): bias_attr=False):
tmp = paddle.layer.img_conv( tmp = fluid.layers.conv2d(
input=input, input=input,
filter_size=filter_size, filter_size=filter_size,
num_channels=ch_in,
num_filters=ch_out, num_filters=ch_out,
stride=stride, stride=stride,
padding=padding, padding=padding,
act=paddle.activation.Linear(), act=None,
bias_attr=False) bias_attr=bias_attr)
return paddle.layer.batch_norm(input=tmp, act=active_type) return fluid.layers.batch_norm(input=tmp, act=act)
def shortcut(ipt, ch_in, ch_out, stride): def shortcut(input, ch_in, ch_out, stride):
if ch_in != ch_out: if ch_in != ch_out:
return conv_bn_layer(ipt, ch_out, 1, stride, 0, return conv_bn_layer(input, ch_out, 1, stride, 0, None)
paddle.activation.Linear())
else: else:
return ipt return input
def basicblock(ipt, ch_in, ch_out, stride): def basicblock(input, ch_in, ch_out, stride):
tmp = conv_bn_layer(ipt, ch_out, 3, stride, 1) tmp = conv_bn_layer(input, ch_out, 3, stride, 1)
tmp = conv_bn_layer(tmp, ch_out, 3, 1, 1, paddle.activation.Linear()) tmp = conv_bn_layer(tmp, ch_out, 3, 1, 1, act=None, bias_attr=True)
short = shortcut(ipt, ch_in, ch_out, stride) short = shortcut(input, ch_in, ch_out, stride)
return paddle.layer.addto(input=[tmp, short], act=paddle.activation.Relu()) return fluid.layers.elementwise_add(x=tmp, y=short, act='relu')
def layer_warp(block_func, ipt, ch_in, ch_out, count, stride): def layer_warp(block_func, input, ch_in, ch_out, count, stride):
tmp = block_func(ipt, ch_in, ch_out, stride) tmp = block_func(input, ch_in, ch_out, stride)
for i in range(1, count): for i in range(1, count):
tmp = block_func(tmp, ch_out, ch_out, 1) tmp = block_func(tmp, ch_out, ch_out, 1)
return tmp return tmp
...@@ -63,11 +61,11 @@ def resnet_cifar10(ipt, depth=32): ...@@ -63,11 +61,11 @@ def resnet_cifar10(ipt, depth=32):
assert (depth - 2) % 6 == 0 assert (depth - 2) % 6 == 0
n = (depth - 2) / 6 n = (depth - 2) / 6
nStages = {16, 64, 128} nStages = {16, 64, 128}
conv1 = conv_bn_layer( conv1 = conv_bn_layer(ipt, ch_out=16, filter_size=3, stride=1, padding=1)
ipt, ch_in=3, ch_out=16, filter_size=3, stride=1, padding=1)
res1 = layer_warp(basicblock, conv1, 16, 16, n, 1) res1 = layer_warp(basicblock, conv1, 16, 16, n, 1)
res2 = layer_warp(basicblock, res1, 16, 32, n, 2) res2 = layer_warp(basicblock, res1, 16, 32, n, 2)
res3 = layer_warp(basicblock, res2, 32, 64, n, 2) res3 = layer_warp(basicblock, res2, 32, 64, n, 2)
pool = paddle.layer.img_pool( pool = fluid.layers.pool2d(
input=res3, pool_size=8, stride=1, pool_type=paddle.pooling.Avg()) input=res3, pool_size=8, pool_type='avg', pool_stride=1)
return pool predict = fluid.layers.fc(input=pool, size=10, act='softmax')
return predict
...@@ -12,92 +12,84 @@ ...@@ -12,92 +12,84 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
import sys, os from __future__ import print_function
import paddle.v2 as paddle import paddle
import paddle.fluid as fluid
import numpy
import sys
from vgg import vgg_bn_drop from vgg import vgg_bn_drop
from resnet import resnet_cifar10 from resnet import resnet_cifar10
with_gpu = os.getenv('WITH_GPU', '0') != '0'
def inference_network():
# The image is 32 * 32 with RGB representation.
data_shape = [3, 32, 32]
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
def main(): predict = resnet_cifar10(images, 32)
datadim = 3 * 32 * 32 # predict = vgg_bn_drop(images) # un-comment to use vgg net
classdim = 10 return predict
# PaddlePaddle init
paddle.init(use_gpu=with_gpu, trainer_count=1)
image = paddle.layer.data( def train_network():
name="image", type=paddle.data_type.dense_vector(datadim)) predict = inference_network()
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(cost)
accuracy = fluid.layers.accuracy(input=predict, label=label)
return [avg_cost, accuracy]
# Add neural network config
# option 1. resnet
# net = resnet_cifar10(image, depth=32)
# option 2. vgg
net = vgg_bn_drop(image)
out = paddle.layer.fc( def train(use_cuda, train_program, params_dirname):
input=net, size=classdim, act=paddle.activation.Softmax()) BATCH_SIZE = 128
EPOCH_NUM = 2
lbl = paddle.layer.data( train_reader = paddle.batch(
name="label", type=paddle.data_type.integer_value(classdim)) paddle.reader.shuffle(paddle.dataset.cifar.train10(), buf_size=50000),
cost = paddle.layer.classification_cost(input=out, label=lbl) batch_size=BATCH_SIZE)
# Create parameters test_reader = paddle.batch(
parameters = paddle.parameters.create(cost) paddle.dataset.cifar.test10(), batch_size=BATCH_SIZE)
# Create optimizer
momentum_optimizer = paddle.optimizer.Momentum(
momentum=0.9,
regularization=paddle.optimizer.L2Regularization(rate=0.0002 * 128),
learning_rate=0.1 / 128.0,
learning_rate_decay_a=0.1,
learning_rate_decay_b=50000 * 100,
learning_rate_schedule='discexp')
# Create trainer
trainer = paddle.trainer.SGD(
cost=cost, parameters=parameters, update_equation=momentum_optimizer)
# End batch and end pass event handler
def event_handler(event): def event_handler(event):
if isinstance(event, paddle.event.EndIteration): if isinstance(event, fluid.EndStepEvent):
if event.batch_id % 100 == 0: if event.step % 100 == 0:
print "\nPass %d, Batch %d, Cost %f, %s" % ( print("Pass %d, Batch %d, Cost %f, Acc %f" %
event.pass_id, event.batch_id, event.cost, event.metrics) (event.step, event.epoch, event.metrics[0],
event.metrics[1]))
else: else:
sys.stdout.write('.') sys.stdout.write('.')
sys.stdout.flush() sys.stdout.flush()
if isinstance(event, paddle.event.EndPass):
# save parameters if isinstance(event, fluid.EndEpochEvent):
with open('params_pass_%d.tar' % event.pass_id, 'w') as f: avg_cost, accuracy = trainer.test(
trainer.save_parameter_to_tar(f) reader=test_reader, feed_order=['pixel', 'label'])
result = trainer.test( print('Loss {0:2.2}, Acc {1:2.2}'.format(avg_cost, accuracy))
reader=paddle.batch( if params_dirname is not None:
paddle.dataset.cifar.test10(), batch_size=128), trainer.save_params(params_dirname)
feeding={'image': 0,
'label': 1}) place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics) trainer = fluid.Trainer(
train_func=train_program,
# Save the inference topology to protobuf. optimizer=fluid.optimizer.Adam(learning_rate=0.001),
inference_topology = paddle.topology.Topology(layers=out) place=place)
with open("inference_topology.pkl", 'wb') as f:
inference_topology.serialize_for_inference(f)
trainer.train( trainer.train(
reader=paddle.batch( reader=train_reader,
paddle.reader.shuffle( num_epochs=EPOCH_NUM,
paddle.dataset.cifar.train10(), buf_size=50000),
batch_size=128),
num_passes=200,
event_handler=event_handler, event_handler=event_handler,
feeding={'image': 0, feed_order=['pixel', 'label'])
'label': 1})
# inference
def infer(use_cuda, inference_program, params_dirname=None):
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
inferencer = fluid.Inferencer(
infer_func=inference_program, param_path=params_dirname, place=place)
# Prepare testing data.
from PIL import Image from PIL import Image
import numpy as np import numpy as np
import os import os
...@@ -105,6 +97,7 @@ def main(): ...@@ -105,6 +97,7 @@ def main():
def load_image(file): def load_image(file):
im = Image.open(file) im = Image.open(file)
im = im.resize((32, 32), Image.ANTIALIAS) im = im.resize((32, 32), Image.ANTIALIAS)
im = np.array(im).astype(np.float32) im = np.array(im).astype(np.float32)
# The storage order of the loaded image is W(widht), # The storage order of the loaded image is W(widht),
# H(height), C(channel). PaddlePaddle requires # H(height), C(channel). PaddlePaddle requires
...@@ -114,23 +107,38 @@ def main(): ...@@ -114,23 +107,38 @@ def main():
# image is B(Blue), G(green), R(Red). But PIL open # image is B(Blue), G(green), R(Red). But PIL open
# image in RGB mode. It must swap the channel order. # image in RGB mode. It must swap the channel order.
im = im[(2, 1, 0), :, :] # BGR im = im[(2, 1, 0), :, :] # BGR
im = im.flatten()
im = im / 255.0 im = im / 255.0
# Add one dimension to mimic the list format.
im = numpy.expand_dims(im, axis=0)
return im return im
test_data = []
cur_dir = os.path.dirname(os.path.realpath(__file__)) cur_dir = os.path.dirname(os.path.realpath(__file__))
test_data.append((load_image(cur_dir + '/image/dog.png'), )) img = load_image(cur_dir + '/image/dog.png')
# inference
results = inferencer.infer({'pixel': img})
print("infer results: ", results)
def main(use_cuda):
if use_cuda and not fluid.core.is_compiled_with_cuda():
return
save_path = "image_classification_resnet.inference.model"
# users can remove the comments and change the model name train(
# with open('params_pass_50.tar', 'r') as f: use_cuda=use_cuda,
# parameters = paddle.parameters.Parameters.from_tar(f) train_program=train_network,
params_dirname=save_path)
probs = paddle.infer( infer(
output_layer=out, parameters=parameters, input=test_data) use_cuda=use_cuda,
lab = np.argsort(-probs) # probs and lab are the results of one batch data inference_program=inference_network,
print "Label of image/dog.png is: %d" % lab[0][0] params_dirname=save_path)
if __name__ == '__main__': if __name__ == '__main__':
main() # For demo purpose, the training runs on CPU
# Please change accordingly.
main(use_cuda=False)
...@@ -12,36 +12,35 @@ ...@@ -12,36 +12,35 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import paddle.v2 as paddle import paddle
import paddle.fluid as fluid
__all__ = ['vgg_bn_drop'] __all__ = ['vgg_bn_drop']
def vgg_bn_drop(input): def vgg_bn_drop(input):
def conv_block(ipt, num_filter, groups, dropouts, num_channels=None): def conv_block(ipt, num_filter, groups, dropouts):
return paddle.networks.img_conv_group( return fluid.nets.img_conv_group(
input=ipt, input=ipt,
num_channels=num_channels,
pool_size=2, pool_size=2,
pool_stride=2, pool_stride=2,
conv_num_filter=[num_filter] * groups, conv_num_filter=[num_filter] * groups,
conv_filter_size=3, conv_filter_size=3,
conv_act=paddle.activation.Relu(), conv_act='relu',
conv_with_batchnorm=True, conv_with_batchnorm=True,
conv_batchnorm_drop_rate=dropouts, conv_batchnorm_drop_rate=dropouts,
pool_type=paddle.pooling.Max()) pool_type='max')
conv1 = conv_block(input, 64, 2, [0.3, 0], 3) conv1 = conv_block(input, 64, 2, [0.3, 0])
conv2 = conv_block(conv1, 128, 2, [0.4, 0]) conv2 = conv_block(conv1, 128, 2, [0.4, 0])
conv3 = conv_block(conv2, 256, 3, [0.4, 0.4, 0]) conv3 = conv_block(conv2, 256, 3, [0.4, 0.4, 0])
conv4 = conv_block(conv3, 512, 3, [0.4, 0.4, 0]) conv4 = conv_block(conv3, 512, 3, [0.4, 0.4, 0])
conv5 = conv_block(conv4, 512, 3, [0.4, 0.4, 0]) conv5 = conv_block(conv4, 512, 3, [0.4, 0.4, 0])
drop = paddle.layer.dropout(input=conv5, dropout_rate=0.5) drop = fluid.layers.dropout(x=conv5, dropout_prob=0.5)
fc1 = paddle.layer.fc(input=drop, size=512, act=paddle.activation.Linear()) fc1 = fluid.layers.fc(input=drop, size=512, act=None)
bn = paddle.layer.batch_norm( bn = fluid.layers.batch_norm(input=fc1, act='relu')
input=fc1, drop2 = fluid.layers.dropout(x=bn, dropout_prob=0.5)
act=paddle.activation.Relu(), fc2 = fluid.layers.fc(input=drop2, size=512, act=None)
layer_attr=paddle.attr.Extra(drop_rate=0.5)) predict = fluid.layers.fc(input=fc2, size=10, act='softmax')
fc2 = paddle.layer.fc(input=bn, size=512, act=paddle.activation.Linear()) return predict
return fc2 \ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册