提交 6ae67e35 编写于 作者: Y yeliang2258

fix

上级 3967c640
...@@ -688,28 +688,26 @@ class OpSet9(): ...@@ -688,28 +688,26 @@ class OpSet9():
axes = node.get_attr('axes') axes = node.get_attr('axes')
if axes is None: if axes is None:
axes = self.graph.get_input_node(node, idx=1, copy=True) axes = self.graph.get_input_node(node, idx=1, copy=True)
if node.name in ["x2paddle_vis_local_cost_volume_3d_0_ExpandDims_5_0"]: if len(val_x.out_shapes[0]) == 0 and len(axes) == 1 and axes[0] == 0:
print("output_shape:", val_x.out_shapes[0]) if node.name:
# if len(val_x.out_shapes[0]) == 0: self.paddle_graph.add_layer(
# if node.name: 'paddle.reshape',
# self.paddle_graph.add_layer( inputs={"x": val_x.name},
# 'paddle.reshape', outputs=[node.name],
# inputs={"x": val_x.name}, shape=[1])
# outputs=[node.name],
# shape=[1])
# else:
if isinstance(axes, list) or isinstance(axes, tuple):
self.paddle_graph.add_layer(
'paddle.unsqueeze',
inputs={"x": val_x.name},
axis=axes,
outputs=[node.name])
else: else:
self.paddle_graph.add_layer( if isinstance(axes, list) or isinstance(axes, tuple):
'paddle.unsqueeze', self.paddle_graph.add_layer(
inputs={"x": val_x.name, 'paddle.unsqueeze',
"axis": axes.name}, inputs={"x": val_x.name},
outputs=[node.name]) axis=axes,
outputs=[node.name])
else:
self.paddle_graph.add_layer(
'paddle.unsqueeze',
inputs={"x": val_x.name,
"axis": axes.name},
outputs=[node.name])
@print_mapping_info @print_mapping_info
def Shrink(self, node): def Shrink(self, node):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册