未验证 提交 21ff465e 编写于 作者: W WJJ1995 提交者: GitHub

fixed pytorch codegen bug (#650)

上级 379ce426
...@@ -5388,7 +5388,7 @@ def aten_upsample_bilinear2d(mapper, graph, node): ...@@ -5388,7 +5388,7 @@ def aten_upsample_bilinear2d(mapper, graph, node):
%4963 (list): 上采样后的大小。 %4963 (list): 上采样后的大小。
%5421 (bool): 若为True,则将输入和输出张量的4个角落像素的中心对齐,并保留角点像素的值。 %5421 (bool): 若为True,则将输入和输出张量的4个角落像素的中心对齐,并保留角点像素的值。
%4995 (float): 高度的乘数因子。 %4995 (float): 高度的乘数因子。
%4995 (float): 宽度的乘数因子。 %4996 (float): 宽度的乘数因子。
""" """
scope_name = mapper.normalize_scope_name(node) scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
...@@ -5465,7 +5465,7 @@ def aten_upsample_nearest2d(mapper, graph, node): ...@@ -5465,7 +5465,7 @@ def aten_upsample_nearest2d(mapper, graph, node):
%4997 (Tensor): 输出,上采样后的Tensor。 %4997 (Tensor): 输出,上采样后的Tensor。
%x.13 (Tensor): 需要上采样的Tensor。 %x.13 (Tensor): 需要上采样的Tensor。
%4963 (list): 上采样后的大小。 %4963 (list): 上采样后的大小。
%4995 (float): 高度的乘数因子。 %5421 (float): 高度的乘数因子。
%4995 (float): 宽度的乘数因子。 %4995 (float): 宽度的乘数因子。
""" """
scope_name = mapper.normalize_scope_name(node) scope_name = mapper.normalize_scope_name(node)
......
...@@ -19,33 +19,35 @@ import x2paddle ...@@ -19,33 +19,35 @@ import x2paddle
from x2paddle.optimizer.pytorch_code_optimizer.parameter_tree import PamareterNode from x2paddle.optimizer.pytorch_code_optimizer.parameter_tree import PamareterNode
from x2paddle.core.util import * from x2paddle.core.util import *
NN_KERNEL_NAME = {
NN_KERNEL_NAME = {"paddle.nn.BatchNorm": "bn", "paddle.nn.BatchNorm": "bn",
"paddle.nn.LayerNorm": "layernorm", "paddle.nn.LayerNorm": "layernorm",
"paddle.nn.Conv2D": "conv", "paddle.nn.Conv2D": "conv",
"paddle.nn.Embedding": "embedding", "paddle.nn.Embedding": "embedding",
"paddle.nn.Linear": "linear", "paddle.nn.Linear": "linear",
"paddle.nn.Conv2DTranspose": "conv", "paddle.nn.Conv2DTranspose": "conv",
"paddle.nn.LSTM": "lstm", "paddle.nn.LSTM": "lstm",
"paddle.nn.GRU": "gru", "paddle.nn.GRU": "gru",
"custom_layer:InstanceNorm": "instance_norm", "custom_layer:InstanceNorm": "instance_norm",
"paddle.nn.PReLU": "prelu", "paddle.nn.PReLU": "prelu",
"paddle.nn.ReLU": "relu", "paddle.nn.ReLU": "relu",
"paddle.nn.ReLU6": "relu", "paddle.nn.ReLU6": "relu",
"paddle.nn.Softmax": "softmax", "paddle.nn.Softmax": "softmax",
"paddle.nn.Softplus": "softplus", "paddle.nn.Softplus": "softplus",
"paddle.nn.Tanh": "tanh", "paddle.nn.Tanh": "tanh",
"paddle.nn.AvgPool2D": "avgpool", "paddle.nn.AvgPool2D": "avgpool",
"paddle.nn.MaxPool2D": "maxpool", "paddle.nn.MaxPool2D": "maxpool",
"paddle.nn.Pad1D": "pad1d", "paddle.nn.Pad1D": "pad1d",
"paddle.nn.Pad2D": "pad2d", "paddle.nn.Pad2D": "pad2d",
"paddle.nn.Pad3D": "pad3d", "paddle.nn.Pad3D": "pad3d",
"paddle.nn.Dropout": "dropout", "paddle.nn.Dropout": "dropout",
"paddle.nn.GELU": "gelu", "paddle.nn.GELU": "gelu",
"paddle.nn.Hardtanh": "tanh", "paddle.nn.Hardtanh": "tanh",
"paddle.nn.LeakyReLU": "leakly_relu"} "paddle.nn.LeakyReLU": "leakly_relu"
}
NN_KERNEL_WITH_PARAMS = list(NN_KERNEL_NAME.keys())[:10] NN_KERNEL_WITH_PARAMS = list(NN_KERNEL_NAME.keys())[:10]
def rename_layers(layers, param_tree=None, is_rename_module=False): def rename_layers(layers, param_tree=None, is_rename_module=False):
""" 对子模块的输入输出等进行重命名。 """ 对子模块的输入输出等进行重命名。
""" """
...@@ -58,6 +60,7 @@ def rename_layers(layers, param_tree=None, is_rename_module=False): ...@@ -58,6 +60,7 @@ def rename_layers(layers, param_tree=None, is_rename_module=False):
new_names = list() new_names = list()
for kernel in NN_KERNEL_NAME.keys(): for kernel in NN_KERNEL_NAME.keys():
nn_count_dict[kernel] = 0 nn_count_dict[kernel] = 0
def rename_sub_layers(sub_layers, count, is_block=False): def rename_sub_layers(sub_layers, count, is_block=False):
for layer_id, layer in sub_layers.items(): for layer_id, layer in sub_layers.items():
# 对输入重命名 # 对输入重命名
...@@ -69,10 +72,9 @@ def rename_layers(layers, param_tree=None, is_rename_module=False): ...@@ -69,10 +72,9 @@ def rename_layers(layers, param_tree=None, is_rename_module=False):
count += 1 count += 1
layer.inputs[input_k] = new_name layer.inputs[input_k] = new_name
name_dict[input_v] = new_name name_dict[input_v] = new_name
# 对block重命名 # 对block重命名
for block in layer.blocks: for block in layer.blocks:
count = rename_sub_layers(block.layers, count = rename_sub_layers(block.layers, count, is_block=True)
count, is_block=True)
# 对输出重命名 # 对输出重命名
if len(layer.outputs) == 0 and not is_block: if len(layer.outputs) == 0 and not is_block:
new_names.append("layer_id/{}".format(layer_id)) new_names.append("layer_id/{}".format(layer_id))
...@@ -83,9 +85,10 @@ def rename_layers(layers, param_tree=None, is_rename_module=False): ...@@ -83,9 +85,10 @@ def rename_layers(layers, param_tree=None, is_rename_module=False):
new_names.append(name_dict[output_v]) new_names.append(name_dict[output_v])
else: else:
if i == 0 and layer.kernel in NN_KERNEL_NAME.keys(): if i == 0 and layer.kernel in NN_KERNEL_NAME.keys():
new_name = NN_KERNEL_NAME[layer.kernel] + str(nn_count_dict[layer.kernel]) new_name = NN_KERNEL_NAME[layer.kernel] + str(
param_node = PamareterNode(old_name=layer.outputs[0], nn_count_dict[layer.kernel])
new_name=new_name) param_node = PamareterNode(
old_name=layer.outputs[0], new_name=new_name)
nn_param_nodes.append(param_node) nn_param_nodes.append(param_node)
if param_tree is not None: if param_tree is not None:
param_tree.add_node(param_node) param_tree.add_node(param_node)
...@@ -94,7 +97,8 @@ def rename_layers(layers, param_tree=None, is_rename_module=False): ...@@ -94,7 +97,8 @@ def rename_layers(layers, param_tree=None, is_rename_module=False):
elif i == 0 and layer.kernel == "module": elif i == 0 and layer.kernel == "module":
if is_rename_module: if is_rename_module:
if param_tree is not None: if param_tree is not None:
param_node = param_tree.get_node(layer.outputs[0]) param_node = param_tree.get_node(layer.outputs[
0])
nn_param_nodes.append(param_node) nn_param_nodes.append(param_node)
param_node.new_name = layer.outputs[0] param_node.new_name = layer.outputs[0]
else: else:
...@@ -105,7 +109,8 @@ def rename_layers(layers, param_tree=None, is_rename_module=False): ...@@ -105,7 +109,8 @@ def rename_layers(layers, param_tree=None, is_rename_module=False):
nn_count_dict[old_name] += 1 nn_count_dict[old_name] += 1
new_name = old_name + str(nn_count_dict[old_name]) new_name = old_name + str(nn_count_dict[old_name])
if param_tree is not None: if param_tree is not None:
param_node = param_tree.get_node(layer.outputs[0]) param_node = param_tree.get_node(layer.outputs[
0])
nn_param_nodes.append(param_node) nn_param_nodes.append(param_node)
param_node.new_name = new_name param_node.new_name = new_name
layer.outputs[0] = new_name layer.outputs[0] = new_name
...@@ -116,8 +121,8 @@ def rename_layers(layers, param_tree=None, is_rename_module=False): ...@@ -116,8 +121,8 @@ def rename_layers(layers, param_tree=None, is_rename_module=False):
layer.outputs[i] = new_name layer.outputs[i] = new_name
name_dict[output_v] = new_name name_dict[output_v] = new_name
if layer.kernel == "self.create_parameter": if layer.kernel == "self.create_parameter":
param_node = PamareterNode(old_name=old_name, param_node = PamareterNode(
new_name=new_name) old_name=old_name, new_name=new_name)
nn_param_nodes.append(param_node) nn_param_nodes.append(param_node)
if param_tree is not None: if param_tree is not None:
param_tree.add_node(param_node) param_tree.add_node(param_node)
...@@ -129,6 +134,7 @@ def rename_layers(layers, param_tree=None, is_rename_module=False): ...@@ -129,6 +134,7 @@ def rename_layers(layers, param_tree=None, is_rename_module=False):
and attr_v in name_dict: and attr_v in name_dict:
layer.attrs[attr_k] = name_dict[attr_v] layer.attrs[attr_k] = name_dict[attr_v]
return count return count
rename_sub_layers(layers_cp, count) rename_sub_layers(layers_cp, count)
return layers_cp, nn_param_nodes, new_names return layers_cp, nn_param_nodes, new_names
...@@ -152,22 +158,24 @@ def _update_attrs(layer, different_attrs): ...@@ -152,22 +158,24 @@ def _update_attrs(layer, different_attrs):
common_attrs.update(special_attrs) common_attrs.update(special_attrs)
layer.attrs = common_attrs layer.attrs = common_attrs
def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()): def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()):
""" 根据sub_layers生成对应的Module代码。 """ 根据sub_layers生成对应的Module代码。
Args: Args:
graph (x2paddle.core.program.PaddleGraph): 整个Paddle图。 graph (x2paddle.core.program.PaddleGraph): 整个Paddle图。
sub_layers (dict): 子图的id和其对应layer组成的字典。 sub_layers (dict): 子图的id和其对应layer组成的字典。
sub_layers_name (str): 子图的名字。 sub_layers_name (str): 子图的名字。
different_attrs (dict/list): 属性字典/列表,这些属性表明在被调用时赋予不同值。 different_attrs (dict/list): 属性字典/列表,这些属性表明在被调用时赋予不同值。
""" """
def gen_codes(code_list, indent=0): def gen_codes(code_list, indent=0):
""" 根据code_list生成代码段。 """ 根据code_list生成代码段。
Args: Args:
code_list (list): 代码行组成的list。 code_list (list): 代码行组成的list。
indent (int): 每行空格的数量。 indent (int): 每行空格的数量。
Returns: Returns:
str: 代码段。 str: 代码段。
""" """
...@@ -179,10 +187,11 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()): ...@@ -179,10 +187,11 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()):
else: else:
codes.append(indent_blank + code_line + '\n') codes.append(indent_blank + code_line + '\n')
return codes return codes
def gen_head(inputs, different_attrs): def gen_head(inputs, different_attrs):
# 生成Layer的头部代码 # 生成Layer的头部代码
head = gen_codes(["class {}(paddle.nn.Layer):".format(sub_layers_name)], indent=0) head = gen_codes(
["class {}(paddle.nn.Layer):".format(sub_layers_name)], indent=0)
# 生成init函数的头部代码 # 生成init函数的头部代码
diff_str_list = list() diff_str_list = list()
if isinstance(different_attrs, dict): if isinstance(different_attrs, dict):
...@@ -199,8 +208,7 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()): ...@@ -199,8 +208,7 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()):
forward_func_head = \ forward_func_head = \
gen_codes(["def forward(self, {}):".format(input_data_name)], indent=1) gen_codes(["def forward(self, {}):".format(input_data_name)], indent=1)
return head, init_func_head, forward_func_head return head, init_func_head, forward_func_head
init_func = [] init_func = []
forward_func = [] forward_func = []
cur_outputs = list() cur_outputs = list()
...@@ -211,7 +219,9 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()): ...@@ -211,7 +219,9 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()):
for layer_id, layer in sub_layers.items(): for layer_id, layer in sub_layers.items():
if layer_id not in graph.edges_out: if layer_id not in graph.edges_out:
for index, output_name in enumerate(layer.outputs): for index, output_name in enumerate(layer.outputs):
if layer.kernel.startswith("paddle.nn") and index == 0: if layer.kernel.startswith(
"paddle.nn"
) and index == 0 and "functional" not in layer.kernel:
continue continue
if not output_name.startswith("x") or output_name in outputs \ if not output_name.startswith("x") or output_name in outputs \
or layer.kernel == "prim.assert": or layer.kernel == "prim.assert":
...@@ -225,7 +235,9 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()): ...@@ -225,7 +235,9 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()):
for out_layer_id in graph.edges_out[layer_id]: for out_layer_id in graph.edges_out[layer_id]:
if out_layer_id not in sub_layers: if out_layer_id not in sub_layers:
for index, output_name in enumerate(layer.outputs): for index, output_name in enumerate(layer.outputs):
if layer.kernel.startswith("paddle.nn") and index == 0 and "functional" not in layer.kernel: if layer.kernel.startswith(
"paddle.nn"
) and index == 0 and "functional" not in layer.kernel:
continue continue
if not output_name.startswith("x") or output_name in outputs \ if not output_name.startswith("x") or output_name in outputs \
or layer.kernel == "prim.assert": or layer.kernel == "prim.assert":
...@@ -263,17 +275,18 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()): ...@@ -263,17 +275,18 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()):
line = line.strip(", ") line = line.strip(", ")
line += ")" line += ")"
init_func.extend(gen_codes([line], indent=2)) init_func.extend(gen_codes([line], indent=2))
if len(layer.outputs) == 1: if len(layer.outputs) == 1:
line = layer.outputs[0] line = layer.outputs[0]
elif len(layer.outputs) == 2: elif len(layer.outputs) == 2:
line = layer.outputs[1] line = layer.outputs[1]
else: else:
if layer.kernel == "paddle.nn.LSTM": if layer.kernel == "paddle.nn.LSTM":
line = "{}, ({})".format(layer.outputs[1], ', '.join(layer.outputs[-2:])) line = "{}, ({})".format(layer.outputs[1],
', '.join(layer.outputs[-2:]))
else: else:
line = ','.join(layer.outputs[1:]) line = ','.join(layer.outputs[1:])
line += " = self.{}(".format(layer.outputs[0]) line += " = self.{}(".format(layer.outputs[0])
for k, v in layer.inputs.items(): for k, v in layer.inputs.items():
if v not in cur_outputs and v not in inputs: if v not in cur_outputs and v not in inputs:
...@@ -299,15 +312,17 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()): ...@@ -299,15 +312,17 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()):
indent=2, indent=2,
init_func=init_func, init_func=init_func,
forward_func=forward_func, forward_func=forward_func,
layer_id=layer_id, layer_id=layer_id,
different_attrs=list(different_attrs.keys()) if isinstance(different_attrs, dict) else different_attrs) different_attrs=list(different_attrs.keys())
if isinstance(different_attrs, dict) else different_attrs)
cur_outputs.extend(layer.outputs) cur_outputs.extend(layer.outputs)
else: else:
raise Exception( raise Exception(
"The kind {} in paddle model is not supported yet.". "The kind {} in paddle model is not supported yet.".format(
format(layer.kernel)) layer.kernel))
elif layer.kernel == "module": elif layer.kernel == "module":
line = "self.{} = {}(".format(layer.outputs[0], layer.attrs["module"]) line = "self.{} = {}(".format(layer.outputs[0],
layer.attrs["module"])
layer.attrs.pop("module") layer.attrs.pop("module")
for k, v in layer.attrs.items(): for k, v in layer.attrs.items():
key_name = "{}_{}".format(layer.outputs[0], k) key_name = "{}_{}".format(layer.outputs[0], k)
...@@ -358,23 +373,31 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()): ...@@ -358,23 +373,31 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()):
key_name = "{}_{}".format(layer.outputs[0], k) key_name = "{}_{}".format(layer.outputs[0], k)
if key_name in different_attrs: if key_name in different_attrs:
line += "{}=self.{}, ".format(k, key_name) line += "{}=self.{}, ".format(k, key_name)
init_func.extend(gen_codes(["self.{} = {}".format(key_name, key_name)], indent=2)) init_func.extend(
gen_codes(
["self.{} = {}".format(key_name, key_name)],
indent=2))
else: else:
line += "{}={}, ".format(k, v) line += "{}={}, ".format(k, v)
line = line.strip(", ") line = line.strip(", ")
line += ")" line += ")"
if layer.kernel == "self.create_parameter": if layer.kernel == "self.create_parameter":
init_func.extend(gen_codes(["self." + line], indent=2)) init_func.extend(gen_codes(["self." + line], indent=2))
forward_func.extend(gen_codes(["{} = self.{}".format(layer.outputs[0], forward_func.extend(
layer.outputs[0])], indent=2)) gen_codes(
[
"{} = self.{}".format(layer.outputs[0],
layer.outputs[0])
],
indent=2))
else: else:
forward_func.extend(gen_codes([line], indent=2)) forward_func.extend(gen_codes([line], indent=2))
cur_outputs.extend(layer.outputs) cur_outputs.extend(layer.outputs)
head, init_func_head, forward_func_head = gen_head(inputs, different_attrs) head, init_func_head, forward_func_head = gen_head(inputs, different_attrs)
output_data_name = ", ".join(outputs) output_data_name = ", ".join(outputs)
code_list = head + init_func_head + init_func + \ code_list = head + init_func_head + init_func + \
forward_func_head + forward_func + \ forward_func_head + forward_func + \
gen_codes(["return {}".format(output_data_name)], indent=2) gen_codes(["return {}".format(output_data_name)], indent=2)
code_str = "".join(code_list) code_str = "".join(code_list)
return code_str return code_str
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册