未验证 提交 182d5092 编写于 作者: R ruri 提交者: GitHub

add ce for image classification (#4030)

上级 9b67ef53
......@@ -287,17 +287,25 @@ class ImageNetReader:
full_lines = [line.strip() for line in flist]
if mode != "test" and len(full_lines) < settings.batch_size:
print(
"Warning: The number of the whole data ({}) is smaller than the batch_size ({}), and drop_last is turnning on, so nothing will feed in program, Terminated now. Please reset batch_size to a smaller number or feed more data!"
.format(len(full_lines), settings.batch_size))
"Warning: The number of the whole data ({}) is smaller than the batch_size ({}), and drop_last is turnning on, so nothing will feed in program, Terminated now. Please reset batch_size to a smaller number or feed more data!".
format(len(full_lines), settings.batch_size))
os._exit(1)
if num_trainers > 1 and mode == "train":
assert self.shuffle_seed is not None, "multiprocess train, shuffle seed must be set!"
np.random.RandomState(self.shuffle_seed).shuffle(
full_lines)
elif shuffle:
np.random.shuffle(full_lines)
if not settings.enable_ce or settings.same_feed:
np.random.shuffle(full_lines)
batch_data = []
if settings.same_feed:
temp_file = full_lines[0]
print("Same images({},nums:{}) will feed in the net".format(
str(temp_file), settings.same_feed))
full_lines = []
for i in range(settings.same_feed):
full_lines.append(temp_file)
for line in full_lines:
img_path, label = line.split()
img_path = os.path.join(data_dir, img_path)
......
......@@ -54,7 +54,7 @@ def build_program(is_train, main_prog, startup_prog, args):
else:
model = models.__dict__[args.model]()
with fluid.program_guard(main_prog, startup_prog):
if args.random_seed:
if args.random_seed or args.enable_ce:
main_prog.random_seed = args.random_seed
startup_prog.random_seed = args.random_seed
with fluid.unique_name.guard():
......@@ -79,8 +79,14 @@ def build_program(is_train, main_prog, startup_prog, args):
return loss_out
def validate(args, test_iter, exe, test_prog, test_fetch_list, pass_id,
train_batch_metrics_record):
def validate(args,
test_iter,
exe,
test_prog,
test_fetch_list,
pass_id,
train_batch_metrics_record,
train_batch_time_record=None):
test_batch_time_record = []
test_batch_metrics_record = []
test_batch_id = 0
......@@ -96,12 +102,11 @@ def validate(args, test_iter, exe, test_prog, test_fetch_list, pass_id,
test_batch_metrics_avg = np.mean(np.array(test_batch_metrics), axis=1)
test_batch_metrics_record.append(test_batch_metrics_avg)
print_info(pass_id, test_batch_id, args.print_step,
test_batch_metrics_avg, test_batch_elapse, "batch")
print_info("batch", test_batch_metrics_avg, test_batch_elapse, pass_id,
test_batch_id, args.print_step)
sys.stdout.flush()
test_batch_id += 1
#train_epoch_time_avg = np.mean(np.array(train_batch_time_record))
train_epoch_metrics_avg = np.mean(
np.array(train_batch_metrics_record), axis=0)
......@@ -109,9 +114,18 @@ def validate(args, test_iter, exe, test_prog, test_fetch_list, pass_id,
test_epoch_metrics_avg = np.mean(
np.array(test_batch_metrics_record), axis=0)
print_info(pass_id, 0, 0,
list(train_epoch_metrics_avg) + list(test_epoch_metrics_avg),
test_epoch_time_avg, "epoch")
print_info(
"epoch",
list(train_epoch_metrics_avg) + list(test_epoch_metrics_avg),
test_epoch_time_avg,
pass_id=pass_id)
if args.enable_ce:
device_num = fluid.core.get_cuda_device_count() if args.use_gpu else 1
print_info(
"ce",
list(train_epoch_metrics_avg) + list(test_epoch_metrics_avg),
train_batch_time_record,
device_num=device_num)
def train(args):
......@@ -207,8 +221,8 @@ def train(args):
np.array(train_batch_metrics), axis=1)
train_batch_metrics_record.append(train_batch_metrics_avg)
if trainer_id == 0:
print_info(pass_id, train_batch_id, args.print_step,
train_batch_metrics_avg, train_batch_elapse, "batch")
print_info("batch", train_batch_metrics_avg, train_batch_elapse,
pass_id, train_batch_id, args.print_step)
sys.stdout.flush()
train_batch_id += 1
t1 = time.time()
......@@ -232,7 +246,7 @@ def train(args):
print('ExponentialMovingAverage validate over!')
validate(args, test_iter, exe, test_prog, test_fetch_list, pass_id,
train_batch_metrics_record)
train_batch_metrics_record, train_batch_time_record)
#For now, save model per epoch.
if pass_id % args.save_step == 0:
save_model(args, exe, train_prog, pass_id)
......
......@@ -136,7 +136,9 @@ def parse_args():
add_arg('label_smoothing_epsilon', float, 0.1, "The value of label_smoothing_epsilon parameter")
#NOTE: (2019/08/08) temporary disable use_distill
#add_arg('use_distill', bool, False, "Whether to use distill")
add_arg("enable_ce", bool, False, "Whether to enable ce")
add_arg('random_seed', int, None, "random seed")
add_arg('use_ema', bool, False, "Whether to use ExponentialMovingAverage.")
add_arg('ema_decay', float, 0.9999, "The value of ema decay rate")
add_arg('padding_type', str, "SAME", "Padding type of convolution")
......@@ -146,6 +148,7 @@ def parse_args():
add_arg('profiler_path', str, './', "the profiler output file path.(used for benchmark)")
add_arg('max_iter', int, 0, "the max train batch num.(used for benchmark)")
add_arg('validate', int, 1, "whether validate.(used for benchmark)")
add_arg('same_feed', int, 0, "whether to feed same images")
# yapf: enable
......@@ -263,6 +266,10 @@ def check_args(args):
args.data_dir
), "Data doesn't exist in {}, please load right path".format(args.data_dir)
if args.enable_ce:
args.random_seed = 0
print("CE is running now!")
#check gpu
check_gpu()
......@@ -344,7 +351,13 @@ def create_data_loader(is_train, args):
return data_loader, [feed_image, feed_label]
def print_info(pass_id, batch_id, print_step, metrics, time_info, info_mode):
def print_info(info_mode,
metrics,
time_info,
pass_id=0,
batch_id=0,
print_step=1,
device_num=1):
"""print function
Args:
......@@ -355,6 +368,7 @@ def print_info(pass_id, batch_id, print_step, metrics, time_info, info_mode):
time_info: time infomation
info_mode: mode
"""
#XXX: Use specific name to choose pattern, not the length of metrics.
if info_mode == "batch":
if batch_id % print_step == 0:
#if isinstance(metrics,np.ndarray):
......@@ -402,11 +416,34 @@ def print_info(pass_id, batch_id, print_step, metrics, time_info, info_mode):
"%.5f" % test_acc5))
sys.stdout.flush()
elif info_mode == "ce":
raise Warning("CE code is not ready")
assert len(
metrics
) == 7, "Enable CE: The Metrics should contain train_loss, train_acc1, train_acc5, test_loss, test_acc1, test_acc5, and train_speed"
assert len(
time_info
) > 10, "0~9th batch statistics will drop when doing benchmark or ce, because it might be mixed with startup time, so please make sure training at least 10 batches."
print_ce(device_num, metrics, time_info)
#raise Warning("CE code is not ready")
else:
raise Exception("Illegal info_mode")
def print_ce(device_num, metrics, time_info):
""" Print log for CE(for internal test).
"""
train_loss, train_acc1, train_acc5, _, test_loss, test_acc1, test_acc5 = metrics
train_speed = np.mean(np.array(time_info[10:]))
print("kpis\ttrain_cost_card{}\t{}".format(device_num, train_loss))
print("kpis\ttrain_acc1_card{}\t{}".format(device_num, train_acc1))
print("kpis\ttrain_acc5_card{}\t{}".format(device_num, train_acc5))
print("kpis\ttest_loss_card{}\t{}".format(device_num, test_loss))
print("kpis\ttest_acc1_card{}\t{}".format(device_num, test_acc1))
print("kpis\ttest_acc5_card{}\t{}".format(device_num, test_acc5))
print("kpis\ttrain_speed_card{}\t{}".format(device_num, train_speed))
def best_strategy_compiled(args, program, loss, exe):
"""make a program which wrapped by a compiled program
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册