提交 feb0f157 编写于 作者: W wjj19950828

deal with scalar tensor

上级 422a6b56
...@@ -740,26 +740,21 @@ class OpSet9(): ...@@ -740,26 +740,21 @@ class OpSet9():
def Unsqueeze(self, node): def Unsqueeze(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
axes = node.get_attr('axes') axes = node.get_attr('axes')
if axes is not None: if axes is None:
if len(val_x.out_shapes[0]) == 0 and len(axes) == 1 and axes[ axes_node = self.graph.get_input_node(node, idx=1, copy=True)
0] == 0: axes = _const_weight_or_none(axes_node, necessary=True)
self.paddle_graph.add_layer( # deal with scalar(0D) tensor
'paddle.reshape', if len(val_x.out_shapes[0]) == 0 and len(axes) == 1 and axes[0] == 0:
inputs={"x": val_x.name}, self.paddle_graph.add_layer(
outputs=[node.name], 'paddle.reshape',
shape=[1]) inputs={"x": val_x.name},
else: outputs=[node.name],
self.paddle_graph.add_layer( shape=[1])
'paddle.unsqueeze',
inputs={"x": val_x.name},
axis=axes,
outputs=[node.name])
else: else:
axes = self.graph.get_input_node(node, idx=1, copy=True)
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.unsqueeze', 'paddle.unsqueeze',
inputs={"x": val_x.name, inputs={"x": val_x.name},
"axis": axes.name}, axis=axes,
outputs=[node.name]) outputs=[node.name])
@print_mapping_info @print_mapping_info
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册