提交 a91d37bd 编写于 作者: W wjj19950828

fixed Flatten

上级 4eb7510d
...@@ -1627,20 +1627,66 @@ class OpSet9(): ...@@ -1627,20 +1627,66 @@ class OpSet9():
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
output_shape = val_x.out_shapes[0] output_shape = val_x.out_shapes[0]
axis = node.get_attr('axis', 1) axis = node.get_attr('axis', 1)
shape_list = [1, 1]
if axis == 0: if axis == 0:
for s in output_shape: self.paddle_graph.add_layer(
shape_list[1] *= s 'paddle.reshape',
inputs={"x": val_x.name},
outputs=[node.name],
shape=[1, -1])
else: else:
for s in output_shape[:axis]: if len(output_shape) != 0:
shape_list[0] *= s shape_list = [1, 1]
for s in output_shape[axis:]: for s in output_shape[:axis]:
shape_list[1] *= s shape_list[0] *= s
self.paddle_graph.add_layer( for s in output_shape[axis:]:
'paddle.reshape', shape_list[1] *= s
inputs={"x": val_x.name}, self.paddle_graph.add_layer(
outputs=[node.name], 'paddle.reshape',
shape=shape_list) inputs={"x": val_x.name},
outputs=[node.name],
shape=shape_list)
else:
self.paddle_graph.add_layer(
'paddle.shape',
inputs={"input": val_x.name},
outputs=[val_x.name + "_shape"])
self.paddle_graph.add_layer(
"paddle.slice",
inputs={"input": val_x.name + "_shape"},
outputs=[val_x.name + "_shape_first"],
axes=[0],
starts=[0],
ends=[axis])
self.paddle_graph.add_layer(
'paddle.prod',
inputs={"x": val_x.name + "_shape_first"},
outputs=[val_x.name + "_shape_first"])
self.paddle_graph.add_layer(
"paddle.slice",
inputs={"input": val_x.name + "_shape"},
outputs=[val_x.name + "_shape_second"],
axes=[0],
starts=[axis],
ends=[2147483647])
self.paddle_graph.add_layer(
'paddle.prod',
inputs={"x": val_x.name + "_shape_second"},
outputs=[val_x.name + "_shape_second"])
self.paddle_graph.add_layer(
'paddle.concat',
inputs={
"x": [
val_x.name + "_shape_first",
val_x.name + "_shape_second"
]
},
outputs=[val_x.name + "_all_shape"],
axis=0)
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={'x': val_x.name},
outputs=[node.name],
shape=val_x.name + "_all_shape")
@print_mapping_info @print_mapping_info
def Gemm(self, node): def Gemm(self, node):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册