提交 a91d37bd 编写于 作者: W wjj19950828

fixed Flatten

上级 4eb7510d
......@@ -1627,20 +1627,66 @@ class OpSet9():
val_x = self.graph.get_input_node(node, idx=0, copy=True)
output_shape = val_x.out_shapes[0]
axis = node.get_attr('axis', 1)
shape_list = [1, 1]
if axis == 0:
for s in output_shape:
shape_list[1] *= s
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={"x": val_x.name},
outputs=[node.name],
shape=[1, -1])
else:
for s in output_shape[:axis]:
shape_list[0] *= s
for s in output_shape[axis:]:
shape_list[1] *= s
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={"x": val_x.name},
outputs=[node.name],
shape=shape_list)
if len(output_shape) != 0:
shape_list = [1, 1]
for s in output_shape[:axis]:
shape_list[0] *= s
for s in output_shape[axis:]:
shape_list[1] *= s
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={"x": val_x.name},
outputs=[node.name],
shape=shape_list)
else:
self.paddle_graph.add_layer(
'paddle.shape',
inputs={"input": val_x.name},
outputs=[val_x.name + "_shape"])
self.paddle_graph.add_layer(
"paddle.slice",
inputs={"input": val_x.name + "_shape"},
outputs=[val_x.name + "_shape_first"],
axes=[0],
starts=[0],
ends=[axis])
self.paddle_graph.add_layer(
'paddle.prod',
inputs={"x": val_x.name + "_shape_first"},
outputs=[val_x.name + "_shape_first"])
self.paddle_graph.add_layer(
"paddle.slice",
inputs={"input": val_x.name + "_shape"},
outputs=[val_x.name + "_shape_second"],
axes=[0],
starts=[axis],
ends=[2147483647])
self.paddle_graph.add_layer(
'paddle.prod',
inputs={"x": val_x.name + "_shape_second"},
outputs=[val_x.name + "_shape_second"])
self.paddle_graph.add_layer(
'paddle.concat',
inputs={
"x": [
val_x.name + "_shape_first",
val_x.name + "_shape_second"
]
},
outputs=[val_x.name + "_all_shape"],
axis=0)
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={'x': val_x.name},
outputs=[node.name],
shape=val_x.name + "_all_shape")
@print_mapping_info
def Gemm(self, node):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册