提交 f35daf3e 编写于 作者: B baiyfbupt

update ce

上级 2a1124c6
...@@ -11,6 +11,11 @@ import reader ...@@ -11,6 +11,11 @@ import reader
from mobilenet_ssd import mobile_net from mobilenet_ssd import mobile_net
from utility import add_arguments, print_arguments from utility import add_arguments, print_arguments
SEED = 90
# random seed must set before configuring the network.
fluid.default_startup_program().random_seed = SEED
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser) add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable # yapf: disable
...@@ -150,31 +155,9 @@ def train(args, ...@@ -150,31 +155,9 @@ def train(args,
print("Pass {0}, test map {1}".format(pass_id, test_map)) print("Pass {0}, test map {1}".format(pass_id, test_map))
return best_map return best_map
def ce_map(pass_id, best_map): total_time = 0.0
_, accum_map = map_eval.get_map_var()
map_eval.reset(exe)
every_train_map = []
for batch_id, data in enumerate(train_reader()):
out, = exe.run(test_program,
feed=feeder.feed(data),
fetch_list=[accum_map])
if batch_id % 20 == 0:
every_train_map.append(out)
train_map = np.mean(every_train_map)
_, accum_map = map_eval.get_map_var()
map_eval.reset(exe)
every_test_map = []
for batch_id, data in enumerate(test_reader()):
out, = exe.run(test_program,
feed=feeder.feed(data),
fetch_list=[accum_map])
if batch_id % 20 == 0:
every_test_map.append(out)
test_map = np.mean(every_test_map)
return (train_map, test_map)
for pass_id in range(num_passes): for pass_id in range(num_passes):
epoch_idx = pass_id + 1
start_time = time.time() start_time = time.time()
prev_start_time = start_time prev_start_time = start_time
every_pass_loss = [] every_pass_loss = []
...@@ -200,20 +183,27 @@ def train(args, ...@@ -200,20 +183,27 @@ def train(args,
pass_id, batch_id, loss_v, start_time - prev_start_time)) pass_id, batch_id, loss_v, start_time - prev_start_time))
end_time = time.time() end_time = time.time()
best_map = test(pass_id, best_map)
if args.for_model_ce: if args.for_model_ce:
gpu_num = get_cards()
total_time += end_time - start_time
train_avg_loss = np.mean(every_pass_loss) train_avg_loss = np.mean(every_pass_loss)
train_avg_acc, test_avg_acc = ce_map(pass_id, best_map) if gpu_num == 1:
print ("kpis train_acc %f" % train_avg_acc) print ("kpis train_cost %s" % train_avg_loss)
print ("kpis train_cost %f" % train_avg_loss) print ("kpis train_speed %s" % (total_time / epoch_idx))
print ("kpis test_acc %f" % test_avg_acc) else:
print ("kpis train_duration %f" % (end_time - start_time)) print ("kpis train_cost_card%s %s" % (gpu_num, train_avg_loss))
print ("kpis train_speed_card%s %f" % (gpu_num, total_time / epoch_idx))
best_map = test(pass_id, best_map)
if pass_id % 10 == 0 or pass_id == num_passes - 1: if pass_id % 10 == 0 or pass_id == num_passes - 1:
save_model(str(pass_id)) save_model(str(pass_id))
print("Best test map {0}".format(best_map)) print("Best test map {0}".format(best_map))
def get_cards():
cards = os.environ.get('CUDA_VISIBLE_DEVICES')
num = len(cards.split(","))
return num
if __name__ == '__main__': if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册