未验证 提交 8bb5f593 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #1363 from panyx0718/pick_fix

fix model to use different names for train and test so that var name …
......@@ -88,25 +88,27 @@ def build_program(main_prog, startup_prog, train_params, is_train):
image, gt_box, gt_label, difficult = fluid.layers.read_file(py_reader)
locs, confs, box, box_var = mobile_net(class_num, image, image_shape)
if is_train:
loss = fluid.layers.ssd_loss(locs, confs, gt_box, gt_label, box,
box_var)
loss = fluid.layers.reduce_sum(loss)
optimizer = optimizer_setting(train_params)
optimizer.minimize(loss)
with fluid.unique_name.guard("train"):
loss = fluid.layers.ssd_loss(locs, confs, gt_box, gt_label, box,
box_var)
loss = fluid.layers.reduce_sum(loss)
optimizer = optimizer_setting(train_params)
optimizer.minimize(loss)
else:
nmsed_out = fluid.layers.detection_output(
locs, confs, box, box_var, nms_threshold=0.45)
with fluid.program_guard(main_prog):
loss = fluid.evaluator.DetectionMAP(
nmsed_out,
gt_label,
gt_box,
difficult,
class_num,
overlap_threshold=0.5,
evaluate_difficult=False,
ap_version=args.ap_version)
with fluid.unique_name.guard("inference"):
nmsed_out = fluid.layers.detection_output(
locs, confs, box, box_var, nms_threshold=0.45)
with fluid.program_guard(main_prog):
loss = fluid.evaluator.DetectionMAP(
nmsed_out,
gt_label,
gt_box,
difficult,
class_num,
overlap_threshold=0.5,
evaluate_difficult=False,
ap_version=args.ap_version)
return py_reader, loss
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册