提交 a85a6014 编写于 作者: B baiyfbupt

update ce kpi

上级 35f9269a
...@@ -149,7 +149,7 @@ def train(args, ...@@ -149,7 +149,7 @@ def train(args,
save_model('best_model') save_model('best_model')
print("Pass {0}, test map {1}".format(pass_id, test_map)) print("Pass {0}, test map {1}".format(pass_id, test_map))
return best_map return best_map
'''
def ce_map(pass_id, best_map): def ce_map(pass_id, best_map):
_, accum_map = map_eval.get_map_var() _, accum_map = map_eval.get_map_var()
map_eval.reset(exe) map_eval.reset(exe)
...@@ -161,7 +161,7 @@ def train(args, ...@@ -161,7 +161,7 @@ def train(args,
if batch_id % 20 == 0: if batch_id % 20 == 0:
every_train_map.append(out) every_train_map.append(out)
train_map = np.mean(every_train_map) train_map = np.mean(every_train_map)
_, accum_map = map_eval.get_map_var() _, accum_map = map_eval.get_map_var()
map_eval.reset(exe) map_eval.reset(exe)
every_test_map = [] every_test_map = []
...@@ -173,7 +173,7 @@ def train(args, ...@@ -173,7 +173,7 @@ def train(args,
every_test_map.append(out) every_test_map.append(out)
test_map = np.mean(every_test_map) test_map = np.mean(every_test_map)
return (train_map, test_map) return (train_map, test_map)
'''
train_num = 0 train_num = 0
total_train_time = 0.0 total_train_time = 0.0
for pass_id in range(num_passes): for pass_id in range(num_passes):
...@@ -207,10 +207,10 @@ def train(args, ...@@ -207,10 +207,10 @@ def train(args,
best_map = test(pass_id, best_map) best_map = test(pass_id, best_map)
if args.for_model_ce: if args.for_model_ce:
#map_kpi = ce_map(pass_id, best_map) train_avg_acc, test_avg_acc = ce_map(pass_id, best_map)
#print ("kpis train_acc %f" % train_avg_acc) print ("kpis train_acc %f" % train_avg_acc)
print ("kpis train_cost %f" % train_avg_loss) print ("kpis train_cost %f" % train_avg_loss)
#print ("kpis test_acc %f" % test_avg_acc) print ("kpis test_acc %f" % test_avg_acc)
print ("kpis train_duration %f" % (end_time - start_time)) print ("kpis train_duration %f" % (end_time - start_time))
if pass_id % 10 == 0 or pass_id == num_passes - 1: if pass_id % 10 == 0 or pass_id == num_passes - 1:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册