提交 dfb5a46e 编写于 作者: S SunAhong1993

fix

上级 60cd5b3b
...@@ -298,6 +298,8 @@ class OpSet9(): ...@@ -298,6 +298,8 @@ class OpSet9():
# TODO(syf): all use # TODO(syf): all use
inputs['out_shape'] = var_hw inputs['out_shape'] = var_hw
ipt = inputs.pop("x")
inputs["input"] = ipt
mode = node.get_attr('mode', 'nearest') mode = node.get_attr('mode', 'nearest')
attrs = {"align_corners": False} attrs = {"align_corners": False}
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
...@@ -454,19 +456,6 @@ class OpSet9(): ...@@ -454,19 +456,6 @@ class OpSet9():
inputs={"x": self.get_node_name(val_x)}, inputs={"x": self.get_node_name(val_x)},
outputs=[node.layer_name], outputs=[node.layer_name],
shape=[1]) shape=[1])
else:
if str(val_x.dtype) == 'bool':
val_x_cast = val_x.layer_name + '_cast'
self.paddle_graph.add_layer(
'paddle.cast',
inputs={"x": self.get_node_name(val_x)},
outputs=[val_x_cast],
dtype=string('int64'))
self.paddle_graph.add_layer(
'paddle.unsqueeze',
inputs={"x": val_x_cast},
outputs=[node.layer_name],
**layer_attrs)
else: else:
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.unsqueeze', 'paddle.unsqueeze',
...@@ -1013,25 +1002,6 @@ class OpSet9(): ...@@ -1013,25 +1002,6 @@ class OpSet9():
inputs={'x': self.get_node_name(val_x)}, inputs={'x': self.get_node_name(val_x)},
outputs=[node.layer_name], outputs=[node.layer_name],
shape=node.out_shapes[0]) shape=node.out_shapes[0])
elif val_shape.dtype == 'int64':
val_shape_cast = val_shape.layer_name + '_cast'
self.paddle_graph.add_layer(
'paddle.cast',
inputs={'x': self.get_node_name(val_shape)},
outputs=[val_shape_cast],
dtype=string('int32'))
# shape may be [], come form Gather by scalar indices
if len(val_shape.out_shapes[0]) > 0:
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={'x': self.get_node_name(val_shape_cast)},
outputs=[val_shape_cast],
shape=val_shape.out_shapes[0])
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={'x': elf.get_node_name(val_x),
'shape': val_shape_cast},
outputs=[node.layer_name])
else: else:
# shape may be [], come form Gather by scalar indices # shape may be [], come form Gather by scalar indices
if len(val_shape.out_shapes[0]) > 0: if len(val_shape.out_shapes[0]) > 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册