未验证 提交 986ff5cc 编写于 作者: Q qingqing01 提交者: GitHub

Minor change for save inferecen model. (#1436)

上级 0e3e4681
...@@ -250,6 +250,7 @@ class PyramidBox(object): ...@@ -250,6 +250,7 @@ class PyramidBox(object):
face_loc, head_loc = fluid.layers.split( face_loc, head_loc = fluid.layers.split(
mbox_loc, num_or_sections=2, dim=1) mbox_loc, num_or_sections=2, dim=1)
face_loc = permute_and_reshape(face_loc, 4) face_loc = permute_and_reshape(face_loc, 4)
if not self.is_infer:
head_loc = permute_and_reshape(head_loc, 4) head_loc = permute_and_reshape(head_loc, 4)
mbox_conf = fluid.layers.conv2d( mbox_conf = fluid.layers.conv2d(
...@@ -259,14 +260,17 @@ class PyramidBox(object): ...@@ -259,14 +260,17 @@ class PyramidBox(object):
face_conf3_maxin = fluid.layers.reduce_max( face_conf3_maxin = fluid.layers.reduce_max(
face_conf3, dim=1, keep_dim=True) face_conf3, dim=1, keep_dim=True)
face_conf = fluid.layers.concat([face_conf3_maxin, face_conf1], axis=1) face_conf = fluid.layers.concat([face_conf3_maxin, face_conf1], axis=1)
face_conf = permute_and_reshape(face_conf, 2)
if not self.is_infer:
head_conf3_maxin = fluid.layers.reduce_max( head_conf3_maxin = fluid.layers.reduce_max(
head_conf3, dim=1, keep_dim=True) head_conf3, dim=1, keep_dim=True)
head_conf = fluid.layers.concat([head_conf3_maxin, head_conf1], axis=1) head_conf = fluid.layers.concat(
face_conf = permute_and_reshape(face_conf, 2) [head_conf3_maxin, head_conf1], axis=1)
head_conf = permute_and_reshape(head_conf, 2) head_conf = permute_and_reshape(head_conf, 2)
face_locs.append(face_loc) face_locs.append(face_loc)
face_confs.append(face_conf) face_confs.append(face_conf)
if not self.is_infer:
head_locs.append(head_loc) head_locs.append(head_loc)
head_confs.append(head_conf) head_confs.append(head_conf)
...@@ -293,6 +297,7 @@ class PyramidBox(object): ...@@ -293,6 +297,7 @@ class PyramidBox(object):
face_loc, head_loc = fluid.layers.split( face_loc, head_loc = fluid.layers.split(
mbox_loc, num_or_sections=2, dim=1) mbox_loc, num_or_sections=2, dim=1)
face_loc = permute_and_reshape(face_loc, 4) face_loc = permute_and_reshape(face_loc, 4)
if not self.is_infer:
head_loc = permute_and_reshape(head_loc, 4) head_loc = permute_and_reshape(head_loc, 4)
mbox_conf = fluid.layers.conv2d(input, 6, 3, 1, 1, bias_attr=b_attr) mbox_conf = fluid.layers.conv2d(input, 6, 3, 1, 1, bias_attr=b_attr)
...@@ -304,11 +309,13 @@ class PyramidBox(object): ...@@ -304,11 +309,13 @@ class PyramidBox(object):
[face_conf1, face_conf3_maxin], axis=1) [face_conf1, face_conf3_maxin], axis=1)
face_conf = permute_and_reshape(face_conf, 2) face_conf = permute_and_reshape(face_conf, 2)
if not self.is_infer:
head_conf = permute_and_reshape(head_conf, 2) head_conf = permute_and_reshape(head_conf, 2)
face_locs.append(face_loc) face_locs.append(face_loc)
face_confs.append(face_conf) face_confs.append(face_conf)
if not self.is_infer:
head_locs.append(head_loc) head_locs.append(head_loc)
head_confs.append(head_conf) head_confs.append(head_conf)
...@@ -330,6 +337,7 @@ class PyramidBox(object): ...@@ -330,6 +337,7 @@ class PyramidBox(object):
self.face_mbox_loc = fluid.layers.concat(face_locs, axis=1) self.face_mbox_loc = fluid.layers.concat(face_locs, axis=1)
self.face_mbox_conf = fluid.layers.concat(face_confs, axis=1) self.face_mbox_conf = fluid.layers.concat(face_confs, axis=1)
if not self.is_infer:
self.head_mbox_loc = fluid.layers.concat(head_locs, axis=1) self.head_mbox_loc = fluid.layers.concat(head_locs, axis=1)
self.head_mbox_conf = fluid.layers.concat(head_confs, axis=1) self.head_mbox_conf = fluid.layers.concat(head_confs, axis=1)
......
...@@ -308,6 +308,9 @@ if __name__ == '__main__': ...@@ -308,6 +308,9 @@ if __name__ == '__main__':
infer_program, nmsed_out = network.infer(main_program) infer_program, nmsed_out = network.infer(main_program)
fetches = [nmsed_out] fetches = [nmsed_out]
fluid.io.load_persistables( fluid.io.load_persistables(
exe, args.model_dir, main_program=main_program) exe, args.model_dir, main_program=infer_program)
# save model and program
#fluid.io.save_inference_model('pyramidbox_model',
# ['image'], [nmsed_out], exe, main_program=infer_program,
# model_filename='model', params_filename='params')
infer(args, config) infer(args, config)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册