提交 1f137f00 编写于 作者: W Wang,Jeff

Update the train.py

上级 fdcd800c
import os import os
from PIL import Image from PIL import Image
import numpy as np import numpy as np
import paddle.v2 as paddle import paddle
import paddle.fluid as fluid
with_gpu = os.getenv('WITH_GPU', '0') != '0'
def softmax_regression():
def softmax_regression(img): img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32')
predict = paddle.layer.fc( predict = paddle.layer.fc(
input=img, size=10, act=paddle.activation.Softmax()) input=img, size=10, act=paddle.activation.Softmax())
return predict return predict
def multilayer_perceptron(img): def multilayer_perceptron():
# The first fully-connected layer img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32')
hidden1 = paddle.layer.fc(input=img, size=128, act=paddle.activation.Relu()) # first fully-connected layer, using ReLu as its activation function
# The second fully-connected layer and the according activation function hidden = fluid.layers.fc(input=img, size=128, act='relu')
hidden2 = paddle.layer.fc( # second fully-connected layer, using ReLu as its activation function
input=hidden1, size=64, act=paddle.activation.Relu()) hidden = fluid.layers.fc(input=hidden, size=64, act='relu')
# The thrid fully-connected layer, note that the hidden size should be 10, # The thrid fully-connected layer, note that the hidden size should be 10,
# which is the number of unique digits # which is the number of unique digits
predict = paddle.layer.fc( prediction = fluid.layers.fc(input=hidden, size=10, act='softmax')
input=hidden2, size=10, act=paddle.activation.Softmax()) return prediction
return predict
def convolutional_neural_network(img): def convolutional_neural_network():
# first conv layer img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32')
conv_pool_1 = paddle.networks.simple_img_conv_pool( # first conv pool
conv_pool_1 = fluid.nets.simple_img_conv_pool(
input=img, input=img,
filter_size=5, filter_size=5,
num_filters=20, num_filters=20,
num_channel=1,
pool_size=2, pool_size=2,
pool_stride=2, pool_stride=2,
act=paddle.activation.Relu()) act="relu")
# second conv layer conv_pool_1 = fluid.layers.batch_norm(conv_pool_1)
conv_pool_2 = paddle.networks.simple_img_conv_pool( # second conv pool
conv_pool_2 = fluid.nets.simple_img_conv_pool(
input=conv_pool_1, input=conv_pool_1,
filter_size=5, filter_size=5,
num_filters=50, num_filters=50,
num_channel=20,
pool_size=2, pool_size=2,
pool_stride=2, pool_stride=2,
act=paddle.activation.Relu()) act="relu")
# fully-connected layer # output layer with softmax activation function. size = 10 since there are only 10 possible digits.
predict = paddle.layer.fc( prediction = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax')
input=conv_pool_2, size=10, act=paddle.activation.Softmax()) return prediction
return predict
def main():
paddle.init(use_gpu=with_gpu, trainer_count=1)
# define network topology def train_program():
images = paddle.layer.data( label = fluid.layers.data(name='label', shape=[1], dtype='int64')
name='pixel', type=paddle.data_type.dense_vector(784))
label = paddle.layer.data(
name='label', type=paddle.data_type.integer_value(10))
# Here we can build the prediction network in different ways. Please # Here we can build the prediction network in different ways. Please
# choose one by uncomment corresponding line. # predict = softmax_regression(images) # uncomment for Softmax
# predict = softmax_regression(images) # predict = multilayer_perceptron() # uncomment for MLP
# predict = multilayer_perceptron(images) predict = convolutional_neural_network() # uncomment for LeNet5
predict = convolutional_neural_network(images)
cost = paddle.layer.classification_cost(input=predict, label=label) # Calculate the cost from the prediction and label.
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(cost)
acc = fluid.layers.accuracy(input=predict, label=label)
return [avg_cost, acc]
parameters = paddle.parameters.create(cost)
optimizer = paddle.optimizer.Momentum( def main():
learning_rate=0.1 / 128.0, train_reader = paddle.batch(
momentum=0.9, paddle.reader.shuffle(paddle.dataset.mnist.train(), buf_size=500),
regularization=paddle.optimizer.L2Regularization(rate=0.0005 * 128)) batch_size=64)
trainer = paddle.trainer.SGD( test_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=64)
cost=cost, parameters=parameters, update_equation=optimizer)
use_cuda = os.getenv('WITH_GPU', '0') != '0'
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
trainer = fluid.Trainer(
train_func=train_program, place=place, optimizer=optimizer)
# Save the parameter into a directory. The Inferencer can load the parameters from it to do infer
params_dirname = "recognize_digits_network.inference.model"
lists = [] lists = []
def event_handler(event): def event_handler(event):
if isinstance(event, paddle.event.EndIteration): if isinstance(event, fluid.EndEpochEvent):
if event.batch_id % 100 == 0: avg_cost, acc = trainer.test(
print "Pass %d, Batch %d, Cost %f, %s" % ( reader=test_reader, feed_order=['img', 'label'])
event.pass_id, event.batch_id, event.cost, event.metrics)
if isinstance(event, paddle.event.EndPass):
# save parameters
with open('params_pass_%d.tar' % event.pass_id, 'w') as f:
trainer.save_parameter_to_tar(f)
result = trainer.test(reader=paddle.batch( print("Test with Epoch %d, avg_cost: %s, acc: %s" %
paddle.dataset.mnist.test(), batch_size=128)) (event.epoch, avg_cost, acc))
print "Test with Pass %d, Cost %f, %s\n" % (
event.pass_id, result.cost, result.metrics)
lists.append((event.pass_id, result.cost,
result.metrics['classification_error_evaluator']))
# save parameters
trainer.save_params(params_dirname)
lists.append((event.epoch, avg_cost, acc))
# Train the model now
trainer.train( trainer.train(
reader=paddle.batch( num_epochs=5,
paddle.reader.shuffle(paddle.dataset.mnist.train(), buf_size=8192),
batch_size=128),
event_handler=event_handler, event_handler=event_handler,
num_passes=5) reader=train_reader,
feed_order=['img', 'label'])
# find the best pass # find the best pass
best = sorted(lists, key=lambda list: float(list[1]))[0] best = sorted(lists, key=lambda list: float(list[1]))[0]
print 'Best pass is %s, testing Avgcost is %s' % (best[0], best[1]) print 'Best pass is %s, testing Avgcost is %s' % (best[0], best[1])
print 'The classification accuracy is %.2f%%' % (100 - float(best[2]) * 100) print 'The classification accuracy is %.2f%%' % (float(best[2]) * 100)
def load_image(file): def load_image(file):
im = Image.open(file).convert('L') im = Image.open(file).convert('L')
im = im.resize((28, 28), Image.ANTIALIAS) im = im.resize((28, 28), Image.ANTIALIAS)
im = np.array(im).astype(np.float32).flatten() im = np.array(im).reshape(1, 1, 28, 28).astype(np.float32)
im = im / 255.0 * 2.0 - 1.0 im = im / 255.0 * 2.0 - 1.0
return im return im
test_data = []
cur_dir = os.path.dirname(os.path.realpath(__file__)) cur_dir = os.path.dirname(os.path.realpath(__file__))
test_data.append((load_image(cur_dir + '/image/infer_3.png'), )) img = load_image(cur_dir + '/image/infer_3.png')
inferencer = fluid.Inferencer(
probs = paddle.infer( # infer_func=softmax_regression, # uncomment for softmax regression
output_layer=predict, parameters=parameters, input=test_data) # infer_func=multilayer_perceptron, # uncomment for MLP
lab = np.argsort(-probs) # probs and lab are the results of one batch data infer_func=convolutional_neural_network, # uncomment for LeNet5
print "Label of image/infer_3.png is: %d" % lab[0][0] param_path=params_dirname,
place=place)
results = inferencer.infer({'img': img})
lab = np.argsort(results) # probs and lab are the results of one batch data
print "Label of image/infer_3.png is: %d" % lab[0][0][-1]
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册