提交 ede3ecc4 编写于 作者: W wuzewu

Fix fpn issue in paddle 1.7.x

上级 226d8ca9
......@@ -89,7 +89,7 @@ class FasterRCNNResNet50(hub.Module):
with fluid.program_guard(context_prog, startup_program):
with fluid.unique_name.guard():
image = fluid.layers.data(
name='image', shape=[3, 800, 1333], dtype='float32')
name='image', shape=[-1, 3, -1, -1], dtype='float32')
# backbone
backbone = ResNet(
norm_type='affine_channel',
......@@ -201,9 +201,13 @@ class FasterRCNNResNet50(hub.Module):
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_program)
if pretrained:
def _if_exist(var):
if num_classes != 81:
if 'bbox_pred' in var.name or 'cls_score' in var.name:
return False
return os.path.exists(
os.path.join(self.default_pretrained_model_path,
var.name))
......@@ -212,8 +216,6 @@ class FasterRCNNResNet50(hub.Module):
exe,
self.default_pretrained_model_path,
predicate=_if_exist)
else:
exe.run(startup_program)
return inputs, outputs, context_prog
def rpn_head(self):
......
......@@ -90,7 +90,7 @@ class FasterRCNNResNet50RPN(hub.Module):
with fluid.program_guard(context_prog, startup_program):
with fluid.unique_name.guard():
image = fluid.layers.data(
name='image', shape=[3, 800, 1333], dtype='float32')
name='image', shape=[-1, 3, -1, -1], dtype='float32')
# backbone
backbone = ResNet(
norm_type='affine_channel',
......@@ -202,9 +202,13 @@ class FasterRCNNResNet50RPN(hub.Module):
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_program)
if pretrained:
def _if_exist(var):
if num_classes != 81:
if 'bbox_pred' in var.name or 'cls_score' in var.name:
return False
return os.path.exists(
os.path.join(self.default_pretrained_model_path,
var.name))
......@@ -213,8 +217,6 @@ class FasterRCNNResNet50RPN(hub.Module):
exe,
self.default_pretrained_model_path,
predicate=_if_exist)
else:
exe.run(startup_program)
return inputs, outputs, context_prog
def rpn_head(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册