diff --git a/demo/sentiment/train_with_new_api.py b/demo/sentiment/train_with_new_api.py index f937b029068c222ae8a87d3bc02435a2b58d17f1..59a303c0d58a02aa4b8246c757d27d45df71d38a 100644 --- a/demo/sentiment/train_with_new_api.py +++ b/demo/sentiment/train_with_new_api.py @@ -134,7 +134,6 @@ def data_reader(): for i, line in enumerate(fdict): dictionary[line.split('\t')[0]] = i - print('dict len : %d' % (len(dictionary))) for line_count, line in enumerate(fdata): label, comment = line.strip().split('\t\t') label = int(label) @@ -165,7 +164,7 @@ if __name__ == '__main__': def event_handler(event): if isinstance(event, paddle.event.EndIteration): - if event.batch_id % 1 == 0: + if event.batch_id % 100 == 0: print "Pass %d, Batch %d, Cost %f, %s" % ( event.pass_id, event.batch_id, event.cost, event.metrics) @@ -175,7 +174,8 @@ if __name__ == '__main__': trainer.train( reader=paddle.reader.batched( - data_reader, batch_size=128), + paddle.reader.shuffle( + data_reader, buf_size=4096), batch_size=128), event_handler=event_handler, reader_dict={'word': 0, 'label': 1},