diff --git a/PaddleCV/image_classification/train.py b/PaddleCV/image_classification/train.py index ac67d43d27f847ee2e72ad15885dd91f80cdc548..150b44b6e747efae84b1efb5803be30c4f0528ce 100755 --- a/PaddleCV/image_classification/train.py +++ b/PaddleCV/image_classification/train.py @@ -234,6 +234,7 @@ def train(args): train_batch_id = 0 train_batch_time_record = [] train_batch_metrics_record = [] + train_batch_time_print_step = [] if not args.use_dali: train_iter = train_data_loader() @@ -256,8 +257,19 @@ def train(args): np.array(train_batch_metrics), axis=1) train_batch_metrics_record.append(train_batch_metrics_avg) if trainer_id == 0: - print_info("batch", train_batch_metrics_avg, train_batch_elapse, - pass_id, train_batch_id, args.print_step) + if train_batch_id % args.print_step == 0: + if len(train_batch_time_print_step) == 0: + train_batch_time_print_step_avg = train_batch_elapse + else: + train_batch_time_print_step_avg = np.mean( + train_batch_time_print_step) + train_batch_time_print_step = [] + print_info("batch", train_batch_metrics_avg, + train_batch_time_print_step_avg, pass_id, + train_batch_id, args.print_step) + else: + train_batch_time_print_step.append(train_batch_elapse) + sys.stdout.flush() train_batch_id += 1 t1 = time.time()