未验证 提交 e2b3e7e0 编写于 作者: J Jason 提交者: GitHub

Merge pull request #536 from SunAhong1993/develop

add PyTorch op
...@@ -76,6 +76,7 @@ class PaddleGraph(object): ...@@ -76,6 +76,7 @@ class PaddleGraph(object):
self.source_type = source_type self.source_type = source_type
self.custom_code = None self.custom_code = None
self.inputs_info = None self.inputs_info = None
self.has_unpack = False
def set_name(self, name): def set_name(self, name):
self.name = name.replace("-", "_").replace("/", "_") self.name = name.replace("-", "_").replace("/", "_")
...@@ -112,6 +113,8 @@ class PaddleGraph(object): ...@@ -112,6 +113,8 @@ class PaddleGraph(object):
layer_id) layer_id)
layer = PaddleLayer(layer_id, kernel, inputs, outputs, scope_name=scope_name, **kwargs) layer = PaddleLayer(layer_id, kernel, inputs, outputs, scope_name=scope_name, **kwargs)
self.layers[layer_id] = layer self.layers[layer_id] = layer
if layer.kernel in ["prim.list_unpack" or "prim.tuple_unpack"]:
self.has_unpack = True
return layer_id return layer_id
def del_layer(self, layer_id): def del_layer(self, layer_id):
...@@ -272,12 +275,16 @@ class PaddleGraph(object): ...@@ -272,12 +275,16 @@ class PaddleGraph(object):
def gen_dygraph_model(self, save_dir, jit_type=None): def gen_dygraph_model(self, save_dir, jit_type=None):
if jit_type == "trace": if jit_type == "trace":
from x2paddle.optimizer.pytorch_code_optimizer import HierarchicalTree if not self.has_unpack:
hierarchical_tree = HierarchicalTree(self) from x2paddle.optimizer.pytorch_code_optimizer import HierarchicalTree
for layer_id, layer in self.layers.items(): hierarchical_tree = HierarchicalTree(self)
hierarchical_tree.insert(layer) for layer_id, layer in self.layers.items():
hierarchical_tree.save_source_files(save_dir) hierarchical_tree.insert(layer)
self.dump_dygraph_parameter(save_dir) hierarchical_tree.save_source_files(save_dir)
self.dump_dygraph_parameter(save_dir)
else:
self.gen_dygraph_code(save_dir)
self.dump_dygraph_parameter(save_dir)
else: else:
if self.source_type == "pytorch": if self.source_type == "pytorch":
from x2paddle.optimizer.pytorch_code_optimizer import ModuleGraph from x2paddle.optimizer.pytorch_code_optimizer import ModuleGraph
......
...@@ -64,5 +64,5 @@ class TraceDecoder(Decoder): ...@@ -64,5 +64,5 @@ class TraceDecoder(Decoder):
print(e) print(e)
exit(0) exit(0)
self.graph = self._optimize_graph(self.script.inlined_graph) self.graph = self._optimize_graph(self.script.inlined_graph)
self.input_examples = input_examples self.input_examples = input_examples
...@@ -101,6 +101,7 @@ class TFGraphNode(GraphNode): ...@@ -101,6 +101,7 @@ class TFGraphNode(GraphNode):
@property @property
def name(self): def name(self):
if hasattr(self, 'index'): if hasattr(self, 'index'):
print(self.layer_type)
return self.layer_name + "_p{}".format(self.index) return self.layer_name + "_p{}".format(self.index)
return self.layer_name return self.layer_name
...@@ -184,7 +185,7 @@ class TFGraph(Graph): ...@@ -184,7 +185,7 @@ class TFGraph(Graph):
node = super(TFGraph, self).get_node(new_node_name, copy) node = super(TFGraph, self).get_node(new_node_name, copy)
if node is None: if node is None:
return None return None
if node.layer_type == "Switch": if node.layer_type in ["Switch", "Reshape", "Sub"]:
if hasattr(node, 'index'): if hasattr(node, 'index'):
del node.index del node.index
if len(items) == 1 and node.layer_type in self.multi_out_ops: if len(items) == 1 and node.layer_type in self.multi_out_ops:
...@@ -284,6 +285,11 @@ class TFGraph(Graph): ...@@ -284,6 +285,11 @@ class TFGraph(Graph):
if node_name in self.output_nodes: if node_name in self.output_nodes:
idx = self.output_nodes.index(node_name) idx = self.output_nodes.index(node_name)
self.output_nodes[idx] = input_node.layer_name self.output_nodes[idx] = input_node.layer_name
if len(input_node.outputs) > 0:
self.output_nodes.pop(idx)
else:
self.output_nodes[idx] = input_node.layer_name
def _remove_cast_node(self): def _remove_cast_node(self):
cast_node = list() cast_node = list()
......
...@@ -37,20 +37,24 @@ def prim_Constant(mapper, graph, node): ...@@ -37,20 +37,24 @@ def prim_Constant(mapper, graph, node):
tensor_value = value tensor_value = value
value = "{}".format(value) value = "{}".format(value)
if "tensor" in value: if "tensor" in value:
if isinstance(tensor_value, list) or isinstance(tensor_value, tuple): if isinstance(tensor_value, list) or isinstance(tensor_value,
tuple):
name_dict = dict() name_dict = dict()
for i, tv in enumerate(tensor_value): for i, tv in enumerate(tensor_value):
output_name_i = "{}_p{}".format(output_name,i) output_name_i = "{}_p{}".format(output_name, i)
key_i = "input{}".format(i) key_i = "input{}".format(i)
mapper.paddle_params[output_name_i] = tv.cpu().detach().numpy() mapper.paddle_params[output_name_i] = tv.cpu().detach(
).numpy()
graph.add_layer( graph.add_layer(
"self.create_parameter", "self.create_parameter",
inputs={}, inputs={},
outputs=[output_name_i], outputs=[output_name_i],
scope_name=scope_name, scope_name=scope_name,
dtype=string(str(mapper.paddle_params[output_name_i].dtype)), dtype=string(
shape = mapper.paddle_params[output_name_i].shape, str(mapper.paddle_params[output_name_i].dtype)),
default_initializer="paddle.nn.initializer.Constant(value=0.0)") shape=mapper.paddle_params[output_name_i].shape,
default_initializer="paddle.nn.initializer.Constant(value=0.0)"
)
name_dict[key_i] = output_name_i name_dict[key_i] = output_name_i
graph.add_layer( graph.add_layer(
"prim.list", "prim.list",
...@@ -59,8 +63,19 @@ def prim_Constant(mapper, graph, node): ...@@ -59,8 +63,19 @@ def prim_Constant(mapper, graph, node):
scope_name=scope_name) scope_name=scope_name)
return [], [output_name] return [], [output_name]
else: else:
mapper.pytorch_params[output_name] = tensor_value.cpu().detach().numpy() # mapper.pytorch_params[output_name] = tensor_value.cpu().detach().numpy()
mapper.paddle_params[output_name] = tensor_value.cpu().detach(
).numpy()
graph.add_layer(
"self.create_parameter",
inputs={},
outputs=[output_name],
scope_name=scope_name,
dtype=string(str(mapper.paddle_params[output_name].dtype)),
shape=mapper.paddle_params[output_name].shape,
default_initializer="paddle.nn.initializer.Constant(value=0.0)"
)
return [], [output_name]
if "inf" in str(value): if "inf" in str(value):
t = str(type(value)).split("'")[1] t = str(type(value)).split("'")[1]
if str(value).startswith("-"): if str(value).startswith("-"):
...@@ -72,7 +87,11 @@ def prim_Constant(mapper, graph, node): ...@@ -72,7 +87,11 @@ def prim_Constant(mapper, graph, node):
value = int(math.pow(2, 31) - 1) value = int(math.pow(2, 31) - 1)
mapper.attrs[output_name] = value mapper.attrs[output_name] = value
graph.add_layer( graph.add_layer(
"prim.constant", inputs={}, outputs=[output_name], scope_name=scope_name, value=value) "prim.constant",
inputs={},
outputs=[output_name],
scope_name=scope_name,
value=value)
return [], [output_name] return [], [output_name]
...@@ -96,18 +115,23 @@ def prim_data(mapper, graph, node): ...@@ -96,18 +115,23 @@ def prim_data(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%4336 # 处理输入0,即%4336
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.equal", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"prim.equal",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
def prim_DictConstruct(mapper, graph, node): def prim_DictConstruct(mapper, graph, node):
""" 构建dict。 """ 构建dict。
TorchScript示例: TorchScript示例:
%32 : Dict(str, Tensor) = prim::DictConstruct(%30, %23, %31, %29) %32 : Dict(str, Tensor) = prim::DictConstruct(%30, %23, %31, %29)
参数含义: 参数含义:
...@@ -127,22 +151,22 @@ def prim_DictConstruct(mapper, graph, node): ...@@ -127,22 +151,22 @@ def prim_DictConstruct(mapper, graph, node):
current_outputs = [output_name] current_outputs = [output_name]
# 处理每个输入 # 处理每个输入
for i, input_name in enumerate(inputs_name): for i, input_name in enumerate(inputs_name):
if i%2 == 0: if i % 2 == 0:
layer_attrs["key{}".format(int(i/2))] = mapper.attrs[input_name] layer_attrs["key{}".format(int(i / 2))] = mapper.attrs[input_name]
else: else:
layer_inputs["value{}".format(int(i/2))] = input_name layer_inputs["value{}".format(int(i / 2))] = input_name
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.dict_construct", graph.add_layer(
inputs=layer_inputs, "prim.dict_construct",
outputs=layer_outputs, inputs=layer_inputs,
scope_name=scope_name, outputs=layer_outputs,
**layer_attrs) scope_name=scope_name,
**layer_attrs)
return current_inputs, current_outputs return current_inputs, current_outputs
def prim_GetAttr(mapper, graph, node): def prim_GetAttr(mapper, graph, node):
""" 获取attribute信息。 """ 获取attribute信息。
...@@ -203,8 +227,13 @@ def prim_If(mapper, graph, node): ...@@ -203,8 +227,13 @@ def prim_If(mapper, graph, node):
input_node = list(node.inputs())[0].node() input_node = list(node.inputs())[0].node()
script_input_unique_id = list(node.inputs())[0].unique() script_input_unique_id = list(node.inputs())[0].unique()
input_node_name = mapper.outputs_info[script_input_unique_id] input_node_name = mapper.outputs_info[script_input_unique_id]
mapper._check_input(graph, input_node, input_node_name, current_outputs, scope_name) mapper._check_input(graph, input_node, input_node_name, current_outputs,
graph.add_layer("prim.if", inputs={'input': input_node_name}, outputs=node_outputs, scope_name=scope_name) scope_name)
graph.add_layer(
"prim.if",
inputs={'input': input_node_name},
outputs=node_outputs,
scope_name=scope_name)
current_layer = list(graph.layers.values())[-1] current_layer = list(graph.layers.values())[-1]
block0 = list(node.blocks())[0] block0 = list(node.blocks())[0]
block0_graph, graph_inputs0 = mapper.traverse(block0, current_layer) block0_graph, graph_inputs0 = mapper.traverse(block0, current_layer)
...@@ -240,12 +269,17 @@ def prim_ListConstruct(mapper, graph, node): ...@@ -240,12 +269,17 @@ def prim_ListConstruct(mapper, graph, node):
current_outputs = [output_name] current_outputs = [output_name]
# 处理每个输入 # 处理每个输入
for i, input_name in enumerate(inputs_name): for i, input_name in enumerate(inputs_name):
mapper._check_input(graph, inputs_node[i], input_name, current_outputs, scope_name) mapper._check_input(graph, inputs_node[i], input_name, current_outputs,
scope_name)
layer_inputs["input{}".format(i)] = input_name layer_inputs["input{}".format(i)] = input_name
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
layer_id = graph.add_layer("prim.list", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) layer_id = graph.add_layer(
"prim.list",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
mapper.output2id[output_name] = layer_id mapper.output2id[output_name] = layer_id
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -268,13 +302,17 @@ def prim_ListUnpack(mapper, graph, node): ...@@ -268,13 +302,17 @@ def prim_ListUnpack(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = layer_outputs.copy() current_outputs = layer_outputs.copy()
# 处理输入0,即%4354 # 处理输入0,即%4354
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"prim.list_unpack", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) "prim.list_unpack",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
mapper.split_len[list(layer_inputs.values())[0]] = len(layer_outputs) mapper.split_len[list(layer_inputs.values())[0]] = len(layer_outputs)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -333,7 +371,11 @@ def prim_Loop(mapper, graph, node): ...@@ -333,7 +371,11 @@ def prim_Loop(mapper, graph, node):
scope_name=scope_name) scope_name=scope_name)
node_outputs.append(block_input_node_name) node_outputs.append(block_input_node_name)
graph.add_layer("prim.loop", inputs=loop_inputs, outputs=loop_outputs, scope_name=scope_name) graph.add_layer(
"prim.loop",
inputs=loop_inputs,
outputs=loop_outputs,
scope_name=scope_name)
current_layer = list(graph.layers.values())[-1] current_layer = list(graph.layers.values())[-1]
block_graph, graph_inputs = mapper.traverse(block, current_layer) block_graph, graph_inputs = mapper.traverse(block, current_layer)
for i, input_name in enumerate(graph_inputs): for i, input_name in enumerate(graph_inputs):
...@@ -361,12 +403,17 @@ def prim_min(mapper, graph, node): ...@@ -361,12 +403,17 @@ def prim_min(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%86 # 处理输入0,即%86
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.min", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"prim.min",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -388,14 +435,19 @@ def prim_NumToTensor(mapper, graph, node): ...@@ -388,14 +435,19 @@ def prim_NumToTensor(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%86 # 处理输入0,即%86
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
inputs_inputs_name, inputs_inputs_node = mapper._get_inputs_name(inputs_node[0]) scope_name)
inputs_inputs_name, inputs_inputs_node = mapper._get_inputs_name(
inputs_node[0])
if inputs_node[0].kind() == "aten::size" and len(inputs_inputs_name) > 1: if inputs_node[0].kind() == "aten::size" and len(inputs_inputs_name) > 1:
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"prim_equal", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) "prim_equal",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
else: else:
layer_inputs["fill_value"] = inputs_name[0] layer_inputs["fill_value"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
...@@ -428,13 +480,17 @@ def prim_RaiseException(mapper, graph, node): ...@@ -428,13 +480,17 @@ def prim_RaiseException(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%76 # 处理输入0,即%76
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"prim.exception", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) "prim.exception",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -455,13 +511,17 @@ def prim_requires_grad(mapper, graph, node): ...@@ -455,13 +511,17 @@ def prim_requires_grad(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%86 # 处理输入0,即%86
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"prim.requires_grad", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) "prim.requires_grad",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -518,13 +578,17 @@ def prim_shape(mapper, graph, node): ...@@ -518,13 +578,17 @@ def prim_shape(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%input.8 # 处理输入0,即%input.8
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"paddle.shape", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) "paddle.shape",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -551,7 +615,11 @@ def prim_TupleConstruct(mapper, graph, node): ...@@ -551,7 +615,11 @@ def prim_TupleConstruct(mapper, graph, node):
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.tuple", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"prim.tuple",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -569,15 +637,23 @@ def prim_TupleUnpack(mapper, graph, node): ...@@ -569,15 +637,23 @@ def prim_TupleUnpack(mapper, graph, node):
outputs_name = mapper._get_outputs_name(node) outputs_name = mapper._get_outputs_name(node)
layer_outputs = outputs_name layer_outputs = outputs_name
layer_inputs = {} layer_inputs = {}
layer_attrs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node) inputs_name, inputs_node = mapper._get_inputs_name(node)
if inputs_node[0].kind() == "prim::GetAttr":
layer_attrs["input"] = list(mapper.pytorch_params[inputs_name[0]])
else:
layer_inputs["input"] = inputs_name[0]
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = outputs_name current_outputs = outputs_name
layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"prim.tuple_unpack", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) "prim.tuple_unpack",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name,
**layer_attrs)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -601,12 +677,17 @@ def prim_unchecked_cast(mapper, graph, node): ...@@ -601,12 +677,17 @@ def prim_unchecked_cast(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%size.63 # 处理输入0,即%size.63
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.equal", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"prim.equal",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -623,5 +704,9 @@ def prim_Uninitialized(mapper, graph, node): ...@@ -623,5 +704,9 @@ def prim_Uninitialized(mapper, graph, node):
output = list(node.outputs())[0] output = list(node.outputs())[0]
mapper.attrs[output_name] = None mapper.attrs[output_name] = None
graph.add_layer( graph.add_layer(
"prim.constant", inputs={}, outputs=[output_name], scope_name=scope_name, value=None) "prim.constant",
inputs={},
outputs=[output_name],
scope_name=scope_name,
value=None)
return [], [output_name] return [], [output_name]
...@@ -13,4 +13,5 @@ ...@@ -13,4 +13,5 @@
# limitations under the License. # limitations under the License.
from .gather import Gather from .gather import Gather
\ No newline at end of file from .instance_norm import InstanceNorm
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle.nn.functional import instance_norm
from paddle.fluid.initializer import Constant
class InstanceNorm(paddle.nn.Layer):
"""
This class is based class for InstanceNorm1D, 2d, 3d.
See InstaceNorm1D, InstanceNorm2D or InstanceNorm3D for more details.
"""
def __init__(self,
num_features,
epsilon=1e-5,
momentum=0.9,
weight_attr=None,
bias_attr=None,
data_format="NCHW",
name=None):
super(InstanceNorm, self).__init__()
if weight_attr == False or bias_attr == False:
assert weight_attr == bias_attr, "weight_attr and bias_attr must be set to Fasle at the same time in InstanceNorm"
self._epsilon = epsilon
self._weight_attr = weight_attr
self._bias_attr = bias_attr
if weight_attr != False and bias_attr != False:
self.scale = self.create_parameter(
attr=self._weight_attr,
shape=[num_features],
default_initializer=Constant(1.0),
is_bias=False)
self.bias = self.create_parameter(
attr=self._bias_attr,
shape=[num_features],
default_initializer=Constant(0.0),
is_bias=True)
else:
self.scale = None
self.bias = None
def forward(self, input):
return instance_norm(
input, weight=self.scale, bias=self.bias, eps=self._epsilon)
def extra_repr(self):
return 'num_features={}, epsilon={}'.format(self.scale.shape[0],
self._epsilon)
...@@ -37,7 +37,7 @@ class PyTorchOpMapper(OpMapper): ...@@ -37,7 +37,7 @@ class PyTorchOpMapper(OpMapper):
self.scope_name_list = list() self.scope_name_list = list()
self.scope_name2id = dict() self.scope_name2id = dict()
self.inputs_info = dict() self.inputs_info = dict()
self.output2id = dict() # output名字和layer_id的关系,用于lstm去除前面的node self.output2id = dict() # output名字和layer_id的关系,用于lstm去除前面的node
# 转换 # 转换
if not self.op_checker(decoder.graph): if not self.op_checker(decoder.graph):
raise Exception("Model is not supported yet.") raise Exception("Model is not supported yet.")
...@@ -50,6 +50,7 @@ class PyTorchOpMapper(OpMapper): ...@@ -50,6 +50,7 @@ class PyTorchOpMapper(OpMapper):
op_list.append(node.kind()) op_list.append(node.kind())
for block in node.blocks(): for block in node.blocks():
_update_op_list(block) _update_op_list(block)
op_list = list() op_list = list()
_update_op_list(script_graph) _update_op_list(script_graph)
op_list = list(set(op_list)) op_list = list(set(op_list))
...@@ -62,11 +63,11 @@ class PyTorchOpMapper(OpMapper): ...@@ -62,11 +63,11 @@ class PyTorchOpMapper(OpMapper):
return True return True
else: else:
if len(unsupported_ops) > 0: if len(unsupported_ops) > 0:
print("\n========= {} OPs are not supported yet ===========".format( print("\n========= {} OPs are not supported yet ===========".
len(unsupported_ops))) format(len(unsupported_ops)))
for op in unsupported_ops: for op in unsupported_ops:
print("========== {} ============".format(op)) print("========== {} ============".format(op))
return False return False
def traverse(self, script_graph, parent_layer=None): def traverse(self, script_graph, parent_layer=None):
# 用于获取graph的输入 # 用于获取graph的输入
...@@ -85,20 +86,24 @@ class PyTorchOpMapper(OpMapper): ...@@ -85,20 +86,24 @@ class PyTorchOpMapper(OpMapper):
current_node_outputs.extend(outputs) current_node_outputs.extend(outputs)
# 初始化 # 初始化
graph = PaddleGraph(source_type="pytorch", parent_layer=parent_layer, graph_type="dygraph") graph = PaddleGraph(
source_type="pytorch",
parent_layer=parent_layer,
graph_type="dygraph")
if "TopLevelTracedModule" in str(type(self.script)): if "TopLevelTracedModule" in str(type(self.script)):
graph.set_script(self.script) graph.set_script(self.script)
current_node_outputs = [] current_node_outputs = []
graph_inputs = [] graph_inputs = []
# 转换输入节点 # 转换输入节点
if isinstance(script_graph, torch._C.Graph): if isinstance(script_graph, torch._C.Graph):
input_ct = 0 input_ct = 0
for i, ivalue in enumerate(script_graph.inputs()): for i, ivalue in enumerate(script_graph.inputs()):
node = ivalue.node() node = ivalue.node()
if str(ivalue.type()) not in ["Tensor", "Dict[str, Tensor]"]: if str(ivalue.type()) not in ["Tensor", "Dict[str, Tensor]"]:
graph.set_name(str(ivalue.type()).split(".")[-1]) graph.set_name(str(ivalue.type()).split(".")[-1])
continue continue
inputs, outputs = self.data(graph, node, ivalue.unique(), input_ct) inputs, outputs = self.data(graph, node,
ivalue.unique(), input_ct)
input_ct += 1 input_ct += 1
# 转换中间节点 # 转换中间节点
for node in script_graph.nodes(): for node in script_graph.nodes():
...@@ -183,8 +188,9 @@ class PyTorchOpMapper(OpMapper): ...@@ -183,8 +188,9 @@ class PyTorchOpMapper(OpMapper):
outputs=[output_name], outputs=[output_name],
scope_name=scope_name, scope_name=scope_name,
dtype=string(str(param.dtype)), dtype=string(str(param.dtype)),
shape = param.shape, shape=param.shape,
default_initializer="paddle.nn.initializer.Constant(value=0.0)") default_initializer="paddle.nn.initializer.Constant(value=0.0)"
)
self.output2id[output_name] = layer_id self.output2id[output_name] = layer_id
else: else:
if isinstance(param, dict) and "Tensor" in param and \ if isinstance(param, dict) and "Tensor" in param and \
...@@ -211,8 +217,9 @@ class PyTorchOpMapper(OpMapper): ...@@ -211,8 +217,9 @@ class PyTorchOpMapper(OpMapper):
outputs=[output_name], outputs=[output_name],
scope_name=scope_name, scope_name=scope_name,
dtype=string(str(param.dtype)), dtype=string(str(param.dtype)),
shape = param.shape, shape=param.shape,
default_initializer="paddle.nn.initializer.Constant(value=0.0)") default_initializer="paddle.nn.initializer.Constant(value=0.0)"
)
node_outputs.append(output_name) node_outputs.append(output_name)
self.output2id[output_name] = layer_id self.output2id[output_name] = layer_id
return return
...@@ -232,7 +239,8 @@ class PyTorchOpMapper(OpMapper): ...@@ -232,7 +239,8 @@ class PyTorchOpMapper(OpMapper):
value=string(param) value=string(param)
if isinstance(param, str) else param) if isinstance(param, str) else param)
node_outputs.append(output_name) node_outputs.append(output_name)
elif node.kind() == "prim::Constant" and output_name in self.pytorch_params: elif node.kind(
) == "prim::Constant" and output_name in self.pytorch_params:
param = self.pytorch_params[output_name] param = self.pytorch_params[output_name]
self.paddle_params[output_name] = param self.paddle_params[output_name] = param
layer_id = graph.add_layer( layer_id = graph.add_layer(
...@@ -241,11 +249,10 @@ class PyTorchOpMapper(OpMapper): ...@@ -241,11 +249,10 @@ class PyTorchOpMapper(OpMapper):
outputs=[output_name], outputs=[output_name],
scope_name=scope_name, scope_name=scope_name,
dtype=string(str(param.dtype)), dtype=string(str(param.dtype)),
shape = param.shape, shape=param.shape,
default_initializer="paddle.nn.initializer.Constant(value=0.0)") default_initializer="paddle.nn.initializer.Constant(value=0.0)")
self.output2id[output_name] = layer_id self.output2id[output_name] = layer_id
def _get_inputs_name(self, node): def _get_inputs_name(self, node):
inputs_name = [] inputs_name = []
inputs_node = [] inputs_node = []
...@@ -256,7 +263,6 @@ class PyTorchOpMapper(OpMapper): ...@@ -256,7 +263,6 @@ class PyTorchOpMapper(OpMapper):
inputs_node.append(script_input_node) inputs_node.append(script_input_node)
inputs_name.append(input_name) inputs_name.append(input_name)
return inputs_name, inputs_node return inputs_name, inputs_node
def data(self, graph, node, uid, input_ct): def data(self, graph, node, uid, input_ct):
scope_name = self.normalize_scope_name(node) scope_name = self.normalize_scope_name(node)
...@@ -276,7 +282,8 @@ class PyTorchOpMapper(OpMapper): ...@@ -276,7 +282,8 @@ class PyTorchOpMapper(OpMapper):
data=output_name) data=output_name)
if self.input_examples is not None: if self.input_examples is not None:
input_np = self.input_examples[input_ct].detach().numpy() input_np = self.input_examples[input_ct].detach().numpy()
self.inputs_info[output_name] = [list(input_np.shape), str(input_np.dtype)] self.inputs_info[
output_name] = [list(input_np.shape), str(input_np.dtype)]
return [], [output_name] return [], [output_name]
def equal(self, graph, node, uid=None, parent_layer=None, index=None): def equal(self, graph, node, uid=None, parent_layer=None, index=None):
...@@ -289,7 +296,8 @@ class PyTorchOpMapper(OpMapper): ...@@ -289,7 +296,8 @@ class PyTorchOpMapper(OpMapper):
control_output_id = index - 1 control_output_id = index - 1
output_node_name = parent_layer.outputs[control_output_id] output_node_name = parent_layer.outputs[control_output_id]
current_outputs = [output_node_name] current_outputs = [output_node_name]
self._check_input(graph, node, input_node_name, current_outputs, scope_name) self._check_input(graph, node, input_node_name, current_outputs,
scope_name)
graph.add_layer( graph.add_layer(
"prim.equal", "prim.equal",
inputs={'input': input_node_name}, inputs={'input': input_node_name},
...@@ -321,20 +329,20 @@ class PyTorchOpMapper(OpMapper): ...@@ -321,20 +329,20 @@ class PyTorchOpMapper(OpMapper):
self.scope_name2id[i][ns] = 0 self.scope_name2id[i][ns] = 0
real_scope_name = "/".join(name_segments[1:]) real_scope_name = "/".join(name_segments[1:])
real_father_scope_name = "/".join(name_segments[1:-1]) real_father_scope_name = "/".join(name_segments[1:-1])
for i, ns in enumerate(name_segments): for i, ns in enumerate(name_segments):
if i == 0: if i == 0:
continue continue
if self.scope_name2id[i][ns] != 0: if self.scope_name2id[i][ns] != 0:
name_segments[i] = name_segments[i] + \ name_segments[i] = name_segments[i] + \
"__{}".format(self.scope_name2id[i][ns]) "__{}".format(self.scope_name2id[i][ns])
prefix_scope_name = "/".join(name_segments[1 :i + 1]) prefix_scope_name = "/".join(name_segments[1:i + 1])
is_found = False is_found = False
for j in range(len(self.scope_name_list)): for j in range(len(self.scope_name_list)):
last_scope_name = self.scope_name_list[-1-j] last_scope_name = self.scope_name_list[-1 - j]
if last_scope_name.startswith(prefix_scope_name + "/") \ if last_scope_name.startswith(prefix_scope_name + "/") \
or last_scope_name == prefix_scope_name: or last_scope_name == prefix_scope_name:
if j != 0: # and i != len(name_segments) - 1: if j != 0: # and i != len(name_segments) - 1:
is_found = True is_found = True
origin_name_segment_i = name_segments[i].split("__")[0] origin_name_segment_i = name_segments[i].split("__")[0]
self.scope_name2id[i][origin_name_segment_i] += 1 self.scope_name2id[i][origin_name_segment_i] += 1
...@@ -346,4 +354,3 @@ class PyTorchOpMapper(OpMapper): ...@@ -346,4 +354,3 @@ class PyTorchOpMapper(OpMapper):
real_scope_name = "/".join(name_segments[1:]) real_scope_name = "/".join(name_segments[1:])
self.scope_name_list.append(real_scope_name) self.scope_name_list.append(real_scope_name)
return real_scope_name return real_scope_name
\ No newline at end of file
...@@ -248,8 +248,10 @@ class TFOpMapper(OpMapper): ...@@ -248,8 +248,10 @@ class TFOpMapper(OpMapper):
def Transpose(self, node): def Transpose(self, node):
input = self.graph.get_input_node(node, 0) input = self.graph.get_input_node(node, 0)
perm = self.graph.get_input_node(node, 1) perm = self.graph.get_input_node(node, 1)
assert perm.layer_type == "Const", "Perm of transpose OP should be Const" if perm.layer_type == "Const":
perm = perm.value.tolist() perm = perm.value.tolist()
else:
perm = self.decoder.infer_tensor(perm, use_diff_inputs=False).tolist()
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
"paddle.transpose", "paddle.transpose",
...@@ -641,12 +643,18 @@ class TFOpMapper(OpMapper): ...@@ -641,12 +643,18 @@ class TFOpMapper(OpMapper):
paddings = self.graph.get_input_node(node, 1) paddings = self.graph.get_input_node(node, 1)
assert paddings.layer_type == "Const", "Padding should be Const" assert paddings.layer_type == "Const", "Padding should be Const"
paddings = paddings.value.flatten().tolist() paddings = paddings.value.flatten().tolist()
constant_values = 0
if len(node.layer.input) > 2:
constant_values = self.graph.get_input_node(node, 2)
assert constant_values.layer_type == "Const", "Padding should be Const"
constant_values = constant_values.value
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
kernel="paddle.nn.functional.pad", kernel="paddle.nn.functional.pad",
inputs={"x": input.name}, inputs={"x": input.name},
outputs=[node.name], outputs=[node.name],
pad=paddings) pad=paddings,
value=constant_values)
def MirrorPad(self, node): def MirrorPad(self, node):
self.Pad(node) self.Pad(node)
......
...@@ -238,8 +238,10 @@ class TFOpMapper(OpMapper): ...@@ -238,8 +238,10 @@ class TFOpMapper(OpMapper):
def Transpose(self, node): def Transpose(self, node):
input = self.graph.get_node(node.layer.input[0]) input = self.graph.get_node(node.layer.input[0])
perm = self.graph.get_node(node.layer.input[1]) perm = self.graph.get_node(node.layer.input[1])
assert perm.layer_type == "Const", "Perm of transpose OP should be Const" if perm.layer_type == "Const":
perm = perm.value.tolist() perm = perm.value.tolist()
else:
perm = self.decoder.infer_tensor(perm, use_diff_inputs=False).tolist()
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
kernel="paddle.transpose", kernel="paddle.transpose",
...@@ -629,12 +631,18 @@ class TFOpMapper(OpMapper): ...@@ -629,12 +631,18 @@ class TFOpMapper(OpMapper):
paddings = self.graph.get_input_node(node, 1) paddings = self.graph.get_input_node(node, 1)
assert paddings.layer_type == "Const", "Padding should be Const" assert paddings.layer_type == "Const", "Padding should be Const"
paddings = paddings.value.flatten().tolist() paddings = paddings.value.flatten().tolist()
constant_values = 0
if len(node.layer.input) > 2:
constant_values = self.graph.get_input_node(node, 2)
assert constant_values.layer_type == "Const", "Padding should be Const"
constant_values = constant_values.value
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
kernel="paddle.nn.functional.pad", kernel="paddle.nn.functional.pad",
inputs={"x": input.name}, inputs={"x": input.name},
outputs=[node.name], outputs=[node.name],
pad=paddings) pad=paddings,
value=constant_values)
def MirrorPad(self, node): def MirrorPad(self, node):
self.Pad(node) self.Pad(node)
......
...@@ -27,6 +27,8 @@ NN_KERNEL_NAME = {"paddle.nn.BatchNorm": "bn", ...@@ -27,6 +27,8 @@ NN_KERNEL_NAME = {"paddle.nn.BatchNorm": "bn",
"paddle.nn.Linear": "linear", "paddle.nn.Linear": "linear",
"paddle.nn.Conv2DTranspose": "conv", "paddle.nn.Conv2DTranspose": "conv",
"paddle.nn.LSTM": "lstm", "paddle.nn.LSTM": "lstm",
"paddle.nn.GRU": "gru",
"custom_layer:InstanceNorm": "instance_norm",
"paddle.nn.PReLU": "prelu", "paddle.nn.PReLU": "prelu",
"paddle.nn.ReLU": "relu", "paddle.nn.ReLU": "relu",
"paddle.nn.ReLU6": "relu", "paddle.nn.ReLU6": "relu",
...@@ -35,14 +37,14 @@ NN_KERNEL_NAME = {"paddle.nn.BatchNorm": "bn", ...@@ -35,14 +37,14 @@ NN_KERNEL_NAME = {"paddle.nn.BatchNorm": "bn",
"paddle.nn.Tanh": "tanh", "paddle.nn.Tanh": "tanh",
"paddle.nn.AvgPool2D": "avgpool", "paddle.nn.AvgPool2D": "avgpool",
"paddle.nn.MaxPool2D": "maxpool", "paddle.nn.MaxPool2D": "maxpool",
"paddle.nn.Pad1D": "pad", "paddle.nn.Pad1D": "pad1d",
"paddle.nn.Pad2D": "pad", "paddle.nn.Pad2D": "pad2d",
"paddle.nn.Pad3D": "pad", "paddle.nn.Pad3D": "pad3d",
"paddle.nn.Dropout": "dropout", "paddle.nn.Dropout": "dropout",
"paddle.nn.GELU": "gelu", "paddle.nn.GELU": "gelu",
"paddle.nn.Hardtanh": "tanh", "paddle.nn.Hardtanh": "tanh",
"paddle.nn.LeakyReLU": "leakly_relu"} "paddle.nn.LeakyReLU": "leakly_relu"}
NN_KERNEL_WITH_PARAMS = list(NN_KERNEL_NAME.keys())[:8] NN_KERNEL_WITH_PARAMS = list(NN_KERNEL_NAME.keys())[:10]
def rename_layers(layers, param_tree=None, is_rename_module=False): def rename_layers(layers, param_tree=None, is_rename_module=False):
""" 对子模块的输入输出等进行重命名。 """ 对子模块的输入输出等进行重命名。
...@@ -143,7 +145,10 @@ def _update_attrs(layer, different_attrs): ...@@ -143,7 +145,10 @@ def _update_attrs(layer, different_attrs):
if key_name in different_attrs: if key_name in different_attrs:
common_attrs.pop(k) common_attrs.pop(k)
special_attrs[k] = v special_attrs[k] = v
remove_default_attrs(layer.kernel, common_attrs) remove_kernel = layer.kernel
if remove_kernel == "custom_layer:InstanceNorm":
remove_kernel = "paddle.nn.InstanceNorm2D"
remove_default_attrs(remove_kernel, common_attrs)
common_attrs.update(special_attrs) common_attrs.update(special_attrs)
layer.attrs = common_attrs layer.attrs = common_attrs
......
...@@ -212,6 +212,8 @@ class ModuleGraph(object): ...@@ -212,6 +212,8 @@ class ModuleGraph(object):
layer_id_list2 = list(sub_layers2.keys()) layer_id_list2 = list(sub_layers2.keys())
for i, layer_id1 in enumerate(layer_id_list1): for i, layer_id1 in enumerate(layer_id_list1):
layer_id2 = layer_id_list2[i] layer_id2 = layer_id_list2[i]
if layer_id2 not in self.pd_graph.edges_in:
return False
if len(self.pd_graph.edges_in[layer_id1]) != len(self.pd_graph.edges_in[layer_id2]): if len(self.pd_graph.edges_in[layer_id1]) != len(self.pd_graph.edges_in[layer_id2]):
return False return False
for j, ipt_layer_id1 in enumerate(self.pd_graph.edges_in[layer_id1]): for j, ipt_layer_id1 in enumerate(self.pd_graph.edges_in[layer_id1]):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册