未验证 提交 e56a8e64 编写于 作者: C chengduo 提交者: GitHub

add multi cards example for mnist (#2311)

上级 971509fa
...@@ -13,7 +13,8 @@ ...@@ -13,7 +13,8 @@
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
import argparse
import ast
import numpy as np import numpy as np
from PIL import Image from PIL import Image
import os import os
...@@ -24,6 +25,17 @@ from paddle.fluid.dygraph.nn import Conv2D, Pool2D, FC ...@@ -24,6 +25,17 @@ from paddle.fluid.dygraph.nn import Conv2D, Pool2D, FC
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
def parse_args():
parser = argparse.ArgumentParser("Training for Mnist.")
parser.add_argument(
"--use_data_parallel",
type=ast.literal_eval,
default=False,
help="The flag indicating whether to shuffle instances in each pass.")
args = parser.parse_args()
return args
class SimpleImgConvPool(fluid.dygraph.Layer): class SimpleImgConvPool(fluid.dygraph.Layer):
def __init__(self, def __init__(self,
name_scope, name_scope,
...@@ -105,13 +117,12 @@ class MNIST(fluid.dygraph.Layer): ...@@ -105,13 +117,12 @@ class MNIST(fluid.dygraph.Layer):
return x return x
def test_train(reader, model, batch_size): def test_mnist(reader, model, batch_size):
acc_set = [] acc_set = []
avg_loss_set = [] avg_loss_set = []
for batch_id, data in enumerate(reader()): for batch_id, data in enumerate(reader()):
dy_x_data = np.array( dy_x_data = np.array([x[0].reshape(1, 28, 28)
[x[0].reshape(1, 28, 28) for x in data]).astype('float32')
for x in data]).astype('float32')
y_data = np.array( y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(batch_size, 1) [x[1] for x in data]).astype('int64').reshape(batch_size, 1)
...@@ -131,24 +142,63 @@ def test_train(reader, model, batch_size): ...@@ -131,24 +142,63 @@ def test_train(reader, model, batch_size):
return avg_loss_val_mean, acc_val_mean return avg_loss_val_mean, acc_val_mean
def train_mnist(): def inference_mnist():
with fluid.dygraph.guard():
mnist_infer = MNIST("mnist")
# load checkpoint
mnist_infer.load_dict(fluid.dygraph.load_persistables("save_dir"))
print("checkpoint loaded")
# start evaluate mode
mnist_infer.eval()
def load_image(file):
im = Image.open(file).convert('L')
im = im.resize((28, 28), Image.ANTIALIAS)
im = np.array(im).reshape(1, 1, 28, 28).astype(np.float32)
im = im / 255.0 * 2.0 - 1.0
return im
cur_dir = os.path.dirname(os.path.realpath(__file__))
tensor_img = load_image(cur_dir + '/image/infer_3.png')
results = mnist_infer(to_variable(tensor_img))
lab = np.argsort(results.numpy())
print("Inference result of image/infer_3.png is: %d" % lab[0][-1])
def train_mnist(args):
epoch_num = 5 epoch_num = 5
BATCH_SIZE = 64 BATCH_SIZE = 64
with fluid.dygraph.guard():
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
if args.use_data_parallel else fluid.CUDAPlace(0)
with fluid.dygraph.guard(place):
if args.use_data_parallel:
strategy = fluid.dygraph.parallel.prepare_context()
mnist = MNIST("mnist") mnist = MNIST("mnist")
adam = AdamOptimizer(learning_rate=0.001) adam = AdamOptimizer(learning_rate=0.001)
train_reader = paddle.batch( if args.use_data_parallel:
paddle.dataset.mnist.train(), batch_size=BATCH_SIZE, drop_last=True) mnist = fluid.dygraph.parallel.DataParallel(mnist, strategy)
if args.use_data_parallel:
train_reader = fluid.contrib.reader.distributed_sampler(
paddle.dataset.mnist.train(), batch_size=BATCH_SIZE)
else:
train_reader = paddle.batch(
paddle.dataset.mnist.train(),
batch_size=BATCH_SIZE,
drop_last=True)
test_reader = paddle.batch( test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=BATCH_SIZE, drop_last=True) paddle.dataset.mnist.test(), batch_size=BATCH_SIZE, drop_last=True)
for epoch in range(epoch_num): for epoch in range(epoch_num):
for batch_id, data in enumerate(train_reader()): for batch_id, data in enumerate(train_reader()):
dy_x_data = np.array( dy_x_data = np.array([x[0].reshape(1, 28, 28)
[x[0].reshape(1, 28, 28) for x in data]).astype('float32')
for x in data]).astype('float32')
y_data = np.array( y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(BATCH_SIZE, 1) [x[1] for x in data]).astype('int64').reshape(-1, 1)
img = to_variable(dy_x_data) img = to_variable(dy_x_data)
label = to_variable(y_data) label = to_variable(y_data)
...@@ -158,46 +208,33 @@ def train_mnist(): ...@@ -158,46 +208,33 @@ def train_mnist():
loss = fluid.layers.cross_entropy(cost, label) loss = fluid.layers.cross_entropy(cost, label)
avg_loss = fluid.layers.mean(loss) avg_loss = fluid.layers.mean(loss)
avg_loss.backward()
if args.use_data_parallel:
avg_loss = mnist.scale_loss(avg_loss)
avg_loss.backward()
mnist.apply_collective_grads()
else:
avg_loss.backward()
adam.minimize(avg_loss) adam.minimize(avg_loss)
# save checkpoint # save checkpoint
mnist.clear_gradients() mnist.clear_gradients()
if batch_id % 100 == 0: if batch_id % 100 == 0:
print("Loss at epoch {} step {}: {:}".format(epoch, batch_id, avg_loss.numpy())) print("Loss at epoch {} step {}: {:}".format(
epoch, batch_id, avg_loss.numpy()))
mnist.eval() mnist.eval()
test_cost, test_acc = test_train(test_reader, mnist, BATCH_SIZE) test_cost, test_acc = test_mnist(test_reader, mnist, BATCH_SIZE)
mnist.train() mnist.train()
print("Loss at epoch {} , Test avg_loss is: {}, acc is: {}".format(epoch, test_cost, test_acc)) print("Loss at epoch {} , Test avg_loss is: {}, acc is: {}".format(
epoch, test_cost, test_acc))
fluid.dygraph.save_persistables(mnist.state_dict(), "save_dir") fluid.dygraph.save_persistables(mnist.state_dict(), "save_dir")
print("checkpoint saved") print("checkpoint saved")
with fluid.dygraph.guard(): inference_mnist()
mnist_infer = MNIST("mnist")
# load checkpoint
mnist_infer.load_dict(
fluid.dygraph.load_persistables("save_dir"))
print("checkpoint loaded")
# start evaluate mode
mnist_infer.eval()
def load_image(file):
im = Image.open(file).convert('L')
im = im.resize((28, 28), Image.ANTIALIAS)
im = np.array(im).reshape(1, 1, 28, 28).astype(np.float32)
im = im / 255.0 * 2.0 - 1.0
return im
cur_dir = os.path.dirname(os.path.realpath(__file__))
tensor_img = load_image(cur_dir + '/image/infer_3.png')
results = mnist_infer(to_variable(tensor_img))
lab = np.argsort(results.numpy())
print("Inference result of image/infer_3.png is: %d" % lab[0][-1])
if __name__ == '__main__': if __name__ == '__main__':
train_mnist() args = parse_args()
train_mnist(args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册