未验证 提交 6d6f9317 编写于 作者: G guochaorong 提交者: GitHub

Merge pull request #1205 from baiyfbupt/ce

Refine ce
...@@ -8,11 +8,11 @@ from kpi import CostKpi, DurationKpi, AccKpi ...@@ -8,11 +8,11 @@ from kpi import CostKpi, DurationKpi, AccKpi
#### NOTE kpi.py should shared in models in some way!!!! #### NOTE kpi.py should shared in models in some way!!!!
train_cost_kpi = CostKpi('train_cost', 0.02, 0, actived=True) train_cost_kpi = CostKpi('train_cost', 0.02, 0, actived=True)
test_acc_kpi = AccKpi('test_acc', 0.01, 0, actived=False) test_acc_kpi = AccKpi('test_acc', 0.01, 0, actived=True)
train_speed_kpi = AccKpi('train_speed', 0.2, 0, actived=False) train_speed_kpi = AccKpi('train_speed', 0.1, 0, actived=True)
train_cost_card4_kpi = CostKpi('train_cost_card4', 0.02, 0, actived=False) train_cost_card4_kpi = CostKpi('train_cost_card4', 0.02, 0, actived=True)
test_acc_card4_kpi = AccKpi('test_acc_card4', 0.01, 0, actived=False) test_acc_card4_kpi = AccKpi('test_acc_card4', 0.01, 0, actived=True)
train_speed_card4_kpi = AccKpi('train_speed_card4', 0.2, 0, actived=True) train_speed_card4_kpi = AccKpi('train_speed_card4', 0.1, 0, actived=True)
tracking_kpis = [ tracking_kpis = [
train_cost_kpi, train_cost_kpi,
......
...@@ -45,9 +45,12 @@ def build_program(is_train, main_prog, startup_prog, args, data_args, ...@@ -45,9 +45,12 @@ def build_program(is_train, main_prog, startup_prog, args, data_args,
num_classes = 21 num_classes = 21
def get_optimizer(): def get_optimizer():
if not args.enable_ce:
optimizer = fluid.optimizer.RMSProp( optimizer = fluid.optimizer.RMSProp(
learning_rate=fluid.layers.piecewise_decay(boundaries, values), learning_rate=fluid.layers.piecewise_decay(boundaries, values),
regularization=fluid.regularizer.L2Decay(0.00005), ) regularization=fluid.regularizer.L2Decay(0.00005), )
else:
optimizer = fluid.optimizer.RMSProp(learning_rate=0.001)
return optimizer return optimizer
with fluid.program_guard(main_prog, startup_prog): with fluid.program_guard(main_prog, startup_prog):
...@@ -197,12 +200,10 @@ def train(args, ...@@ -197,12 +200,10 @@ def train(args,
print("Pass {0}, test map {1}".format(pass_id, test_map)) print("Pass {0}, test map {1}".format(pass_id, test_map))
return best_map, mean_map return best_map, mean_map
total_time = 0.0
for pass_id in range(num_passes): for pass_id in range(num_passes):
epoch_idx = pass_id + 1 batch_begin = time.time()
start_time = time.time() start_time = time.time()
train_py_reader.start() train_py_reader.start()
prev_start_time = start_time
every_pass_loss = [] every_pass_loss = []
batch_id = 0 batch_id = 0
try: try:
...@@ -224,23 +225,22 @@ def train(args, ...@@ -224,23 +225,22 @@ def train(args,
break break
except fluid.core.EOFException: except fluid.core.EOFException:
train_py_reader.reset() train_py_reader.reset()
batch_end = time.time()
end_time = time.time()
best_map, mean_map = test(pass_id, best_map) best_map, mean_map = test(pass_id, best_map)
if args.enable_ce and pass_id == num_passes - 1: if args.enable_ce and pass_id == num_passes - 1:
total_time += end_time - start_time total_time = batch_end - batch_begin
train_avg_loss = np.mean(every_pass_loss) train_avg_loss = np.mean(every_pass_loss)
if devices_num == 1: if devices_num == 1:
print("kpis train_cost %s" % train_avg_loss) print("kpis train_cost %s" % train_avg_loss)
print("kpis test_acc %s" % mean_map) print("kpis test_acc %s" % mean_map)
print("kpis train_speed %s" % (total_time / epoch_idx)) print("kpis train_speed %s" % (epocs / total_time))
else: else:
print("kpis train_cost_card%s %s" % print("kpis train_cost_card%s %s" %
(devices_num, train_avg_loss)) (devices_num, train_avg_loss))
print("kpis test_acc_card%s %s" % print("kpis test_acc_card%s %s" %
(devices_num, mean_map)) (devices_num, mean_map))
print("kpis train_speed_card%s %f" % print("kpis train_speed_card%s %f" %
(devices_num, total_time / epoch_idx)) (devices_num, test_epocs / total_time))
if pass_id % 10 == 0 or pass_id == num_passes - 1: if pass_id % 10 == 0 or pass_id == num_passes - 1:
save_model(str(pass_id), train_prog) save_model(str(pass_id), train_prog)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册