未验证 提交 182d5092 编写于 作者: R ruri 提交者: GitHub

add ce for image classification (#4030)

上级 9b67ef53
...@@ -287,17 +287,25 @@ class ImageNetReader: ...@@ -287,17 +287,25 @@ class ImageNetReader:
full_lines = [line.strip() for line in flist] full_lines = [line.strip() for line in flist]
if mode != "test" and len(full_lines) < settings.batch_size: if mode != "test" and len(full_lines) < settings.batch_size:
print( print(
"Warning: The number of the whole data ({}) is smaller than the batch_size ({}), and drop_last is turnning on, so nothing will feed in program, Terminated now. Please reset batch_size to a smaller number or feed more data!" "Warning: The number of the whole data ({}) is smaller than the batch_size ({}), and drop_last is turnning on, so nothing will feed in program, Terminated now. Please reset batch_size to a smaller number or feed more data!".
.format(len(full_lines), settings.batch_size)) format(len(full_lines), settings.batch_size))
os._exit(1) os._exit(1)
if num_trainers > 1 and mode == "train": if num_trainers > 1 and mode == "train":
assert self.shuffle_seed is not None, "multiprocess train, shuffle seed must be set!" assert self.shuffle_seed is not None, "multiprocess train, shuffle seed must be set!"
np.random.RandomState(self.shuffle_seed).shuffle( np.random.RandomState(self.shuffle_seed).shuffle(
full_lines) full_lines)
elif shuffle: elif shuffle:
np.random.shuffle(full_lines) if not settings.enable_ce or settings.same_feed:
np.random.shuffle(full_lines)
batch_data = [] batch_data = []
if settings.same_feed:
temp_file = full_lines[0]
print("Same images({},nums:{}) will feed in the net".format(
str(temp_file), settings.same_feed))
full_lines = []
for i in range(settings.same_feed):
full_lines.append(temp_file)
for line in full_lines: for line in full_lines:
img_path, label = line.split() img_path, label = line.split()
img_path = os.path.join(data_dir, img_path) img_path = os.path.join(data_dir, img_path)
......
...@@ -54,7 +54,7 @@ def build_program(is_train, main_prog, startup_prog, args): ...@@ -54,7 +54,7 @@ def build_program(is_train, main_prog, startup_prog, args):
else: else:
model = models.__dict__[args.model]() model = models.__dict__[args.model]()
with fluid.program_guard(main_prog, startup_prog): with fluid.program_guard(main_prog, startup_prog):
if args.random_seed: if args.random_seed or args.enable_ce:
main_prog.random_seed = args.random_seed main_prog.random_seed = args.random_seed
startup_prog.random_seed = args.random_seed startup_prog.random_seed = args.random_seed
with fluid.unique_name.guard(): with fluid.unique_name.guard():
...@@ -79,8 +79,14 @@ def build_program(is_train, main_prog, startup_prog, args): ...@@ -79,8 +79,14 @@ def build_program(is_train, main_prog, startup_prog, args):
return loss_out return loss_out
def validate(args, test_iter, exe, test_prog, test_fetch_list, pass_id, def validate(args,
train_batch_metrics_record): test_iter,
exe,
test_prog,
test_fetch_list,
pass_id,
train_batch_metrics_record,
train_batch_time_record=None):
test_batch_time_record = [] test_batch_time_record = []
test_batch_metrics_record = [] test_batch_metrics_record = []
test_batch_id = 0 test_batch_id = 0
...@@ -96,12 +102,11 @@ def validate(args, test_iter, exe, test_prog, test_fetch_list, pass_id, ...@@ -96,12 +102,11 @@ def validate(args, test_iter, exe, test_prog, test_fetch_list, pass_id,
test_batch_metrics_avg = np.mean(np.array(test_batch_metrics), axis=1) test_batch_metrics_avg = np.mean(np.array(test_batch_metrics), axis=1)
test_batch_metrics_record.append(test_batch_metrics_avg) test_batch_metrics_record.append(test_batch_metrics_avg)
print_info(pass_id, test_batch_id, args.print_step, print_info("batch", test_batch_metrics_avg, test_batch_elapse, pass_id,
test_batch_metrics_avg, test_batch_elapse, "batch") test_batch_id, args.print_step)
sys.stdout.flush() sys.stdout.flush()
test_batch_id += 1 test_batch_id += 1
#train_epoch_time_avg = np.mean(np.array(train_batch_time_record))
train_epoch_metrics_avg = np.mean( train_epoch_metrics_avg = np.mean(
np.array(train_batch_metrics_record), axis=0) np.array(train_batch_metrics_record), axis=0)
...@@ -109,9 +114,18 @@ def validate(args, test_iter, exe, test_prog, test_fetch_list, pass_id, ...@@ -109,9 +114,18 @@ def validate(args, test_iter, exe, test_prog, test_fetch_list, pass_id,
test_epoch_metrics_avg = np.mean( test_epoch_metrics_avg = np.mean(
np.array(test_batch_metrics_record), axis=0) np.array(test_batch_metrics_record), axis=0)
print_info(pass_id, 0, 0, print_info(
list(train_epoch_metrics_avg) + list(test_epoch_metrics_avg), "epoch",
test_epoch_time_avg, "epoch") list(train_epoch_metrics_avg) + list(test_epoch_metrics_avg),
test_epoch_time_avg,
pass_id=pass_id)
if args.enable_ce:
device_num = fluid.core.get_cuda_device_count() if args.use_gpu else 1
print_info(
"ce",
list(train_epoch_metrics_avg) + list(test_epoch_metrics_avg),
train_batch_time_record,
device_num=device_num)
def train(args): def train(args):
...@@ -207,8 +221,8 @@ def train(args): ...@@ -207,8 +221,8 @@ def train(args):
np.array(train_batch_metrics), axis=1) np.array(train_batch_metrics), axis=1)
train_batch_metrics_record.append(train_batch_metrics_avg) train_batch_metrics_record.append(train_batch_metrics_avg)
if trainer_id == 0: if trainer_id == 0:
print_info(pass_id, train_batch_id, args.print_step, print_info("batch", train_batch_metrics_avg, train_batch_elapse,
train_batch_metrics_avg, train_batch_elapse, "batch") pass_id, train_batch_id, args.print_step)
sys.stdout.flush() sys.stdout.flush()
train_batch_id += 1 train_batch_id += 1
t1 = time.time() t1 = time.time()
...@@ -232,7 +246,7 @@ def train(args): ...@@ -232,7 +246,7 @@ def train(args):
print('ExponentialMovingAverage validate over!') print('ExponentialMovingAverage validate over!')
validate(args, test_iter, exe, test_prog, test_fetch_list, pass_id, validate(args, test_iter, exe, test_prog, test_fetch_list, pass_id,
train_batch_metrics_record) train_batch_metrics_record, train_batch_time_record)
#For now, save model per epoch. #For now, save model per epoch.
if pass_id % args.save_step == 0: if pass_id % args.save_step == 0:
save_model(args, exe, train_prog, pass_id) save_model(args, exe, train_prog, pass_id)
......
...@@ -136,7 +136,9 @@ def parse_args(): ...@@ -136,7 +136,9 @@ def parse_args():
add_arg('label_smoothing_epsilon', float, 0.1, "The value of label_smoothing_epsilon parameter") add_arg('label_smoothing_epsilon', float, 0.1, "The value of label_smoothing_epsilon parameter")
#NOTE: (2019/08/08) temporary disable use_distill #NOTE: (2019/08/08) temporary disable use_distill
#add_arg('use_distill', bool, False, "Whether to use distill") #add_arg('use_distill', bool, False, "Whether to use distill")
add_arg("enable_ce", bool, False, "Whether to enable ce")
add_arg('random_seed', int, None, "random seed") add_arg('random_seed', int, None, "random seed")
add_arg('use_ema', bool, False, "Whether to use ExponentialMovingAverage.") add_arg('use_ema', bool, False, "Whether to use ExponentialMovingAverage.")
add_arg('ema_decay', float, 0.9999, "The value of ema decay rate") add_arg('ema_decay', float, 0.9999, "The value of ema decay rate")
add_arg('padding_type', str, "SAME", "Padding type of convolution") add_arg('padding_type', str, "SAME", "Padding type of convolution")
...@@ -146,6 +148,7 @@ def parse_args(): ...@@ -146,6 +148,7 @@ def parse_args():
add_arg('profiler_path', str, './', "the profiler output file path.(used for benchmark)") add_arg('profiler_path', str, './', "the profiler output file path.(used for benchmark)")
add_arg('max_iter', int, 0, "the max train batch num.(used for benchmark)") add_arg('max_iter', int, 0, "the max train batch num.(used for benchmark)")
add_arg('validate', int, 1, "whether validate.(used for benchmark)") add_arg('validate', int, 1, "whether validate.(used for benchmark)")
add_arg('same_feed', int, 0, "whether to feed same images")
# yapf: enable # yapf: enable
...@@ -263,6 +266,10 @@ def check_args(args): ...@@ -263,6 +266,10 @@ def check_args(args):
args.data_dir args.data_dir
), "Data doesn't exist in {}, please load right path".format(args.data_dir) ), "Data doesn't exist in {}, please load right path".format(args.data_dir)
if args.enable_ce:
args.random_seed = 0
print("CE is running now!")
#check gpu #check gpu
check_gpu() check_gpu()
...@@ -344,7 +351,13 @@ def create_data_loader(is_train, args): ...@@ -344,7 +351,13 @@ def create_data_loader(is_train, args):
return data_loader, [feed_image, feed_label] return data_loader, [feed_image, feed_label]
def print_info(pass_id, batch_id, print_step, metrics, time_info, info_mode): def print_info(info_mode,
metrics,
time_info,
pass_id=0,
batch_id=0,
print_step=1,
device_num=1):
"""print function """print function
Args: Args:
...@@ -355,6 +368,7 @@ def print_info(pass_id, batch_id, print_step, metrics, time_info, info_mode): ...@@ -355,6 +368,7 @@ def print_info(pass_id, batch_id, print_step, metrics, time_info, info_mode):
time_info: time infomation time_info: time infomation
info_mode: mode info_mode: mode
""" """
#XXX: Use specific name to choose pattern, not the length of metrics.
if info_mode == "batch": if info_mode == "batch":
if batch_id % print_step == 0: if batch_id % print_step == 0:
#if isinstance(metrics,np.ndarray): #if isinstance(metrics,np.ndarray):
...@@ -402,11 +416,34 @@ def print_info(pass_id, batch_id, print_step, metrics, time_info, info_mode): ...@@ -402,11 +416,34 @@ def print_info(pass_id, batch_id, print_step, metrics, time_info, info_mode):
"%.5f" % test_acc5)) "%.5f" % test_acc5))
sys.stdout.flush() sys.stdout.flush()
elif info_mode == "ce": elif info_mode == "ce":
raise Warning("CE code is not ready") assert len(
metrics
) == 7, "Enable CE: The Metrics should contain train_loss, train_acc1, train_acc5, test_loss, test_acc1, test_acc5, and train_speed"
assert len(
time_info
) > 10, "0~9th batch statistics will drop when doing benchmark or ce, because it might be mixed with startup time, so please make sure training at least 10 batches."
print_ce(device_num, metrics, time_info)
#raise Warning("CE code is not ready")
else: else:
raise Exception("Illegal info_mode") raise Exception("Illegal info_mode")
def print_ce(device_num, metrics, time_info):
""" Print log for CE(for internal test).
"""
train_loss, train_acc1, train_acc5, _, test_loss, test_acc1, test_acc5 = metrics
train_speed = np.mean(np.array(time_info[10:]))
print("kpis\ttrain_cost_card{}\t{}".format(device_num, train_loss))
print("kpis\ttrain_acc1_card{}\t{}".format(device_num, train_acc1))
print("kpis\ttrain_acc5_card{}\t{}".format(device_num, train_acc5))
print("kpis\ttest_loss_card{}\t{}".format(device_num, test_loss))
print("kpis\ttest_acc1_card{}\t{}".format(device_num, test_acc1))
print("kpis\ttest_acc5_card{}\t{}".format(device_num, test_acc5))
print("kpis\ttrain_speed_card{}\t{}".format(device_num, train_speed))
def best_strategy_compiled(args, program, loss, exe): def best_strategy_compiled(args, program, loss, exe):
"""make a program which wrapped by a compiled program """make a program which wrapped by a compiled program
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册