提交 35f9269a 编写于 作者: B baiyfbupt

add ce

上级 1ae49ef4
cp -r ./data/pascalvoc/. /home/.cache/paddle/dataset/pascalvoc
...@@ -5,7 +5,7 @@ export CUDA_VISIBLE_DEVICES=$cudaid ...@@ -5,7 +5,7 @@ export CUDA_VISIBLE_DEVICES=$cudaid
if [ ! -d "/root/.cache/paddle/dataset/pascalvoc" ];then if [ ! -d "/root/.cache/paddle/dataset/pascalvoc" ];then
mkdir -p /root/.cache/paddle/dataset/pascalvoc mkdir -p /root/.cache/paddle/dataset/pascalvoc
./data/pascalvoc/download.sh #./data/pascalvoc/download.sh
bash ./.move.sh cp -r ./data/pascalvoc/. /home/.cache/paddle/dataset/pascalvoc
fi fi
FLAGS_benchmark=true python train.py --batch_size=64 --num_passes=2 --for_model_ce=True --data_dir=/root/.cache/paddle/dataset/pascalvoc/ FLAGS_benchmark=true python train.py --for_model_ce=True --batch_size=64 --num_passes=2 --data_dir=/root/.cache/paddle/dataset/pascalvoc/ | python _ce.py
####this file is only used for continuous evaluation test!
import os
import sys
sys.path.append(os.environ['ceroot'])
from kpi import CostKpi, DurationKpi, AccKpi
#### NOTE kpi.py should shared in models in some way!!!!
train_cost_kpi = CostKpi('train_cost', 0.02, actived=True)
test_acc_kpi = AccKpi('test_acc', 0.005, actived=True)
train_duration_kpi = DurationKpi('train_duration', 0.06, actived=True)
train_acc_kpi = AccKpi('train_acc', 0.005, actived=True)
tracking_kpis = [
train_acc_kpi,
train_cost_kpi,
test_acc_kpi,
train_duration_kpi,
]
def parse_log(log):
'''
This method should be implemented by model developers.
The suggestion:
each line in the log should be key, value, for example:
"
train_cost\t1.0
test_cost\t1.0
train_cost\t1.0
train_cost\t1.0
train_acc\t1.2
"
'''
#kpi_map = {}
for line in log.split('\n'):
fs = line.strip().split('\t')
print(fs)
if len(fs) == 3 and fs[0] == 'kpis':
print("-----%s" % fs)
kpi_name = fs[1]
kpi_value = float(fs[2])
#kpi_map[kpi_name] = kpi_value
yield kpi_name, kpi_value
#return kpi_map
def log_to_ce(log):
kpi_tracker = {}
for kpi in tracking_kpis:
kpi_tracker[kpi.name] = kpi
for (kpi_name, kpi_value) in parse_log(log):
print(kpi_name, kpi_value)
kpi_tracker[kpi_name].add_record(kpi_value)
kpi_tracker[kpi_name].persist()
if __name__ == '__main__':
log = sys.stdin.read()
print("*****")
print log
print("****")
log_to_ce(log)
...@@ -23,7 +23,7 @@ add_arg('dataset', str, 'pascalvoc', "coco2014, coco2017, and pascalv ...@@ -23,7 +23,7 @@ add_arg('dataset', str, 'pascalvoc', "coco2014, coco2017, and pascalv
add_arg('model_save_dir', str, 'model', "The path to save model.") add_arg('model_save_dir', str, 'model', "The path to save model.")
add_arg('pretrained_model', str, 'pretrained/ssd_mobilenet_v1_coco/', "The init model path.") add_arg('pretrained_model', str, 'pretrained/ssd_mobilenet_v1_coco/', "The init model path.")
add_arg('apply_distort', bool, True, "Whether apply distort.") add_arg('apply_distort', bool, True, "Whether apply distort.")
add_arg('apply_expand', bool, True, "Whether appley expand.") add_arg('apply_expand', bool, True, "Whether apply expand.")
add_arg('nms_threshold', float, 0.45, "NMS threshold.") add_arg('nms_threshold', float, 0.45, "NMS threshold.")
add_arg('ap_version', str, '11point', "integral, 11point.") add_arg('ap_version', str, '11point', "integral, 11point.")
add_arg('resize_h', int, 300, "The resized image height.") add_arg('resize_h', int, 300, "The resized image height.")
...@@ -32,10 +32,8 @@ add_arg('mean_value_B', float, 127.5, "Mean value for B channel which will ...@@ -32,10 +32,8 @@ add_arg('mean_value_B', float, 127.5, "Mean value for B channel which will
add_arg('mean_value_G', float, 127.5, "Mean value for G channel which will be subtracted.") #116.78 add_arg('mean_value_G', float, 127.5, "Mean value for G channel which will be subtracted.") #116.78
add_arg('mean_value_R', float, 127.5, "Mean value for R channel which will be subtracted.") #103.94 add_arg('mean_value_R', float, 127.5, "Mean value for R channel which will be subtracted.") #103.94
add_arg('is_toy', int, 0, "Toy for quick debug, 0 means using all data, while n means using only n sample.") add_arg('is_toy', int, 0, "Toy for quick debug, 0 means using all data, while n means using only n sample.")
add_arg('for_model_ce', bool, False, "Use CE to evaluate the model")
add_arg('data_dir', str, 'data/pascalvoc', "data directory") add_arg('data_dir', str, 'data/pascalvoc', "data directory")
add_arg('skip_batch_num', int, 5, "the num of minibatch to skip.") add_arg('for_model_ce', bool, False, "Use CE to evaluate the model")
add_arg('iterations', int, 120, "mini batchs.")
#yapf: enable #yapf: enable
...@@ -151,21 +149,42 @@ def train(args, ...@@ -151,21 +149,42 @@ def train(args,
save_model('best_model') save_model('best_model')
print("Pass {0}, test map {1}".format(pass_id, test_map)) print("Pass {0}, test map {1}".format(pass_id, test_map))
return best_map return best_map
'''
def ce_map(pass_id, best_map):
_, accum_map = map_eval.get_map_var()
map_eval.reset(exe)
every_train_map = []
for batch_id, data in enumerate(train_reader()):
out, = exe.run(test_program,
feed=feeder.feed(data),
fetch_list=[accum_map])
if batch_id % 20 == 0:
every_train_map.append(out)
train_map = np.mean(every_train_map)
_, accum_map = map_eval.get_map_var()
map_eval.reset(exe)
every_test_map = []
for batch_id, data in enumerate(test_reader()):
out, = exe.run(test_program,
feed=feeder.feed(data),
fetch_list=[accum_map])
if batch_id % 20 == 0:
every_test_map.append(out)
test_map = np.mean(every_test_map)
return (train_map, test_map)
'''
train_num = 0 train_num = 0
total_train_time = 0.0 total_train_time = 0.0
for pass_id in range(num_passes): for pass_id in range(num_passes):
start_time = time.time() start_time = time.time()
prev_start_time = start_time prev_start_time = start_time
# end_time = 0
every_pass_loss = [] every_pass_loss = []
iter = 0 iter = 0
pass_duration = 0.0 pass_duration = 0.0
for batch_id, data in enumerate(train_reader()): for batch_id, data in enumerate(train_reader()):
prev_start_time = start_time prev_start_time = start_time
start_time = time.time() start_time = time.time()
if args.for_model_ce and iter == args.iterations:
break
if len(data) < (devices_num * 2): if len(data) < (devices_num * 2):
print("There are too few data to train on all devices.") print("There are too few data to train on all devices.")
continue continue
...@@ -176,29 +195,24 @@ def train(args, ...@@ -176,29 +195,24 @@ def train(args,
loss_v, = exe.run(fluid.default_main_program(), loss_v, = exe.run(fluid.default_main_program(),
feed=feeder.feed(data), feed=feeder.feed(data),
fetch_list=[loss]) fetch_list=[loss])
# end_time = time.time()
loss_v = np.mean(np.array(loss_v)) loss_v = np.mean(np.array(loss_v))
if batch_id % 20 == 0: if batch_id % 20 == 0:
print("Pass {0}, batch {1}, loss {2}, time {3}".format( print("Pass {0}, batch {1}, loss {2}, time {3}".format(
pass_id, batch_id, loss_v, start_time - prev_start_time)) pass_id, batch_id, loss_v, start_time - prev_start_time))
if args.for_model_ce and iter >= args.skip_batch_num or pass_id != 0: end_time = time.time()
batch_duration = time.time() - start_time every_pass_loss.append(loss_v)
pass_duration += batch_duration
train_num += len(data)
every_pass_loss.append(loss_v)
iter += 1
total_train_time += pass_duration total_train_time += pass_duration
train_avg_loss = np.mean(every_pass_loss)
if args.for_model_ce and pass_id == num_passes - 1:
examples_per_sec = train_num / total_train_time
cost = np.mean(every_pass_loss)
with open("train_speed_factor.txt", 'w') as f:
f.write('{:f}\n'.format(examples_per_sec))
with open("train_cost_factor.txt", 'a+') as f:
f.write('{:f}\n'.format(cost))
best_map = test(pass_id, best_map) best_map = test(pass_id, best_map)
if args.for_model_ce:
#map_kpi = ce_map(pass_id, best_map)
#print ("kpis train_acc %f" % train_avg_acc)
print ("kpis train_cost %f" % train_avg_loss)
#print ("kpis test_acc %f" % test_avg_acc)
print ("kpis train_duration %f" % (end_time - start_time))
if pass_id % 10 == 0 or pass_id == num_passes - 1: if pass_id % 10 == 0 or pass_id == num_passes - 1:
save_model(str(pass_id)) save_model(str(pass_id))
print("Best test map {0}".format(best_map)) print("Best test map {0}".format(best_map))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册