提交 9fb2a20e 编写于 作者: D dengkaipeng

refine test.py and infer.py

上级 3596eca2
...@@ -114,7 +114,6 @@ def infer(args): ...@@ -114,7 +114,6 @@ def infer(args):
infer_feeder = fluid.DataFeeder(place=place, feed_list=infer_feeds) infer_feeder = fluid.DataFeeder(place=place, feed_list=infer_feeds)
fetch_list = [x.name for x in infer_outputs] fetch_list = [x.name for x in infer_outputs]
def _infer_loop():
periods = [] periods = []
results = [] results = []
cur_time = time.time() cur_time = time.time()
...@@ -139,16 +138,12 @@ def infer(args): ...@@ -139,16 +138,12 @@ def infer(args):
logger.info('[INFER] infer finished. average time: {}'.format( logger.info('[INFER] infer finished. average time: {}'.format(
np.mean(periods))) np.mean(periods)))
return results
# start infer loop
infer_results = _infer_loop()
if not os.path.isdir(args.save_dir): if not os.path.isdir(args.save_dir):
os.mkdir(args.save_dir) os.mkdir(args.save_dir)
result_file_name = os.path.join(args.save_dir, result_file_name = os.path.join(args.save_dir,
"{}_infer_result".format(args.model_name)) "{}_infer_result".format(args.model_name))
pickle.dump(infer_results, open(result_file_name, 'wb')) pickle.dump(results, open(result_file_name, 'wb'))
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
......
...@@ -34,7 +34,7 @@ class AttentionCluster(ModelBase): ...@@ -34,7 +34,7 @@ class AttentionCluster(ModelBase):
self.feature_dims = self.cfg.MODEL.feature_dims self.feature_dims = self.cfg.MODEL.feature_dims
self.cluster_nums = self.cfg.MODEL.cluster_nums self.cluster_nums = self.cfg.MODEL.cluster_nums
self.seg_num = self.cfg.MODEL.seg_num self.seg_num = self.cfg.MODEL.seg_num
self.class_num = self.cfg.MODEL.num_classes #self.cfg.MODEL.class_num self.class_num = self.cfg.MODEL.num_classes
self.drop_rate = self.cfg.MODEL.drop_rate self.drop_rate = self.cfg.MODEL.drop_rate
if self.mode == 'train': if self.mode == 'train':
......
...@@ -97,7 +97,6 @@ def test(args): ...@@ -97,7 +97,6 @@ def test(args):
fetch_list = [loss.name] + [x.name fetch_list = [loss.name] + [x.name
for x in test_outputs] + [test_feeds[-1].name] for x in test_outputs] + [test_feeds[-1].name]
def _test_loop():
epoch_period = [] epoch_period = []
for test_iter, data in enumerate(test_reader()): for test_iter, data in enumerate(test_reader()):
cur_time = time.time() cur_time = time.time()
...@@ -116,9 +115,6 @@ def test(args): ...@@ -116,9 +115,6 @@ def test(args):
test_metrics.calculate_and_log_out(loss, pred, label, info_str) test_metrics.calculate_and_log_out(loss, pred, label, info_str)
test_metrics.finalize_and_log_out("[EVAL] eval finished. ") test_metrics.finalize_and_log_out("[EVAL] eval finished. ")
# start eval loop
_test_loop()
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
......
...@@ -188,20 +188,20 @@ def train(args): ...@@ -188,20 +188,20 @@ def train(args):
if args.no_use_pyreader: if args.no_use_pyreader:
train_feeder = fluid.DataFeeder(place=place, feed_list=train_feeds) train_feeder = fluid.DataFeeder(place=place, feed_list=train_feeds)
valid_feeder = fluid.DataFeeder(place=place, feed_list=valid_feeds) valid_feeder = fluid.DataFeeder(place=place, feed_list=valid_feeds)
train_without_pyreader(exe, train_prog, train_exe, train_reader, train_feeder, \ train_without_pyreader(exe, train_prog, train_exe, train_reader, train_feeder,
train_fetch_list, train_metrics, epochs = epochs, \ train_fetch_list, train_metrics, epochs = epochs,
log_interval = args.log_interval, valid_interval = args.valid_interval, \ log_interval = args.log_interval, valid_interval = args.valid_interval,
save_dir = args.save_dir, save_model_name = args.model_name, \ save_dir = args.save_dir, save_model_name = args.model_name,
test_exe = valid_exe, test_reader = valid_reader, test_feeder = valid_feeder, \ test_exe = valid_exe, test_reader = valid_reader, test_feeder = valid_feeder,
test_fetch_list = valid_fetch_list, test_metrics = valid_metrics) test_fetch_list = valid_fetch_list, test_metrics = valid_metrics)
else: else:
train_pyreader.decorate_paddle_reader(train_reader) train_pyreader.decorate_paddle_reader(train_reader)
valid_pyreader.decorate_paddle_reader(valid_reader) valid_pyreader.decorate_paddle_reader(valid_reader)
train_with_pyreader(exe, train_prog, train_exe, train_pyreader, train_fetch_list, train_metrics, \ train_with_pyreader(exe, train_prog, train_exe, train_pyreader, train_fetch_list, train_metrics,
epochs = epochs, log_interval = args.log_interval, \ epochs = epochs, log_interval = args.log_interval,
valid_interval = args.valid_interval, \ valid_interval = args.valid_interval,
save_dir = args.save_dir, save_model_name = args.model_name, \ save_dir = args.save_dir, save_model_name = args.model_name,
test_exe = valid_exe, test_pyreader = valid_pyreader, \ test_exe = valid_exe, test_pyreader = valid_pyreader,
test_fetch_list = valid_fetch_list, test_metrics = valid_metrics) test_fetch_list = valid_fetch_list, test_metrics = valid_metrics)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册