提交 6511e005 编写于 作者: S SunAhong1993

remove transpose

上级 4d7134cc
...@@ -96,6 +96,7 @@ class TFOpMapper(OpMapper): ...@@ -96,6 +96,7 @@ class TFOpMapper(OpMapper):
raise Exception("Model is not supported yet.") raise Exception("Model is not supported yet.")
self.params = dict() self.params = dict()
self.paddle_graph = PaddleGraph(parent_layer=None, graph_type="static", source_type="tf") self.paddle_graph = PaddleGraph(parent_layer=None, graph_type="static", source_type="tf")
self.params_output2id = dict()
not_placeholder = list() not_placeholder = list()
for name in self.graph.input_nodes: for name in self.graph.input_nodes:
...@@ -224,7 +225,7 @@ class TFOpMapper(OpMapper): ...@@ -224,7 +225,7 @@ class TFOpMapper(OpMapper):
return return
self.params[node.name] = node.value self.params[node.name] = node.value
self.paddle_graph.add_layer( layer_id = self.paddle_graph.add_layer(
kernel="paddle.static.create_parameter", kernel="paddle.static.create_parameter",
inputs={}, inputs={},
outputs=[node.name], outputs=[node.name],
...@@ -232,6 +233,7 @@ class TFOpMapper(OpMapper): ...@@ -232,6 +233,7 @@ class TFOpMapper(OpMapper):
shape=shape, shape=shape,
name=string(node.name), name=string(node.name),
default_initializer="paddle.nn.initializer.Constant(value=0.0)") default_initializer="paddle.nn.initializer.Constant(value=0.0)")
self.params_output2id[node.name] = layer_id
def Transpose(self, node): def Transpose(self, node):
input = self.graph.get_node(node.layer.input[0]) input = self.graph.get_node(node.layer.input[0])
...@@ -774,11 +776,17 @@ class TFOpMapper(OpMapper): ...@@ -774,11 +776,17 @@ class TFOpMapper(OpMapper):
data_format = node.get_attr("data_format").decode() data_format = node.get_attr("data_format").decode()
pad_mode = node.get_attr("padding").decode() pad_mode = node.get_attr("padding").decode()
self.paddle_graph.add_layer( if len(kernel.outputs) == 1:
kernel="paddle.transpose", self.params[kernel.name] = numpy.transpose(self.params[kernel.name],
inputs={"x": kernel.name}, (2, 3, 0, 1))
outputs=[kernel.name], layer = self.paddle_graph.layers[self.params_output2id[kernel.name]]
perm=[2, 3, 0, 1]) layer.attrs["shape"] = self.params[kernel.name].shape
else:
self.paddle_graph.add_layer(
kernel="paddle.transpose",
inputs={"x": kernel.name},
outputs=[kernel.name],
perm=[2, 3, 0, 1])
input_name = input.name input_name = input.name
if data_format == "NHWC": if data_format == "NHWC":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册