diff --git a/dygraph/seq2seq/train.py b/dygraph/seq2seq/train.py index 7d4134cbd00cb0b2de0faad6fe9a91e3a9b485fc..18fdee1da96cc9631ac3357164c349fc239f4f14 100644 --- a/dygraph/seq2seq/train.py +++ b/dygraph/seq2seq/train.py @@ -158,11 +158,13 @@ def main(): total_loss = 0 word_count = 0.0 batch_times = [] + total_reader_cost = 0.0 interval_time_start = time.time() batch_start = time.time() for batch_id, batch in enumerate(train_data_iter): batch_reader_end = time.time() + total_reader_cost += batch_reader_end - batch_start input_data_feed, word_num = prepare_input( batch, epoch_id=epoch_id) @@ -178,14 +180,15 @@ def main(): batch_times.append(train_batch_cost) if batch_id > 0 and batch_id % 100 == 0: print( - "-- Epoch:[%d]; Batch:[%d]; ppl: %.5f, batch_cost: %.5f s, reader_cost: %.5f s, speed: %.5f words/s" + "-- Epoch:[%d]; Batch:[%d]; ppl: %.5f, batch_cost: %.5f s, reader_cost: %.5f s, ips: %.5f words/s" % (epoch_id, batch_id, np.exp(total_loss.numpy() / word_count), - train_batch_cost, batch_reader_end - batch_start, + train_batch_cost, total_reader_cost / 100, word_count / (time.time() - interval_time_start))) ce_ppl.append(np.exp(total_loss.numpy() / word_count)) total_loss = 0.0 word_count = 0.0 + total_reader_cost = 0.0 interval_time_start = time.time() batch_start = time.time()