diff --git a/dygraph/resnet/train.py b/dygraph/resnet/train.py index d21f650b710c2cbe31415de7b434ccce80f9baf4..e92a39bde5bce633dda9452d5c0dad3399092248 100644 --- a/dygraph/resnet/train.py +++ b/dygraph/resnet/train.py @@ -25,6 +25,7 @@ from paddle.fluid import framework import math import sys +import time IMAGENET1000 = 1281167 base_lr = 0.1 @@ -45,6 +46,9 @@ def parse_args(): parser.add_argument( "-b", "--batch_size", default=32, type=int, help="set epoch") parser.add_argument("--ce", action="store_true", help="run ce") + + # NOTE:used in benchmark + parser.add_argument("--max_iter", default=0, type=int, help="the max iters to train, used in benchmark") args = parser.parse_args() return args @@ -310,6 +314,9 @@ def train_resnet(): #file_name = './model/epoch_0.npz' #model_data = np.load( file_name ) + #NOTE: used in benchmark + total_batch_num = 0 + for eop in range(epoch): resnet.train() @@ -325,6 +332,12 @@ def train_resnet(): print("load finished") for batch_id, data in enumerate(train_reader()): + + #NOTE: used in benchmark + if args.max_iter and total_batch_num == args.max_iter: + return + batch_start = time.time() + dy_x_data = np.array( [x[0].reshape(3, 224, 224) for x in data]).astype('float32') if len(np.array([x[1] @@ -356,15 +369,18 @@ def train_resnet(): optimizer.minimize(avg_loss) resnet.clear_gradients() + batch_end = time.time() + train_batch_cost = batch_end - batch_start total_loss += dy_out total_acc1 += acc_top1.numpy() total_acc5 += acc_top5.numpy() total_sample += 1 + total_batch_num = total_batch_num + 1 #this is for benchmark #print("epoch id: %d, batch step: %d, loss: %f" % (eop, batch_id, dy_out)) if batch_id % 10 == 0: - print( "epoch %d | batch step %d, loss %0.3f acc1 %0.3f acc5 %0.3f" % \ + print( "epoch %d | batch step %d, loss %0.3f acc1 %0.3f acc5 %0.3f, batch cost: %.5f" % \ ( eop, batch_id, total_loss / total_sample, \ - total_acc1 / total_sample, total_acc5 / total_sample)) + total_acc1 / total_sample, total_acc5 / total_sample, train_batch_cost)) if args.ce: print("kpis\ttrain_acc1\t%0.3f" % (total_acc1 / total_sample))