提交 dfb5a46e 编写于 作者: S SunAhong1993

fix

上级 60cd5b3b
......@@ -298,6 +298,8 @@ class OpSet9():
# TODO(syf): all use
inputs['out_shape'] = var_hw
ipt = inputs.pop("x")
inputs["input"] = ipt
mode = node.get_attr('mode', 'nearest')
attrs = {"align_corners": False}
self.paddle_graph.add_layer(
......@@ -455,24 +457,11 @@ class OpSet9():
outputs=[node.layer_name],
shape=[1])
else:
if str(val_x.dtype) == 'bool':
val_x_cast = val_x.layer_name + '_cast'
self.paddle_graph.add_layer(
'paddle.cast',
inputs={"x": self.get_node_name(val_x)},
outputs=[val_x_cast],
dtype=string('int64'))
self.paddle_graph.add_layer(
'paddle.unsqueeze',
inputs={"x": val_x_cast},
outputs=[node.layer_name],
**layer_attrs)
else:
self.paddle_graph.add_layer(
'paddle.unsqueeze',
inputs={"x": self.get_node_name(val_x)},
outputs=[node.layer_name],
**layer_attrs)
self.paddle_graph.add_layer(
'paddle.unsqueeze',
inputs={"x": self.get_node_name(val_x)},
outputs=[node.layer_name],
**layer_attrs)
@print_mapping_info
def Shrink(self, node):
......@@ -1013,25 +1002,6 @@ class OpSet9():
inputs={'x': self.get_node_name(val_x)},
outputs=[node.layer_name],
shape=node.out_shapes[0])
elif val_shape.dtype == 'int64':
val_shape_cast = val_shape.layer_name + '_cast'
self.paddle_graph.add_layer(
'paddle.cast',
inputs={'x': self.get_node_name(val_shape)},
outputs=[val_shape_cast],
dtype=string('int32'))
# shape may be [], come form Gather by scalar indices
if len(val_shape.out_shapes[0]) > 0:
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={'x': self.get_node_name(val_shape_cast)},
outputs=[val_shape_cast],
shape=val_shape.out_shapes[0])
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={'x': elf.get_node_name(val_x),
'shape': val_shape_cast},
outputs=[node.layer_name])
else:
# shape may be [], come form Gather by scalar indices
if len(val_shape.out_shapes[0]) > 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册