提交 06a3deb5 编写于 作者: W wjj19950828

fixed nonzero

上级 005a420a
...@@ -1880,30 +1880,13 @@ class OpSet9(): ...@@ -1880,30 +1880,13 @@ class OpSet9():
@print_mapping_info @print_mapping_info
def NonZero(self, node): def NonZero(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_x_dim = len(val_x.out_shapes[0]) self.paddle_graph.add_layer(
if val_x_dim == 1: "paddle.nonzero",
self.paddle_graph.add_layer( inputs={"x": val_x.name},
"paddle.nonzero", outputs=[val_x.name],
inputs={"x": val_x.name}, as_tuple=True)
outputs=[val_x.name]) self.paddle_graph.add_layer(
self.paddle_graph.add_layer( "paddle.concat", inputs={"x": val_x.name}, outputs=[node.name])
"paddle.transpose",
inputs={"x": val_x.name},
outputs=[node.layer_name],
perm=[1, 0])
if val_x_dim > 1:
self.paddle_graph.add_layer(
"paddle.nonzero",
inputs={"x": val_x.name},
outputs=[val_x.name])
self.paddle_graph.add_layer(
"paddle.split",
inputs={"x": val_x.name},
outputs=[val_x.name],
num_or_sections=1,
axis=val_x_dim)
self.paddle_graph.add_layer(
"paddle.concat", inputs={"x": val_x.name}, outputs=[node.name])
@print_mapping_info @print_mapping_info
def Identity(self, node): def Identity(self, node):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册