diff --git a/dygraph/mobilenet/train.py b/dygraph/mobilenet/train.py index fbf5d54beac044f76228076eb5d6f13e70e252af..16e27dc4fbc22675e2446dbc5ff146e1b6b5b909 100644 --- a/dygraph/mobilenet/train.py +++ b/dygraph/mobilenet/train.py @@ -128,6 +128,7 @@ def train_mobilenet(): test_data_loader.set_sample_list_generator(test_reader, place) # 4. train loop + total_batch_num = 0 #this is for benchmark for eop in range(args.num_epochs): if num_trainers > 1: imagenet_reader.set_shuffle_seed(eop + ( @@ -142,6 +143,8 @@ def train_mobilenet(): # 4.1 for each batch, call net() , backward(), and minimize() for img, label in train_data_loader(): t1 = time.time() + if args.max_iter and total_batch_num == args.max_iter: + return label = to_variable(label.numpy().astype('int64').reshape( int(args.batch_size // place_num), 1)) t_start = time.time() @@ -185,6 +188,10 @@ def train_mobilenet(): total_sample += 1 batch_id += 1 t_last = time.time() + + # NOTE: used for benchmark + total_batch_num = total_batch_num + 1 + if args.ce: print("kpis\ttrain_acc1\t%0.3f" % (total_acc1 / total_sample)) print("kpis\ttrain_acc5\t%0.3f" % (total_acc5 / total_sample)) diff --git a/dygraph/mobilenet/utils/utility.py b/dygraph/mobilenet/utils/utility.py index 53678ebb72010c44e70a43b0b084ad54a72e6ca1..a7bc9c883edba2e6115d3fe96a61e569b5d7407a 100644 --- a/dygraph/mobilenet/utils/utility.py +++ b/dygraph/mobilenet/utils/utility.py @@ -117,6 +117,10 @@ def parse_args(): add_arg('drop_connect_rate', float, 0.2, "The value of drop connect rate") parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step") + # NOTE: used for benchmark + add_arg('max_iter', int, 0, "The number of total train max_iters.") + + # READER AND PREPROCESS add_arg('lower_scale', float, 0.08, "The value of lower_scale in ramdom_crop") add_arg('lower_ratio', float, 3./4., "The value of lower_ratio in ramdom_crop")