From 27ab65ef8b640e0fb23e947c7f6f0a8144dfbb92 Mon Sep 17 00:00:00 2001 From: hysunflower <52739577+hysunflower@users.noreply.github.com> Date: Mon, 11 May 2020 18:58:42 +0800 Subject: [PATCH] add_maxiter_for_mobilenet (#4603) --- dygraph/mobilenet/train.py | 7 +++++++ dygraph/mobilenet/utils/utility.py | 4 ++++ 2 files changed, 11 insertions(+) diff --git a/dygraph/mobilenet/train.py b/dygraph/mobilenet/train.py index fbf5d54b..16e27dc4 100644 --- a/dygraph/mobilenet/train.py +++ b/dygraph/mobilenet/train.py @@ -128,6 +128,7 @@ def train_mobilenet(): test_data_loader.set_sample_list_generator(test_reader, place) # 4. train loop + total_batch_num = 0 #this is for benchmark for eop in range(args.num_epochs): if num_trainers > 1: imagenet_reader.set_shuffle_seed(eop + ( @@ -142,6 +143,8 @@ def train_mobilenet(): # 4.1 for each batch, call net() , backward(), and minimize() for img, label in train_data_loader(): t1 = time.time() + if args.max_iter and total_batch_num == args.max_iter: + return label = to_variable(label.numpy().astype('int64').reshape( int(args.batch_size // place_num), 1)) t_start = time.time() @@ -185,6 +188,10 @@ def train_mobilenet(): total_sample += 1 batch_id += 1 t_last = time.time() + + # NOTE: used for benchmark + total_batch_num = total_batch_num + 1 + if args.ce: print("kpis\ttrain_acc1\t%0.3f" % (total_acc1 / total_sample)) print("kpis\ttrain_acc5\t%0.3f" % (total_acc5 / total_sample)) diff --git a/dygraph/mobilenet/utils/utility.py b/dygraph/mobilenet/utils/utility.py index 53678ebb..a7bc9c88 100644 --- a/dygraph/mobilenet/utils/utility.py +++ b/dygraph/mobilenet/utils/utility.py @@ -117,6 +117,10 @@ def parse_args(): add_arg('drop_connect_rate', float, 0.2, "The value of drop connect rate") parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step") + # NOTE: used for benchmark + add_arg('max_iter', int, 0, "The number of total train max_iters.") + + # READER AND PREPROCESS add_arg('lower_scale', float, 0.08, "The value of lower_scale in ramdom_crop") add_arg('lower_ratio', float, 3./4., "The value of lower_ratio in ramdom_crop") -- GitLab