提交 9fb2a20e 编写于 作者: D dengkaipeng

refine test.py and infer.py

上级 3596eca2
......@@ -114,41 +114,36 @@ def infer(args):
infer_feeder = fluid.DataFeeder(place=place, feed_list=infer_feeds)
fetch_list = [x.name for x in infer_outputs]
def _infer_loop():
periods = []
results = []
periods = []
results = []
cur_time = time.time()
for infer_iter, data in enumerate(infer_reader()):
data_feed_in = [items[:-1] for items in data]
video_id = [items[-1] for items in data]
infer_outs = exe.run(fetch_list=fetch_list,
feed=infer_feeder.feed(data_feed_in))
predictions = np.array(infer_outs[0])
for i in range(len(predictions)):
topk_inds = predictions[i].argsort()[0 - args.infer_topk:]
topk_inds = topk_inds[::-1]
preds = predictions[i][topk_inds]
results.append(
(video_id[i], preds.tolist(), topk_inds.tolist()))
prev_time = cur_time
cur_time = time.time()
for infer_iter, data in enumerate(infer_reader()):
data_feed_in = [items[:-1] for items in data]
video_id = [items[-1] for items in data]
infer_outs = exe.run(fetch_list=fetch_list,
feed=infer_feeder.feed(data_feed_in))
predictions = np.array(infer_outs[0])
for i in range(len(predictions)):
topk_inds = predictions[i].argsort()[0 - args.infer_topk:]
topk_inds = topk_inds[::-1]
preds = predictions[i][topk_inds]
results.append(
(video_id[i], preds.tolist(), topk_inds.tolist()))
prev_time = cur_time
cur_time = time.time()
period = cur_time - prev_time
periods.append(period)
logger.info('Processed {} samples'.format((infer_iter) * len(
predictions)))
logger.info('[INFER] infer finished. average time: {}'.format(
np.mean(periods)))
return results
# start infer loop
infer_results = _infer_loop()
period = cur_time - prev_time
periods.append(period)
logger.info('Processed {} samples'.format((infer_iter) * len(
predictions)))
logger.info('[INFER] infer finished. average time: {}'.format(
np.mean(periods)))
if not os.path.isdir(args.save_dir):
os.mkdir(args.save_dir)
result_file_name = os.path.join(args.save_dir,
"{}_infer_result".format(args.model_name))
pickle.dump(infer_results, open(result_file_name, 'wb'))
pickle.dump(results, open(result_file_name, 'wb'))
if __name__ == "__main__":
args = parse_args()
......
......@@ -34,7 +34,7 @@ class AttentionCluster(ModelBase):
self.feature_dims = self.cfg.MODEL.feature_dims
self.cluster_nums = self.cfg.MODEL.cluster_nums
self.seg_num = self.cfg.MODEL.seg_num
self.class_num = self.cfg.MODEL.num_classes #self.cfg.MODEL.class_num
self.class_num = self.cfg.MODEL.num_classes
self.drop_rate = self.cfg.MODEL.drop_rate
if self.mode == 'train':
......
......@@ -97,27 +97,23 @@ def test(args):
fetch_list = [loss.name] + [x.name
for x in test_outputs] + [test_feeds[-1].name]
def _test_loop():
epoch_period = []
for test_iter, data in enumerate(test_reader()):
cur_time = time.time()
test_outs = exe.run(fetch_list=fetch_list,
feed=test_feeder.feed(data))
period = time.time() - cur_time
epoch_period.append(period)
loss = np.array(test_outs[0])
pred = np.array(test_outs[1])
label = np.array(test_outs[-1])
test_metrics.accumulate(loss, pred, label)
# metric here
if args.log_interval > 0 and test_iter % args.log_interval == 0:
info_str = '[EVAL] Batch {}'.format(test_iter)
test_metrics.calculate_and_log_out(loss, pred, label, info_str)
test_metrics.finalize_and_log_out("[EVAL] eval finished. ")
# start eval loop
_test_loop()
epoch_period = []
for test_iter, data in enumerate(test_reader()):
cur_time = time.time()
test_outs = exe.run(fetch_list=fetch_list,
feed=test_feeder.feed(data))
period = time.time() - cur_time
epoch_period.append(period)
loss = np.array(test_outs[0])
pred = np.array(test_outs[1])
label = np.array(test_outs[-1])
test_metrics.accumulate(loss, pred, label)
# metric here
if args.log_interval > 0 and test_iter % args.log_interval == 0:
info_str = '[EVAL] Batch {}'.format(test_iter)
test_metrics.calculate_and_log_out(loss, pred, label, info_str)
test_metrics.finalize_and_log_out("[EVAL] eval finished. ")
if __name__ == "__main__":
......
......@@ -188,20 +188,20 @@ def train(args):
if args.no_use_pyreader:
train_feeder = fluid.DataFeeder(place=place, feed_list=train_feeds)
valid_feeder = fluid.DataFeeder(place=place, feed_list=valid_feeds)
train_without_pyreader(exe, train_prog, train_exe, train_reader, train_feeder, \
train_fetch_list, train_metrics, epochs = epochs, \
log_interval = args.log_interval, valid_interval = args.valid_interval, \
save_dir = args.save_dir, save_model_name = args.model_name, \
test_exe = valid_exe, test_reader = valid_reader, test_feeder = valid_feeder, \
train_without_pyreader(exe, train_prog, train_exe, train_reader, train_feeder,
train_fetch_list, train_metrics, epochs = epochs,
log_interval = args.log_interval, valid_interval = args.valid_interval,
save_dir = args.save_dir, save_model_name = args.model_name,
test_exe = valid_exe, test_reader = valid_reader, test_feeder = valid_feeder,
test_fetch_list = valid_fetch_list, test_metrics = valid_metrics)
else:
train_pyreader.decorate_paddle_reader(train_reader)
valid_pyreader.decorate_paddle_reader(valid_reader)
train_with_pyreader(exe, train_prog, train_exe, train_pyreader, train_fetch_list, train_metrics, \
epochs = epochs, log_interval = args.log_interval, \
valid_interval = args.valid_interval, \
save_dir = args.save_dir, save_model_name = args.model_name, \
test_exe = valid_exe, test_pyreader = valid_pyreader, \
train_with_pyreader(exe, train_prog, train_exe, train_pyreader, train_fetch_list, train_metrics,
epochs = epochs, log_interval = args.log_interval,
valid_interval = args.valid_interval,
save_dir = args.save_dir, save_model_name = args.model_name,
test_exe = valid_exe, test_pyreader = valid_pyreader,
test_fetch_list = valid_fetch_list, test_metrics = valid_metrics)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册