提交 6ae67e35 编写于 作者: Y yeliang2258

fix

上级 3967c640
......@@ -688,28 +688,26 @@ class OpSet9():
axes = node.get_attr('axes')
if axes is None:
axes = self.graph.get_input_node(node, idx=1, copy=True)
if node.name in ["x2paddle_vis_local_cost_volume_3d_0_ExpandDims_5_0"]:
print("output_shape:", val_x.out_shapes[0])
# if len(val_x.out_shapes[0]) == 0:
# if node.name:
# self.paddle_graph.add_layer(
# 'paddle.reshape',
# inputs={"x": val_x.name},
# outputs=[node.name],
# shape=[1])
# else:
if isinstance(axes, list) or isinstance(axes, tuple):
self.paddle_graph.add_layer(
'paddle.unsqueeze',
inputs={"x": val_x.name},
axis=axes,
outputs=[node.name])
if len(val_x.out_shapes[0]) == 0 and len(axes) == 1 and axes[0] == 0:
if node.name:
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={"x": val_x.name},
outputs=[node.name],
shape=[1])
else:
self.paddle_graph.add_layer(
'paddle.unsqueeze',
inputs={"x": val_x.name,
"axis": axes.name},
outputs=[node.name])
if isinstance(axes, list) or isinstance(axes, tuple):
self.paddle_graph.add_layer(
'paddle.unsqueeze',
inputs={"x": val_x.name},
axis=axes,
outputs=[node.name])
else:
self.paddle_graph.add_layer(
'paddle.unsqueeze',
inputs={"x": val_x.name,
"axis": axes.name},
outputs=[node.name])
@print_mapping_info
def Shrink(self, node):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册