提交 3fbf8b4c 编写于 作者: W wuzewu

Fixed the bug where multiple call contexts caused parameter names to alternate in retinanet module

上级 abc05deb
...@@ -69,7 +69,7 @@ class RetinaNetResNet50FPN(hub.Module): ...@@ -69,7 +69,7 @@ class RetinaNetResNet50FPN(hub.Module):
num_classes=81, num_classes=81,
trainable=True, trainable=True,
pretrained=True, pretrained=True,
get_prediction=False): phase='train'):
""" """
Distill the Head Features, so as to perform transfer learning. Distill the Head Features, so as to perform transfer learning.
...@@ -77,7 +77,7 @@ class RetinaNetResNet50FPN(hub.Module): ...@@ -77,7 +77,7 @@ class RetinaNetResNet50FPN(hub.Module):
num_classes (int): number of classes. num_classes (int): number of classes.
trainable (bool): whether to set parameters trainable. trainable (bool): whether to set parameters trainable.
pretrained (bool): whether to load default pretrained model. pretrained (bool): whether to load default pretrained model.
get_prediction (bool): whether to get prediction. phase (str): optional choices are 'train' and 'predict'.
Returns: Returns:
inputs(dict): the input variables. inputs(dict): the input variables.
...@@ -87,6 +87,7 @@ class RetinaNetResNet50FPN(hub.Module): ...@@ -87,6 +87,7 @@ class RetinaNetResNet50FPN(hub.Module):
context_prog = fluid.Program() context_prog = fluid.Program()
startup_program = fluid.Program() startup_program = fluid.Program()
with fluid.program_guard(context_prog, startup_program): with fluid.program_guard(context_prog, startup_program):
with fluid.unique_name.guard():
var_prefix = '@HUB_{}@'.format(self.name) var_prefix = '@HUB_{}@'.format(self.name)
# image # image
image = fluid.layers.data( image = fluid.layers.data(
...@@ -139,14 +140,16 @@ class RetinaNetResNet50FPN(hub.Module): ...@@ -139,14 +140,16 @@ class RetinaNetResNet50FPN(hub.Module):
'image': var_prefix + image.name, 'image': var_prefix + image.name,
'im_info': var_prefix + im_info.name 'im_info': var_prefix + im_info.name
} }
if get_prediction: if phase == 'predict':
pred = retina_head.get_prediction(body_feats, spatial_scale, pred = retina_head.get_prediction(body_feats, spatial_scale,
im_info) im_info)
outputs = {'bbox_out': var_prefix + pred.name} outputs = {'bbox_out': var_prefix + pred.name}
else: else:
outputs = { outputs = {
'body_features': 'body_features': [
[var_prefix + var.name for key, var in body_feats.items()] var_prefix + var.name
for key, var in body_feats.items()
]
} }
# add_vars_prefix # add_vars_prefix
...@@ -154,7 +157,10 @@ class RetinaNetResNet50FPN(hub.Module): ...@@ -154,7 +157,10 @@ class RetinaNetResNet50FPN(hub.Module):
add_vars_prefix(fluid.default_startup_program(), var_prefix) add_vars_prefix(fluid.default_startup_program(), var_prefix)
global_vars = context_prog.global_block().vars global_vars = context_prog.global_block().vars
inputs = {key: global_vars[value] for key, value in inputs.items()} inputs = {
key: global_vars[value]
for key, value in inputs.items()
}
outputs = { outputs = {
key: global_vars[value] if not isinstance(value, list) else key: global_vars[value] if not isinstance(value, list) else
[global_vars[var] for var in value] [global_vars[var] for var in value]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册