提交 f2898533 编写于 作者: X Xinghai Sun

Add batch processing pipeline with xmap_reader.

上级 b56a548e
......@@ -186,13 +186,18 @@ class DataGenerator(object):
for instance in instance_reader():
batch.append(instance)
if len(batch) == batch_size:
yield self._padding_batch(batch, padding_to, flatten)
yield batch
batch = []
if len(batch) >= min_batch_size:
yield self._padding_batch(batch, padding_to, flatten)
yield batch
self._epoch += 1
return batch_reader
return paddle.reader.xmap_readers(
lambda batch: self._padding_batch(batch, padding_to, flatten),
batch_reader,
process_num=1,
buffer_size=8,
order=True)
@property
def feeding(self):
......
......@@ -101,7 +101,7 @@ def train():
rnn_layer_size=args.rnn_layer_size,
use_gru=args.use_gru,
pretrained_model_path=args.init_model_path,
share_rnn_weights=args.share_weights)
share_rnn_weights=args.share_rnn_weights)
ds2_model.train(
train_batch_reader=train_batch_reader,
dev_batch_reader=dev_batch_reader,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册