提交 9fb2a20e 编写于 作者: D dengkaipeng

refine test.py and infer.py

上级 3596eca2
...@@ -114,41 +114,36 @@ def infer(args): ...@@ -114,41 +114,36 @@ def infer(args):
infer_feeder = fluid.DataFeeder(place=place, feed_list=infer_feeds) infer_feeder = fluid.DataFeeder(place=place, feed_list=infer_feeds)
fetch_list = [x.name for x in infer_outputs] fetch_list = [x.name for x in infer_outputs]
def _infer_loop(): periods = []
periods = [] results = []
results = [] cur_time = time.time()
for infer_iter, data in enumerate(infer_reader()):
data_feed_in = [items[:-1] for items in data]
video_id = [items[-1] for items in data]
infer_outs = exe.run(fetch_list=fetch_list,
feed=infer_feeder.feed(data_feed_in))
predictions = np.array(infer_outs[0])
for i in range(len(predictions)):
topk_inds = predictions[i].argsort()[0 - args.infer_topk:]
topk_inds = topk_inds[::-1]
preds = predictions[i][topk_inds]
results.append(
(video_id[i], preds.tolist(), topk_inds.tolist()))
prev_time = cur_time
cur_time = time.time() cur_time = time.time()
for infer_iter, data in enumerate(infer_reader()): period = cur_time - prev_time
data_feed_in = [items[:-1] for items in data] periods.append(period)
video_id = [items[-1] for items in data] logger.info('Processed {} samples'.format((infer_iter) * len(
infer_outs = exe.run(fetch_list=fetch_list, predictions)))
feed=infer_feeder.feed(data_feed_in))
predictions = np.array(infer_outs[0]) logger.info('[INFER] infer finished. average time: {}'.format(
for i in range(len(predictions)): np.mean(periods)))
topk_inds = predictions[i].argsort()[0 - args.infer_topk:]
topk_inds = topk_inds[::-1]
preds = predictions[i][topk_inds]
results.append(
(video_id[i], preds.tolist(), topk_inds.tolist()))
prev_time = cur_time
cur_time = time.time()
period = cur_time - prev_time
periods.append(period)
logger.info('Processed {} samples'.format((infer_iter) * len(
predictions)))
logger.info('[INFER] infer finished. average time: {}'.format(
np.mean(periods)))
return results
# start infer loop
infer_results = _infer_loop()
if not os.path.isdir(args.save_dir): if not os.path.isdir(args.save_dir):
os.mkdir(args.save_dir) os.mkdir(args.save_dir)
result_file_name = os.path.join(args.save_dir, result_file_name = os.path.join(args.save_dir,
"{}_infer_result".format(args.model_name)) "{}_infer_result".format(args.model_name))
pickle.dump(infer_results, open(result_file_name, 'wb')) pickle.dump(results, open(result_file_name, 'wb'))
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
......
...@@ -34,7 +34,7 @@ class AttentionCluster(ModelBase): ...@@ -34,7 +34,7 @@ class AttentionCluster(ModelBase):
self.feature_dims = self.cfg.MODEL.feature_dims self.feature_dims = self.cfg.MODEL.feature_dims
self.cluster_nums = self.cfg.MODEL.cluster_nums self.cluster_nums = self.cfg.MODEL.cluster_nums
self.seg_num = self.cfg.MODEL.seg_num self.seg_num = self.cfg.MODEL.seg_num
self.class_num = self.cfg.MODEL.num_classes #self.cfg.MODEL.class_num self.class_num = self.cfg.MODEL.num_classes
self.drop_rate = self.cfg.MODEL.drop_rate self.drop_rate = self.cfg.MODEL.drop_rate
if self.mode == 'train': if self.mode == 'train':
......
...@@ -97,27 +97,23 @@ def test(args): ...@@ -97,27 +97,23 @@ def test(args):
fetch_list = [loss.name] + [x.name fetch_list = [loss.name] + [x.name
for x in test_outputs] + [test_feeds[-1].name] for x in test_outputs] + [test_feeds[-1].name]
def _test_loop(): epoch_period = []
epoch_period = [] for test_iter, data in enumerate(test_reader()):
for test_iter, data in enumerate(test_reader()): cur_time = time.time()
cur_time = time.time() test_outs = exe.run(fetch_list=fetch_list,
test_outs = exe.run(fetch_list=fetch_list, feed=test_feeder.feed(data))
feed=test_feeder.feed(data)) period = time.time() - cur_time
period = time.time() - cur_time epoch_period.append(period)
epoch_period.append(period) loss = np.array(test_outs[0])
loss = np.array(test_outs[0]) pred = np.array(test_outs[1])
pred = np.array(test_outs[1]) label = np.array(test_outs[-1])
label = np.array(test_outs[-1]) test_metrics.accumulate(loss, pred, label)
test_metrics.accumulate(loss, pred, label)
# metric here
# metric here if args.log_interval > 0 and test_iter % args.log_interval == 0:
if args.log_interval > 0 and test_iter % args.log_interval == 0: info_str = '[EVAL] Batch {}'.format(test_iter)
info_str = '[EVAL] Batch {}'.format(test_iter) test_metrics.calculate_and_log_out(loss, pred, label, info_str)
test_metrics.calculate_and_log_out(loss, pred, label, info_str) test_metrics.finalize_and_log_out("[EVAL] eval finished. ")
test_metrics.finalize_and_log_out("[EVAL] eval finished. ")
# start eval loop
_test_loop()
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -188,20 +188,20 @@ def train(args): ...@@ -188,20 +188,20 @@ def train(args):
if args.no_use_pyreader: if args.no_use_pyreader:
train_feeder = fluid.DataFeeder(place=place, feed_list=train_feeds) train_feeder = fluid.DataFeeder(place=place, feed_list=train_feeds)
valid_feeder = fluid.DataFeeder(place=place, feed_list=valid_feeds) valid_feeder = fluid.DataFeeder(place=place, feed_list=valid_feeds)
train_without_pyreader(exe, train_prog, train_exe, train_reader, train_feeder, \ train_without_pyreader(exe, train_prog, train_exe, train_reader, train_feeder,
train_fetch_list, train_metrics, epochs = epochs, \ train_fetch_list, train_metrics, epochs = epochs,
log_interval = args.log_interval, valid_interval = args.valid_interval, \ log_interval = args.log_interval, valid_interval = args.valid_interval,
save_dir = args.save_dir, save_model_name = args.model_name, \ save_dir = args.save_dir, save_model_name = args.model_name,
test_exe = valid_exe, test_reader = valid_reader, test_feeder = valid_feeder, \ test_exe = valid_exe, test_reader = valid_reader, test_feeder = valid_feeder,
test_fetch_list = valid_fetch_list, test_metrics = valid_metrics) test_fetch_list = valid_fetch_list, test_metrics = valid_metrics)
else: else:
train_pyreader.decorate_paddle_reader(train_reader) train_pyreader.decorate_paddle_reader(train_reader)
valid_pyreader.decorate_paddle_reader(valid_reader) valid_pyreader.decorate_paddle_reader(valid_reader)
train_with_pyreader(exe, train_prog, train_exe, train_pyreader, train_fetch_list, train_metrics, \ train_with_pyreader(exe, train_prog, train_exe, train_pyreader, train_fetch_list, train_metrics,
epochs = epochs, log_interval = args.log_interval, \ epochs = epochs, log_interval = args.log_interval,
valid_interval = args.valid_interval, \ valid_interval = args.valid_interval,
save_dir = args.save_dir, save_model_name = args.model_name, \ save_dir = args.save_dir, save_model_name = args.model_name,
test_exe = valid_exe, test_pyreader = valid_pyreader, \ test_exe = valid_exe, test_pyreader = valid_pyreader,
test_fetch_list = valid_fetch_list, test_metrics = valid_metrics) test_fetch_list = valid_fetch_list, test_metrics = valid_metrics)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册