提交 2a2b7f4c 编写于 作者: D dangqingqing

Refine train.py for MobileNet-SSD.

上级 0675f9cd
......@@ -216,7 +216,7 @@ def distort_image(img, settings):
def expand_image(img, bbox_labels, img_width, img_height, settings):
prob = random.uniform(0, 1)
if prob < settings._hue_prob:
if prob < settings._expand_prob:
expand_ratio = random.uniform(1, settings._expand_max_ratio)
if expand_ratio - 1 >= 0.01:
height = int(img_height * expand_ratio)
......
......@@ -50,11 +50,12 @@ def train(args,
box, box_var)
nmsed_out = fluid.layers.detection_output(
locs, confs, box, box_var, nms_threshold=0.45)
loss = fluid.layers.reduce_sum(loss)
pd.write_output(loss)
pd.write_output(nmsed_out)
loss, nmsed_out = pd()
loss = fluid.layers.reduce_sum(loss)
loss = fluid.layers.mean(loss)
else:
locs, confs, box, box_var = mobile_net(image, image_shape)
nmsed_out = fluid.layers.detection_output(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册