提交 92ecc5a9 编写于 作者: W wjj19950828

deal with comments

上级 d7621aea
...@@ -1651,47 +1651,18 @@ class OpSet9(): ...@@ -1651,47 +1651,18 @@ class OpSet9():
outputs=[node.name], outputs=[node.name],
shape=shape_list) shape=shape_list)
else: else:
# flatten + reshape
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.shape', "paddle.flatten",
inputs={"input": val_x.name}, inputs={"input": val_x.name},
outputs=[val_x.name + "_shape"]) outputs=[val_x.name + "_flatten"],
self.paddle_graph.add_layer( start_axis=[0],
"paddle.slice", stop_axis=[axis])
inputs={"input": val_x.name + "_shape"},
outputs=[val_x.name + "_shape_first"],
axes=[0],
starts=[0],
ends=[axis])
self.paddle_graph.add_layer(
'paddle.prod',
inputs={"x": val_x.name + "_shape_first"},
outputs=[val_x.name + "_shape_first"])
self.paddle_graph.add_layer(
"paddle.slice",
inputs={"input": val_x.name + "_shape"},
outputs=[val_x.name + "_shape_second"],
axes=[0],
starts=[axis],
ends=[2147483647])
self.paddle_graph.add_layer(
'paddle.prod',
inputs={"x": val_x.name + "_shape_second"},
outputs=[val_x.name + "_shape_second"])
self.paddle_graph.add_layer(
'paddle.concat',
inputs={
"x": [
val_x.name + "_shape_first",
val_x.name + "_shape_second"
]
},
outputs=[val_x.name + "_all_shape"],
axis=0)
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.reshape', 'paddle.reshape',
inputs={'x': val_x.name}, inputs={'x': val_x.name + "_flatten"},
outputs=[node.name], outputs=[node.name],
shape=val_x.name + "_all_shape") shape=[0, -1])
@print_mapping_info @print_mapping_info
def Gemm(self, node): def Gemm(self, node):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册