未验证 提交 27ab65ef 编写于 作者: H hysunflower 提交者: GitHub

add_maxiter_for_mobilenet (#4603)

上级 da5c04a6
......@@ -128,6 +128,7 @@ def train_mobilenet():
test_data_loader.set_sample_list_generator(test_reader, place)
# 4. train loop
total_batch_num = 0 #this is for benchmark
for eop in range(args.num_epochs):
if num_trainers > 1:
imagenet_reader.set_shuffle_seed(eop + (
......@@ -142,6 +143,8 @@ def train_mobilenet():
# 4.1 for each batch, call net() , backward(), and minimize()
for img, label in train_data_loader():
t1 = time.time()
if args.max_iter and total_batch_num == args.max_iter:
return
label = to_variable(label.numpy().astype('int64').reshape(
int(args.batch_size // place_num), 1))
t_start = time.time()
......@@ -185,6 +188,10 @@ def train_mobilenet():
total_sample += 1
batch_id += 1
t_last = time.time()
# NOTE: used for benchmark
total_batch_num = total_batch_num + 1
if args.ce:
print("kpis\ttrain_acc1\t%0.3f" % (total_acc1 / total_sample))
print("kpis\ttrain_acc5\t%0.3f" % (total_acc5 / total_sample))
......
......@@ -117,6 +117,10 @@ def parse_args():
add_arg('drop_connect_rate', float, 0.2, "The value of drop connect rate")
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
# NOTE: used for benchmark
add_arg('max_iter', int, 0, "The number of total train max_iters.")
# READER AND PREPROCESS
add_arg('lower_scale', float, 0.08, "The value of lower_scale in ramdom_crop")
add_arg('lower_ratio', float, 3./4., "The value of lower_ratio in ramdom_crop")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册