未验证 提交 0c93656b 编写于 作者: W WJJ1995 提交者: GitHub

Support PyTorch InstanceNorm2d op (#638)

* Add pytorch LeakyRelu op

* fix pytorch InstanceNorm op

* update paddle code format
上级 d5c2ceb3
...@@ -2637,7 +2637,7 @@ def aten_instance_norm(mapper, graph, node): ...@@ -2637,7 +2637,7 @@ def aten_instance_norm(mapper, graph, node):
# 处理输入1,即%88 # 处理输入1,即%88
if inputs_name[1] in mapper.pytorch_params: if inputs_name[1] in mapper.pytorch_params:
weights = mapper.pytorch_params[inputs_name[1]] weights = mapper.pytorch_params[inputs_name[1]]
mapper.paddle_params[op_name + ".weight"] = weights mapper.paddle_params[op_name + ".scale"] = weights
layer_attrs['num_features'] = weights.shape[0] layer_attrs['num_features'] = weights.shape[0]
# 处理输入2,即%85 # 处理输入2,即%85
if inputs_name[2] in mapper.pytorch_params: if inputs_name[2] in mapper.pytorch_params:
...@@ -2888,6 +2888,42 @@ def aten_leaky_relu_(mapper, graph, node): ...@@ -2888,6 +2888,42 @@ def aten_leaky_relu_(mapper, graph, node):
return current_inputs, current_outputs return current_inputs, current_outputs
def aten_leaky_relu(mapper, graph, node):
""" 构造leaky relu激活的PaddleLayer。
TorchScript示例:
%input.117 : Tensor = aten::leaky_relu(%input.114, %1570)
参数含义:
%input.117 (Tensor): 输出,leaky relu后的结果。
%input.114 (Tensor): 需要leaky relu的Tensor。
%1570 (float): 输入中的元素小于0时的斜率。
"""
scope_name = mapper.normalize_scope_name(node)
op_name = name_generator("leakly_relu", mapper.nn_name2id)
output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [op_name, output_name]
layer_inputs = {}
layer_attrs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list
current_outputs = [output_name]
# 处理输入0,即%result.5
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入、输出的list
current_inputs = list(layer_inputs.values())
# 处理输入1,即%1570
layer_attrs["negative_slope"] = mapper.attrs[inputs_name[1]]
graph.add_layer(
"paddle.nn.LeakyReLU",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name,
**layer_attrs)
return current_inputs, current_outputs
def aten_len(mapper, graph, node): def aten_len(mapper, graph, node):
""" 构造获取list长度的PaddleLayer。 """ 构造获取list长度的PaddleLayer。
TorchScript示例: TorchScript示例:
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy import copy
import os.path as osp import os.path as osp
from treelib import Tree from treelib import Tree
...@@ -29,6 +28,7 @@ SEPARATOR_IN_SCOPE = "/" ...@@ -29,6 +28,7 @@ SEPARATOR_IN_SCOPE = "/"
class HierarchicalTree(Tree): class HierarchicalTree(Tree):
""" 定义层次树。 """ 定义层次树。
""" """
def __init__(self, pd_graph): def __init__(self, pd_graph):
super(HierarchicalTree, self).__init__() super(HierarchicalTree, self).__init__()
self.pd_graph = pd_graph self.pd_graph = pd_graph
...@@ -61,7 +61,8 @@ class HierarchicalTree(Tree): ...@@ -61,7 +61,8 @@ class HierarchicalTree(Tree):
if layer.kernel == "prim.tuple": if layer.kernel == "prim.tuple":
for i, input_layer_id in enumerate(layer_id_list): for i, input_layer_id in enumerate(layer_id_list):
input_layer_id_str = str(input_layer_id) input_layer_id_str = str(input_layer_id)
scope_name = self.pd_graph.layers[input_layer_id_str].scope_name scope_name = self.pd_graph.layers[
input_layer_id_str].scope_name
if i == 0: if i == 0:
min_scope_name = scope_name min_scope_name = scope_name
else: else:
...@@ -73,7 +74,8 @@ class HierarchicalTree(Tree): ...@@ -73,7 +74,8 @@ class HierarchicalTree(Tree):
if len1 > len2: if len1 > len2:
min_scope_name = scope_name min_scope_name = scope_name
if min_scope_name == "": if min_scope_name == "":
self.create_node(tag=layer.id, self.create_node(
tag=layer.id,
identifier="no_scope_" + layer.id, identifier="no_scope_" + layer.id,
parent=self.pd_graph.name, parent=self.pd_graph.name,
data=layer) data=layer)
...@@ -83,20 +85,23 @@ class HierarchicalTree(Tree): ...@@ -83,20 +85,23 @@ class HierarchicalTree(Tree):
else: else:
for input_layer_id in layer_id_list: for input_layer_id in layer_id_list:
input_layer_id_str = str(input_layer_id) input_layer_id_str = str(input_layer_id)
if self.pd_graph.layers[input_layer_id_str].scope_name != "": if self.pd_graph.layers[
scope_name = self.pd_graph.layers[input_layer_id_str].scope_name input_layer_id_str].scope_name != "":
scope_name = self.pd_graph.layers[
input_layer_id_str].scope_name
break break
layer.scope_name = scope_name layer.scope_name = scope_name
else: else:
self.create_node(tag=layer.id, self.create_node(
tag=layer.id,
identifier="no_scope_" + layer.id, identifier="no_scope_" + layer.id,
parent=self.pd_graph.name, parent=self.pd_graph.name,
data=layer) data=layer)
return return
scopes = scope_name.split(SEPARATOR_IN_SCOPE) scopes = scope_name.split(SEPARATOR_IN_SCOPE)
for idx, scope in enumerate(scopes): for idx, scope in enumerate(scopes):
parent = SEPARATOR_IN_SCOPE.join(scopes[:idx])#.lower() parent = SEPARATOR_IN_SCOPE.join(scopes[:idx]) #.lower()
identifier = SEPARATOR_IN_SCOPE.join(scopes[:idx + 1])#.lower() identifier = SEPARATOR_IN_SCOPE.join(scopes[:idx + 1]) #.lower()
if self.contains(identifier): if self.contains(identifier):
if idx != len(scopes) - 1: if idx != len(scopes) - 1:
parent_node = self.parent(identifier) parent_node = self.parent(identifier)
...@@ -109,8 +114,10 @@ class HierarchicalTree(Tree): ...@@ -109,8 +114,10 @@ class HierarchicalTree(Tree):
self.identifier_idx[identifier] = 0 self.identifier_idx[identifier] = 0
else: else:
self.identifier_idx[identifier] += 1 self.identifier_idx[identifier] += 1
identifier_name = identifier + SEPARATOR_IN_SCOPE + str(self.identifier_idx[identifier]) identifier_name = identifier + SEPARATOR_IN_SCOPE + str(
self.create_node(tag=scopes[idx], self.identifier_idx[identifier])
self.create_node(
tag=scopes[idx],
identifier=identifier_name, identifier=identifier_name,
parent=identifier, parent=identifier,
data=data) data=data)
...@@ -125,20 +132,25 @@ class HierarchicalTree(Tree): ...@@ -125,20 +132,25 @@ class HierarchicalTree(Tree):
self.identifier_idx[identifier] = 0 self.identifier_idx[identifier] = 0
else: else:
self.identifier_idx[identifier] += 1 self.identifier_idx[identifier] += 1
self.create_node(tag=scopes[idx], self.create_node(
identifier=identifier + SEPARATOR_IN_SCOPE + str(self.identifier_idx[identifier]), tag=scopes[idx],
identifier=identifier + SEPARATOR_IN_SCOPE +
str(self.identifier_idx[identifier]),
parent=identifier, parent=identifier,
data=data) data=data)
self.identifier_idx[identifier] += 1 self.identifier_idx[identifier] += 1
data = layer data = layer
self.create_node(tag=scopes[idx], self.create_node(
identifier=identifier + SEPARATOR_IN_SCOPE + str(self.identifier_idx[identifier]), tag=scopes[idx],
identifier=identifier + SEPARATOR_IN_SCOPE +
str(self.identifier_idx[identifier]),
parent=identifier, parent=identifier,
data=data) data=data)
continue continue
if idx == 0 and not self.contains(identifier): if idx == 0 and not self.contains(identifier):
data = layer if idx == len(scopes) - 1 else None data = layer if idx == len(scopes) - 1 else None
self.create_node(tag=scopes[idx], self.create_node(
tag=scopes[idx],
identifier=identifier, identifier=identifier,
parent=self.pd_graph.name, parent=self.pd_graph.name,
data=data) data=data)
...@@ -153,7 +165,8 @@ class HierarchicalTree(Tree): ...@@ -153,7 +165,8 @@ class HierarchicalTree(Tree):
identifiers = list() identifiers = list()
for child in childs: for child in childs:
child_identifier = child.identifier child_identifier = child.identifier
if child_identifier.startswith(prefix) and child_identifier != prefix: if child_identifier.startswith(
prefix) and child_identifier != prefix:
identifiers.append(child_identifier) identifiers.append(child_identifier)
if len(identifiers) == 0: if len(identifiers) == 0:
identifier = prefix + "_0" identifier = prefix + "_0"
...@@ -162,14 +175,15 @@ class HierarchicalTree(Tree): ...@@ -162,14 +175,15 @@ class HierarchicalTree(Tree):
for id_obj in identifiers: for id_obj in identifiers:
identifier_ids.append(int(id_obj.split("_")[-1])) identifier_ids.append(int(id_obj.split("_")[-1]))
identifier_ids.sort() identifier_ids.sort()
identifier = prefix + "_{}".format(identifier_ids[-1] + 1) identifier = prefix + "_{}".format(identifier_ids[-1] +
1)
data = layer if idx == len(scopes) - 1 else None data = layer if idx == len(scopes) - 1 else None
self.create_node(tag=scopes[idx], self.create_node(
tag=scopes[idx],
identifier=identifier, identifier=identifier,
parent=parent, parent=parent,
data=data) data=data)
def update_hierarchical_order(self): def update_hierarchical_order(self):
""" 更新层次排序,使用一个字典存储该信息, """ 更新层次排序,使用一个字典存储该信息,
关键字为当前层次,值为节点名字。 关键字为当前层次,值为节点名字。
...@@ -202,7 +216,8 @@ class HierarchicalTree(Tree): ...@@ -202,7 +216,8 @@ class HierarchicalTree(Tree):
break break
return diff_attrs_column return diff_attrs_column
def merge_node(self, sub_layers_list, attrs_table, node_name2sub_layers, module_name): def merge_node(self, sub_layers_list, attrs_table, node_name2sub_layers,
module_name):
""" 将一个scope的节点合成一个Module(Class),并将对应的Class代码 """ 将一个scope的节点合成一个Module(Class),并将对应的Class代码
放到code字符串中。 放到code字符串中。
""" """
...@@ -224,7 +239,10 @@ class HierarchicalTree(Tree): ...@@ -224,7 +239,10 @@ class HierarchicalTree(Tree):
module_name = module_name[0].upper() + module_name[1:] module_name = module_name[0].upper() + module_name[1:]
if module_name in self.module_name2count: if module_name in self.module_name2count:
module_name = module_name + "_0" module_name = module_name + "_0"
code_str = gen_layer_code(self.pd_graph, sub_layers, module_name, code_str = gen_layer_code(
self.pd_graph,
sub_layers,
module_name,
different_attrs=diff_attrs_column) different_attrs=diff_attrs_column)
self.codes.append(code_str) self.codes.append(code_str)
...@@ -241,7 +259,9 @@ class HierarchicalTree(Tree): ...@@ -241,7 +259,9 @@ class HierarchicalTree(Tree):
mn = module_name.lower() + "__" mn = module_name.lower() + "__"
else: else:
mn = module_name.lower() mn = module_name.lower()
outputs = ["{}/{}".format(mn, self.module_name2count[module_name])] + outputs outputs = [
"{}/{}".format(mn, self.module_name2count[module_name])
] + outputs
node_name = get_node_name(sub_layers) node_name = get_node_name(sub_layers)
diff_attrs = dict() diff_attrs = dict()
for column in diff_attrs_column: for column in diff_attrs_column:
...@@ -250,7 +270,8 @@ class HierarchicalTree(Tree): ...@@ -250,7 +270,8 @@ class HierarchicalTree(Tree):
node_name_seg = node_name.split(SEPARATOR_IN_SCOPE) node_name_seg = node_name.split(SEPARATOR_IN_SCOPE)
node_name_seg[-1] = module_name.lower() node_name_seg[-1] = module_name.lower()
new_node_name = SEPARATOR_IN_SCOPE.join(node_name_seg) new_node_name = SEPARATOR_IN_SCOPE.join(node_name_seg)
new_layer = PaddleLayer(id=list(sub_layers.keys())[-1], new_layer = PaddleLayer(
id=list(sub_layers.keys())[-1],
kernel="module", kernel="module",
inputs=inputs_dict, inputs=inputs_dict,
outputs=outputs, outputs=outputs,
...@@ -273,16 +294,15 @@ class HierarchicalTree(Tree): ...@@ -273,16 +294,15 @@ class HierarchicalTree(Tree):
self.pd_graph.build() self.pd_graph.build()
self[node_name].data = new_layer self[node_name].data = new_layer
def find_subgraph_diff(self, module_name2sub_layers,
def find_subgraph_diff(self, module_name2sub_layers, module_name2sub_identifiers, node_name2sub_layers, name): module_name2sub_identifiers, node_name2sub_layers,
name):
""" 查找子图的diff,主要是输入参数的diff。 """ 查找子图的diff,主要是输入参数的diff。
""" """
sub_layers = module_name2sub_layers[name] sub_layers = module_name2sub_layers[name]
sub_identifiers = module_name2sub_identifiers[name] sub_identifiers = module_name2sub_identifiers[name]
new_sub_layers, new_sub_sequentials, sequentials2attrs_table = distinguish_sequential(self.pd_graph, new_sub_layers, new_sub_sequentials, sequentials2attrs_table = distinguish_sequential(
name, self.pd_graph, name, sub_layers, sub_identifiers,
sub_layers,
sub_identifiers,
node_name2sub_layers) node_name2sub_layers)
module_name2sub_layers.pop(name) module_name2sub_layers.pop(name)
module_name2sub_identifiers.pop(name) module_name2sub_identifiers.pop(name)
...@@ -291,7 +311,6 @@ class HierarchicalTree(Tree): ...@@ -291,7 +311,6 @@ class HierarchicalTree(Tree):
module_name2sub_identifiers[k] = new_sub_sequentials[k] module_name2sub_identifiers[k] = new_sub_sequentials[k]
return sequentials2attrs_table return sequentials2attrs_table
def convert_subgraph_to_layer(self): def convert_subgraph_to_layer(self):
""" """
1. 根据_hierarchical_order,从最深的层次开始将 1. 根据_hierarchical_order,从最深的层次开始将
...@@ -314,8 +333,10 @@ class HierarchicalTree(Tree): ...@@ -314,8 +333,10 @@ class HierarchicalTree(Tree):
sub_layers = dict() sub_layers = dict()
sub_identifiers = dict() sub_identifiers = dict()
for successor_name in node_inst.successors(self.identifier): for successor_name in node_inst.successors(self.identifier):
sub_layers[self[successor_name].data.id] = self[successor_name].data sub_layers[self[successor_name].data.id] = self[
sub_identifiers[self[successor_name].data.id] = self[successor_name].data.scope_name.split("/")[-1] successor_name].data
sub_identifiers[self[successor_name].data.id] = self[
successor_name].data.scope_name.split("/")[-1]
node_name2sub_layers[node_name] = sub_layers node_name2sub_layers[node_name] = sub_layers
node_name_segs = node_name.split("/") node_name_segs = node_name.split("/")
...@@ -345,23 +366,29 @@ class HierarchicalTree(Tree): ...@@ -345,23 +366,29 @@ class HierarchicalTree(Tree):
len(module_name2sub_layers[module_name][0][list(module_name2sub_layers[module_name][0].keys())[-1]].outputs): len(module_name2sub_layers[module_name][0][list(module_name2sub_layers[module_name][0].keys())[-1]].outputs):
break break
if module_name not in module_name2sub_layers: if module_name not in module_name2sub_layers:
module_name2sub_layers[module_name] = [sub_layers] module_name2sub_layers[
module_name2sub_identifiers[module_name] = [sub_identifiers] module_name] = [sub_layers]
module_name2sub_identifiers[
module_name] = [sub_identifiers]
else: else:
module_name2sub_layers[module_name].append(sub_layers) module_name2sub_layers[module_name].append(
module_name2sub_identifiers[module_name].append(sub_identifiers) sub_layers)
module_name2sub_identifiers[module_name].append(
sub_identifiers)
else: else:
module_name2sub_layers[module_name].append(sub_layers) module_name2sub_layers[module_name].append(
module_name2sub_identifiers[module_name].append(sub_identifiers) sub_layers)
module_name2sub_identifiers[module_name].append(
sub_identifiers)
else: else:
module_name2sub_layers[module_name] = [sub_layers] module_name2sub_layers[module_name] = [sub_layers]
module_name2sub_identifiers[module_name] = [sub_identifiers] module_name2sub_identifiers[
module_name] = [sub_identifiers]
module_names = list(module_name2sub_layers.keys()) module_names = list(module_name2sub_layers.keys())
for module_name in module_names: for module_name in module_names:
sequentials2attrs_table = self.find_subgraph_diff(module_name2sub_layers, sequentials2attrs_table = self.find_subgraph_diff(
module_name2sub_identifiers, module_name2sub_layers, module_name2sub_identifiers,
node_name2sub_layers, node_name2sub_layers, module_name)
module_name)
for name in sequentials2attrs_table.keys(): for name in sequentials2attrs_table.keys():
if name.startswith("Sequential"): if name.startswith("Sequential"):
# 若Module的名字为Sequential,则以scope_name的名字来命名,在merge_node中实现 # 若Module的名字为Sequential,则以scope_name的名字来命名,在merge_node中实现
...@@ -373,19 +400,19 @@ class HierarchicalTree(Tree): ...@@ -373,19 +400,19 @@ class HierarchicalTree(Tree):
current_module_name_list.append(module_name) current_module_name_list.append(module_name)
self.merge_node(module_name2sub_layers[name], self.merge_node(module_name2sub_layers[name],
sequentials2attrs_table[name], sequentials2attrs_table[name],
node_name2sub_layers, node_name2sub_layers, module_name)
module_name)
def update_parameters(self): def update_parameters(self):
""" 更新参数。 """ 更新参数。
""" """
self.param_tree.traverse() self.param_tree.traverse()
full_old_name_list = copy.deepcopy(list(self.pd_graph.parameters.keys())) full_old_name_list = copy.deepcopy(
list(self.pd_graph.parameters.keys()))
for old_name, new_name in self.param_tree.old2new.items(): for old_name, new_name in self.param_tree.old2new.items():
for full_old_name in full_old_name_list: for full_old_name in full_old_name_list:
if full_old_name.startswith("{}.".format(old_name)): if full_old_name.startswith("{}.".format(old_name)):
full_new_name = full_old_name.replace("{}.".format(old_name), "{}.".format(new_name)) full_new_name = full_old_name.replace(
"{}.".format(old_name), "{}.".format(new_name))
params = self.pd_graph.parameters.pop(full_old_name) params = self.pd_graph.parameters.pop(full_old_name)
self.pd_graph.parameters[full_new_name] = params self.pd_graph.parameters[full_new_name] = params
if full_old_name == old_name: if full_old_name == old_name:
...@@ -398,27 +425,30 @@ class HierarchicalTree(Tree): ...@@ -398,27 +425,30 @@ class HierarchicalTree(Tree):
input_data_name = ', '.join(self.pd_graph.inputs) input_data_name = ', '.join(self.pd_graph.inputs)
run_func_list = list() run_func_list = list()
run_func_list.append("def main({}):".format(input_data_name)) run_func_list.append("def main({}):".format(input_data_name))
run_func_list.append(" # There are {} inputs.".format(len(self.pd_graph.inputs_info))) run_func_list.append(" # There are {} inputs.".format(
len(self.pd_graph.inputs_info)))
for k, v in self.pd_graph.inputs_info.items(): for k, v in self.pd_graph.inputs_info.items():
run_func_list.append(" # {}: shape-{}, type-{}.".format(k, v[0], v[1])) run_func_list.append(" # {}: shape-{}, type-{}.".format(k, v[
run_func_list.extend( 0], v[1]))
[" paddle.disable_static()", run_func_list.extend([
" params = paddle.load('{}')".format(osp.join(osp.abspath(save_dir), "model.pdparams")), " paddle.disable_static()",
" params = paddle.load('{}')".format(
osp.join(osp.abspath(save_dir), "model.pdparams")),
" model = {}()".format(self.pd_graph.name), " model = {}()".format(self.pd_graph.name),
" model.set_dict(params)", " model.set_dict(params)", " model.eval()",
" model.eval()", " out = model({})".format(input_data_name), " return out"
" out = model({})".format(input_data_name), ])
" return out"])
return "\n".join(run_func_list) return "\n".join(run_func_list)
self.update_hierarchical_order() self.update_hierarchical_order()
self.convert_subgraph_to_layer() self.convert_subgraph_to_layer()
self.update_parameters() self.update_parameters()
import_list = ["import paddle", import_list = ["import paddle",
"import math", "import math",
"from x2paddle.op_mapper.pytorch2paddle " + \ "from x2paddle.op_mapper.pytorch2paddle " + \
"import pytorch_custom_layer as x2paddle_nn" "import pytorch_custom_layer as x2paddle_nn",
"\n",] "",]
import_str = "\n".join(import_list) import_str = "\n".join(import_list) + "\n"
if not osp.exists(save_dir): if not osp.exists(save_dir):
os.makedirs(save_dir) os.makedirs(save_dir)
f = open(osp.join(save_dir, 'x2paddle_code.py'), 'w') f = open(osp.join(save_dir, 'x2paddle_code.py'), 'w')
......
...@@ -263,7 +263,7 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()): ...@@ -263,7 +263,7 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()):
layer.kernel.startswith("custom_layer"): layer.kernel.startswith("custom_layer"):
line = "self.{}".format(layer.outputs[0]) line = "self.{}".format(layer.outputs[0])
if layer.kernel.startswith("custom_layer"): if layer.kernel.startswith("custom_layer"):
line += "= x2paddle_nn.{}(".format(layer.kernel.split(":")[-1]) line += " = x2paddle_nn.{}(".format(layer.kernel.split(":")[-1])
else: else:
line += " = {}(".format(layer.kernel) line += " = {}(".format(layer.kernel)
for k, v in layer.attrs.items(): for k, v in layer.attrs.items():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册