提交 ede3ecc4 编写于 作者: W wuzewu

Fix fpn issue in paddle 1.7.x

上级 226d8ca9
...@@ -89,7 +89,7 @@ class FasterRCNNResNet50(hub.Module): ...@@ -89,7 +89,7 @@ class FasterRCNNResNet50(hub.Module):
with fluid.program_guard(context_prog, startup_program): with fluid.program_guard(context_prog, startup_program):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
image = fluid.layers.data( image = fluid.layers.data(
name='image', shape=[3, 800, 1333], dtype='float32') name='image', shape=[-1, 3, -1, -1], dtype='float32')
# backbone # backbone
backbone = ResNet( backbone = ResNet(
norm_type='affine_channel', norm_type='affine_channel',
...@@ -201,9 +201,13 @@ class FasterRCNNResNet50(hub.Module): ...@@ -201,9 +201,13 @@ class FasterRCNNResNet50(hub.Module):
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(startup_program)
if pretrained: if pretrained:
def _if_exist(var): def _if_exist(var):
if num_classes != 81:
if 'bbox_pred' in var.name or 'cls_score' in var.name:
return False
return os.path.exists( return os.path.exists(
os.path.join(self.default_pretrained_model_path, os.path.join(self.default_pretrained_model_path,
var.name)) var.name))
...@@ -212,8 +216,6 @@ class FasterRCNNResNet50(hub.Module): ...@@ -212,8 +216,6 @@ class FasterRCNNResNet50(hub.Module):
exe, exe,
self.default_pretrained_model_path, self.default_pretrained_model_path,
predicate=_if_exist) predicate=_if_exist)
else:
exe.run(startup_program)
return inputs, outputs, context_prog return inputs, outputs, context_prog
def rpn_head(self): def rpn_head(self):
......
...@@ -90,7 +90,7 @@ class FasterRCNNResNet50RPN(hub.Module): ...@@ -90,7 +90,7 @@ class FasterRCNNResNet50RPN(hub.Module):
with fluid.program_guard(context_prog, startup_program): with fluid.program_guard(context_prog, startup_program):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
image = fluid.layers.data( image = fluid.layers.data(
name='image', shape=[3, 800, 1333], dtype='float32') name='image', shape=[-1, 3, -1, -1], dtype='float32')
# backbone # backbone
backbone = ResNet( backbone = ResNet(
norm_type='affine_channel', norm_type='affine_channel',
...@@ -202,9 +202,13 @@ class FasterRCNNResNet50RPN(hub.Module): ...@@ -202,9 +202,13 @@ class FasterRCNNResNet50RPN(hub.Module):
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(startup_program)
if pretrained: if pretrained:
def _if_exist(var): def _if_exist(var):
if num_classes != 81:
if 'bbox_pred' in var.name or 'cls_score' in var.name:
return False
return os.path.exists( return os.path.exists(
os.path.join(self.default_pretrained_model_path, os.path.join(self.default_pretrained_model_path,
var.name)) var.name))
...@@ -213,8 +217,6 @@ class FasterRCNNResNet50RPN(hub.Module): ...@@ -213,8 +217,6 @@ class FasterRCNNResNet50RPN(hub.Module):
exe, exe,
self.default_pretrained_model_path, self.default_pretrained_model_path,
predicate=_if_exist) predicate=_if_exist)
else:
exe.run(startup_program)
return inputs, outputs, context_prog return inputs, outputs, context_prog
def rpn_head(self): def rpn_head(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册