提交 3a5f98c3 编写于 作者: H hedaoyuan

Add reader.shuffle

上级 82ec9f22
...@@ -134,7 +134,6 @@ def data_reader(): ...@@ -134,7 +134,6 @@ def data_reader():
for i, line in enumerate(fdict): for i, line in enumerate(fdict):
dictionary[line.split('\t')[0]] = i dictionary[line.split('\t')[0]] = i
print('dict len : %d' % (len(dictionary)))
for line_count, line in enumerate(fdata): for line_count, line in enumerate(fdata):
label, comment = line.strip().split('\t\t') label, comment = line.strip().split('\t\t')
label = int(label) label = int(label)
...@@ -165,7 +164,7 @@ if __name__ == '__main__': ...@@ -165,7 +164,7 @@ if __name__ == '__main__':
def event_handler(event): def event_handler(event):
if isinstance(event, paddle.event.EndIteration): if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 1 == 0: if event.batch_id % 100 == 0:
print "Pass %d, Batch %d, Cost %f, %s" % ( print "Pass %d, Batch %d, Cost %f, %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics) event.pass_id, event.batch_id, event.cost, event.metrics)
...@@ -175,7 +174,8 @@ if __name__ == '__main__': ...@@ -175,7 +174,8 @@ if __name__ == '__main__':
trainer.train( trainer.train(
reader=paddle.reader.batched( reader=paddle.reader.batched(
data_reader, batch_size=128), paddle.reader.shuffle(
data_reader, buf_size=4096), batch_size=128),
event_handler=event_handler, event_handler=event_handler,
reader_dict={'word': 0, reader_dict={'word': 0,
'label': 1}, 'label': 1},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册