提交 3a1bad1c 编写于 作者: D dangqingqing

Refine train.py for MobileNet-SSD.

上级 356d9637
......@@ -12,8 +12,9 @@ import functools
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('parallel', bool, True, "Whether use parallel training.")
add_arg('use_gpu', bool, True, "Whether use GPU.")
add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('parallel', bool, True, "Whether use parallel training.")
add_arg('use_gpu', bool, True, "Whether use GPU.")
# yapf: disable
......@@ -47,26 +48,23 @@ def train(args,
locs, confs, box, box_var = mobile_net(image_, image_shape)
loss = fluid.layers.ssd_loss(locs, confs, gt_box_, gt_label_,
box, box_var)
nmsed_out = fluid.layers.detection_output(
locs, confs, box, box_var, nms_threshold=0.45)
pd.write_output(loss)
pd.write_output(locs)
pd.write_output(confs)
pd.write_output(box)
pd.write_output(box_var)
pd.write_output(nmsed_out)
loss, locs, confs, box, box_var = pd()
loss, nmsed_out = pd()
loss = fluid.layers.reduce_sum(loss)
else:
locs, confs, box, box_var = mobile_net(image, image_shape)
nmsed_out = fluid.layers.detection_output(
locs, mbox_confs, box, box_var, nms_threshold=0.45)
loss = fluid.layers.ssd_loss(locs, mbox_confs, gt_box, gt_label,
locs, confs, box, box_var, nms_threshold=0.45)
loss = fluid.layers.ssd_loss(locs, confs, gt_box, gt_label,
box, box_var)
loss = fluid.layers.reduce_sum(loss)
test_program = fluid.default_main_program().clone(for_test=True)
with fluid.program_guard(test_program):
nmsed_out = fluid.layers.detection_output(
locs, confs, box, box_var, nms_threshold=0.45)
map_eval = fluid.evaluator.DetectionMAP(
nmsed_out,
gt_label,
......@@ -100,7 +98,6 @@ def train(args,
feeder = fluid.DataFeeder(
place=place, feed_list=[image, gt_box, gt_label, difficult])
#print 'test_program ', test_program
def test(pass_id):
_, accum_map = map_eval.get_map_var()
map_eval.reset(exe)
......@@ -111,7 +108,6 @@ def train(args,
fetch_list=[accum_map])
print("Test {0}, map {1}".format(pass_id, test_map[0]))
#print 'main_program ', fluid.default_main_program()
for pass_id in range(num_passes):
for batch_id, data in enumerate(train_reader()):
loss_v = exe.run(fluid.default_main_program(),
......@@ -142,5 +138,5 @@ if __name__ == '__main__':
val_file_list='./data/test.txt',
data_args=data_args,
learning_rate=0.001,
batch_size=32,
batch_size=args.batch_size,
num_passes=300)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册