提交 a85a6014 编写于 作者: B baiyfbupt

update ce kpi

上级 35f9269a
......@@ -149,7 +149,7 @@ def train(args,
save_model('best_model')
print("Pass {0}, test map {1}".format(pass_id, test_map))
return best_map
'''
def ce_map(pass_id, best_map):
_, accum_map = map_eval.get_map_var()
map_eval.reset(exe)
......@@ -161,7 +161,7 @@ def train(args,
if batch_id % 20 == 0:
every_train_map.append(out)
train_map = np.mean(every_train_map)
_, accum_map = map_eval.get_map_var()
map_eval.reset(exe)
every_test_map = []
......@@ -173,7 +173,7 @@ def train(args,
every_test_map.append(out)
test_map = np.mean(every_test_map)
return (train_map, test_map)
'''
train_num = 0
total_train_time = 0.0
for pass_id in range(num_passes):
......@@ -207,10 +207,10 @@ def train(args,
best_map = test(pass_id, best_map)
if args.for_model_ce:
#map_kpi = ce_map(pass_id, best_map)
#print ("kpis train_acc %f" % train_avg_acc)
train_avg_acc, test_avg_acc = ce_map(pass_id, best_map)
print ("kpis train_acc %f" % train_avg_acc)
print ("kpis train_cost %f" % train_avg_loss)
#print ("kpis test_acc %f" % test_avg_acc)
print ("kpis test_acc %f" % test_avg_acc)
print ("kpis train_duration %f" % (end_time - start_time))
if pass_id % 10 == 0 or pass_id == num_passes - 1:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册