提交 45d41e25 编写于 作者: G Guanghua Yu 提交者: qingqing01

Fix object_detection eval and infer.

上级 949d558c
...@@ -8,7 +8,7 @@ import math ...@@ -8,7 +8,7 @@ import math
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import reader import reader
from mobilenet_ssd import mobile_net from mobilenet_ssd import build_mobilenet_ssd
from utility import add_arguments, print_arguments from utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
...@@ -47,8 +47,8 @@ def build_program(main_prog, startup_prog, args, data_args): ...@@ -47,8 +47,8 @@ def build_program(main_prog, startup_prog, args, data_args):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
image, gt_box, gt_label, difficult = fluid.layers.read_file( image, gt_box, gt_label, difficult = fluid.layers.read_file(
py_reader) py_reader)
locs, confs, box, box_var = mobile_net(num_classes, image, locs, confs, box, box_var = build_mobilenet_ssd(image, num_classes,
image_shape) image_shape)
nmsed_out = fluid.layers.detection_output( nmsed_out = fluid.layers.detection_output(
locs, confs, box, box_var, nms_threshold=args.nms_threshold) locs, confs, box, box_var, nms_threshold=args.nms_threshold)
with fluid.program_guard(main_prog): with fluid.program_guard(main_prog):
...@@ -67,7 +67,6 @@ def build_program(main_prog, startup_prog, args, data_args): ...@@ -67,7 +67,6 @@ def build_program(main_prog, startup_prog, args, data_args):
def eval(args, data_args, test_list, batch_size, model_dir=None): def eval(args, data_args, test_list, batch_size, model_dir=None):
startup_prog = fluid.Program() startup_prog = fluid.Program()
test_prog = fluid.Program() test_prog = fluid.Program()
test_py_reader, map_eval = build_program( test_py_reader, map_eval = build_program(
main_prog=test_prog, main_prog=test_prog,
startup_prog=startup_prog, startup_prog=startup_prog,
......
...@@ -10,7 +10,7 @@ from PIL import ImageFont ...@@ -10,7 +10,7 @@ from PIL import ImageFont
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import reader import reader
from mobilenet_ssd import mobile_net from mobilenet_ssd import build_mobilenet_ssd
from utility import add_arguments, print_arguments from utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
...@@ -50,7 +50,8 @@ def infer(args, data_args, image_path, model_dir): ...@@ -50,7 +50,8 @@ def infer(args, data_args, image_path, model_dir):
label_list = data_args.label_list label_list = data_args.label_list
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32') image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
locs, confs, box, box_var = mobile_net(num_classes, image, image_shape) locs, confs, box, box_var = build_mobilenet_ssd(image, num_classes,
image_shape)
nmsed_out = fluid.layers.detection_output( nmsed_out = fluid.layers.detection_output(
locs, confs, box, box_var, nms_threshold=args.nms_threshold) locs, confs, box, box_var, nms_threshold=args.nms_threshold)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册