提交 30acc565 编写于 作者: B Bai Yifan 提交者: whs

use develop PaddleDetection create_feed interface (#3767)

上级 7942a4da
......@@ -143,7 +143,7 @@ def main():
# build program
model = create(main_arch)
train_loader, train_feed_vars = create_feed(train_feed, iterable=True)
_, train_feed_vars = create_feed(train_feed, False)
train_fetches = model.train(train_feed_vars)
loss = train_fetches['loss']
lr = lr_builder()
......@@ -155,7 +155,6 @@ def main():
cfg.max_iters = 258
train_reader = create_reader(train_feed, cfg.max_iters, FLAGS.dataset_dir)
train_loader.set_sample_list_generator(train_reader, place)
exe.run(fluid.default_startup_program())
......@@ -174,7 +173,7 @@ def main():
with fluid.program_guard(eval_prog, startup_prog):
with fluid.unique_name.guard():
model = create(main_arch)
_, test_feed_vars = create_feed(eval_feed, iterable=True)
_, test_feed_vars = create_feed(eval_feed, False)
fetches = model.eval(test_feed_vars)
eval_prog = eval_prog.clone(True)
......
......@@ -32,11 +32,13 @@ from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass
from paddle.fluid.contrib.slim.quantization import TransformForMobilePass
def set_paddle_flags(**kwargs):
for key, value in kwargs.items():
if os.environ.get(key, None) is None:
os.environ[key] = str(value)
# NOTE(paddle-dev): All of these flags should be set before
# `import paddle`. Otherwise, it would not take any effect.
set_paddle_flags(
......@@ -59,6 +61,8 @@ import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
def eval_run(exe, compile_program, reader, keys, values, cls, test_feed):
"""
Run evaluation program, return program outputs.
......@@ -71,8 +75,7 @@ def eval_run(exe, compile_program, reader, keys, values, cls, test_feed):
has_bbox = 'bbox' in keys
for data in reader():
data = test_feed.feed(data)
feed_data = {'image': data['image'],
'im_size': data['im_size']}
feed_data = {'image': data['image'], 'im_size': data['im_size']}
outs = exe.run(compile_program,
feed=feed_data,
fetch_list=values[0],
......@@ -123,7 +126,6 @@ def main():
devices_num = int(
os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
if 'eval_feed' not in cfg:
eval_feed = create(main_arch + 'EvalFeed')
else:
......@@ -132,42 +134,40 @@ def main():
place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
_, test_feed_vars = create_feed(eval_feed, iterable=True)
_, test_feed_vars = create_feed(eval_feed, False)
eval_reader = create_reader(eval_feed, args_path=FLAGS.dataset_dir)
#eval_pyreader.decorate_sample_list_generator(eval_reader, place)
test_data_feed = fluid.DataFeeder(test_feed_vars.values(), place)
assert os.path.exists(FLAGS.model_path)
infer_prog, feed_names, fetch_targets = fluid.io.load_inference_model(
dirname=FLAGS.model_path, executor=exe,
model_filename=FLAGS.model_name,
params_filename=FLAGS.params_name)
dirname=FLAGS.model_path,
executor=exe,
model_filename=FLAGS.model_name,
params_filename=FLAGS.params_name)
eval_keys = ['bbox', 'gt_box', 'gt_label', 'is_difficult']
eval_values = ['multiclass_nms_0.tmp_0', 'gt_box', 'gt_label', 'is_difficult']
eval_values = [
'multiclass_nms_0.tmp_0', 'gt_box', 'gt_label', 'is_difficult'
]
eval_cls = []
eval_values[0] = fetch_targets[0]
results = eval_run(exe, infer_prog, eval_reader,
eval_keys, eval_values, eval_cls, test_data_feed)
results = eval_run(exe, infer_prog, eval_reader, eval_keys, eval_values,
eval_cls, test_data_feed)
resolution = None
if 'mask' in results[0]:
resolution = model.mask_head.resolution
eval_results(results, eval_feed, cfg.metric, cfg.num_classes,
resolution, False, FLAGS.output_eval)
eval_results(results, eval_feed, cfg.metric, cfg.num_classes, resolution,
False, FLAGS.output_eval)
if __name__ == '__main__':
parser = ArgsParser()
parser.add_argument(
"-m",
"--model_path",
default=None,
type=str,
help="path of checkpoint")
"-m", "--model_path", default=None, type=str, help="path of checkpoint")
parser.add_argument(
"--output_eval",
default=None,
......
......@@ -143,7 +143,7 @@ def main():
with fluid.program_guard(train_prog, startup_prog):
with fluid.unique_name.guard():
model = create(main_arch)
train_loader, feed_vars = create_feed(train_feed, iterable=True)
_, feed_vars = create_feed(train_feed, False)
train_fetches = model.train(feed_vars)
loss = train_fetches['loss']
lr = lr_builder()
......@@ -151,7 +151,6 @@ def main():
optimizer.minimize(loss)
train_reader = create_reader(train_feed, cfg.max_iters, FLAGS.dataset_dir)
train_loader.set_sample_list_generator(train_reader, place)
# parse train fetches
train_keys, train_values, _ = parse_fetches(train_fetches)
......@@ -166,13 +165,12 @@ def main():
with fluid.program_guard(eval_prog, startup_prog):
with fluid.unique_name.guard():
model = create(main_arch)
_, test_feed_vars = create_feed(eval_feed, iterable=True)
_, test_feed_vars = create_feed(eval_feed, False)
fetches = model.eval(test_feed_vars)
eval_prog = eval_prog.clone(True)
eval_reader = create_reader(eval_feed, args_path=FLAGS.dataset_dir)
#eval_pyreader.decorate_sample_list_generator(eval_reader, place)
test_data_feed = fluid.DataFeeder(test_feed_vars.values(), place)
# parse eval fetches
......
......@@ -151,7 +151,7 @@ def main():
with fluid.program_guard(train_prog, startup_prog):
with fluid.unique_name.guard():
model = create(main_arch)
train_loader, feed_vars = create_feed(train_feed, iterable=True)
_, feed_vars = create_feed(train_feed, False)
train_fetches = model.train(feed_vars)
loss = train_fetches['loss']
lr = lr_builder()
......@@ -159,7 +159,6 @@ def main():
optimizer.minimize(loss)
train_reader = create_reader(train_feed, cfg.max_iters, FLAGS.dataset_dir)
train_loader.set_sample_list_generator(train_reader, place)
# parse train fetches
train_keys, train_values, _ = parse_fetches(train_fetches)
......@@ -174,7 +173,7 @@ def main():
with fluid.program_guard(eval_prog, startup_prog):
with fluid.unique_name.guard():
model = create(main_arch)
_, test_feed_vars = create_feed(eval_feed, iterable=True)
_, test_feed_vars = create_feed(eval_feed, False)
fetches = model.eval(test_feed_vars)
eval_prog = eval_prog.clone(True)
......
......@@ -134,7 +134,7 @@ def main():
place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
_, test_feed_vars = create_feed(eval_feed, iterable=True)
_, test_feed_vars = create_feed(eval_feed, False)
eval_reader = create_reader(eval_feed, args_path=FLAGS.dataset_dir)
#eval_pyreader.decorate_sample_list_generator(eval_reader, place)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册