未验证 提交 dbc27b84 编写于 作者: C chengduo 提交者: GitHub

Use multi process reader for dygraph (#2416)

* add multi process reader

* use distributed_batch_reader
上级 55138a40
......@@ -184,15 +184,11 @@ def train_mnist(args):
if args.use_data_parallel:
mnist = fluid.dygraph.parallel.DataParallel(mnist, strategy)
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=BATCH_SIZE, drop_last=True)
if args.use_data_parallel:
train_reader = fluid.contrib.reader.distributed_sampler(
paddle.dataset.mnist.train(),
batch_size=BATCH_SIZE * trainer_count)
else:
train_reader = paddle.batch(
paddle.dataset.mnist.train(),
batch_size=BATCH_SIZE,
drop_last=True)
train_reader = fluid.contrib.reader.distributed_batch_reader(
train_reader)
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=BATCH_SIZE, drop_last=True)
......
......@@ -282,14 +282,11 @@ def train_resnet():
if args.use_data_parallel:
resnet = fluid.dygraph.parallel.DataParallel(resnet, strategy)
train_reader = paddle.batch(
paddle.dataset.flowers.train(use_xmap=False), batch_size=batch_size)
if args.use_data_parallel:
train_reader = fluid.contrib.reader.distributed_sampler(
paddle.dataset.flowers.train(use_xmap=False),
batch_size=batch_size * trainer_count)
else:
train_reader = paddle.batch(
paddle.dataset.flowers.train(use_xmap=False),
batch_size=batch_size)
train_reader = fluid.contrib.reader.distributed_batch_reader(
train_reader)
test_reader = paddle.batch(
paddle.dataset.flowers.test(use_xmap=False), batch_size=batch_size)
......
......@@ -1119,16 +1119,12 @@ def train():
transformer = fluid.dygraph.parallel.DataParallel(transformer,
strategy)
reader = paddle.batch(
wmt16.train(ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size),
batch_size=TrainTaskConfig.batch_size)
if args.use_data_parallel:
reader = fluid.contrib.reader.distributed_sampler(
wmt16.train(ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size),
batch_size=TrainTaskConfig.batch_size * trainer_count)
else:
reader = paddle.batch(
wmt16.train(ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size),
batch_size=TrainTaskConfig.batch_size)
reader = fluid.contrib.reader.distributed_batch_reader(reader)
for i in range(200):
dy_step = 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册