提交 0c993cd2 编写于 作者: S SunAhong1993

add remove default param

上级 d9f33f45
...@@ -929,7 +929,7 @@ def aten_constant_pad_nd(mapper, graph, node): ...@@ -929,7 +929,7 @@ def aten_constant_pad_nd(mapper, graph, node):
outputs=[inputs_name[0] + "_list"], outputs=[inputs_name[0] + "_list"],
scope_name=scope_name) scope_name=scope_name)
block.add_layer( block.add_layer(
"paddle.tensor.unsqueeze", "paddle.unsqueeze",
inputs={"x": inputs_name[0], inputs={"x": inputs_name[0],
"axis": inputs_name[0] + "_list"}, "axis": inputs_name[0] + "_list"},
outputs=[inputs_name[0] + "_var"], outputs=[inputs_name[0] + "_var"],
...@@ -941,7 +941,7 @@ def aten_constant_pad_nd(mapper, graph, node): ...@@ -941,7 +941,7 @@ def aten_constant_pad_nd(mapper, graph, node):
scope_name=scope_name, scope_name=scope_name,
**layer_attrs) **layer_attrs)
block.add_layer( block.add_layer(
"paddle.tensor.squeeze", "paddle.squeeze",
inputs={"x": output_name, inputs={"x": output_name,
"axis": inputs_name[0] + "_list"}, "axis": inputs_name[0] + "_list"},
outputs=[output_name], outputs=[output_name],
...@@ -1703,7 +1703,7 @@ def aten_expand_as(mapper, graph, node): ...@@ -1703,7 +1703,7 @@ def aten_expand_as(mapper, graph, node):
outputs=[inputs_name[1] + "_type"], outputs=[inputs_name[1] + "_type"],
scope_name=scope_name) scope_name=scope_name)
block.add_layer( block.add_layer(
"fluid.layers.cast", "paddle.cast",
inputs={"x": inputs_name[0]}, inputs={"x": inputs_name[0]},
outputs=[inputs_name[0]], outputs=[inputs_name[0]],
scope_name=scope_name, scope_name=scope_name,
...@@ -1722,7 +1722,7 @@ def aten_expand_as(mapper, graph, node): ...@@ -1722,7 +1722,7 @@ def aten_expand_as(mapper, graph, node):
if_layer = graph.layers[list(graph.layers.keys())[-1]] if_layer = graph.layers[list(graph.layers.keys())[-1]]
block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph") block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph")
block.add_layer( block.add_layer(
"fluid.layers.cast", "paddle.cast",
inputs={"x": layer_outputs[0]}, inputs={"x": layer_outputs[0]},
outputs=copy.deepcopy(layer_outputs), outputs=copy.deepcopy(layer_outputs),
scope_name=scope_name, scope_name=scope_name,
...@@ -4074,7 +4074,7 @@ def aten_squeeze(mapper, graph, node): ...@@ -4074,7 +4074,7 @@ def aten_squeeze(mapper, graph, node):
layer_inputs["axis"] = inputs_name[1] layer_inputs["axis"] = inputs_name[1]
current_inputs.append(inputs_name[1]) current_inputs.append(inputs_name[1])
graph.add_layer( graph.add_layer(
"paddle.tensor.squeeze", "paddle.squeeze",
inputs=layer_inputs, inputs=layer_inputs,
outputs=layer_outputs, outputs=layer_outputs,
scope_name=scope_name, scope_name=scope_name,
......
...@@ -67,7 +67,10 @@ class HierarchicalTree(Tree): ...@@ -67,7 +67,10 @@ class HierarchicalTree(Tree):
else: else:
len1 = len(min_scope_name.split("/")) len1 = len(min_scope_name.split("/"))
len2 = len(scope_name.split("/")) len2 = len(scope_name.split("/"))
if len1 > len2 and scope_name in self.scope_name_list: if scope_name not in self.scope_name_list:
min_scope_name = scope_name
continue
if len1 > len2:
min_scope_name = scope_name min_scope_name = scope_name
if min_scope_name == "": if min_scope_name == "":
self.create_node(tag=layer.id, self.create_node(tag=layer.id,
......
...@@ -14,6 +14,9 @@ ...@@ -14,6 +14,9 @@
# limitations under the License. # limitations under the License.
import copy import copy
import yaml
import os.path as osp
import x2paddle
from x2paddle.optimizer.code_optimizer.parameter_tree import PamareterNode from x2paddle.optimizer.code_optimizer.parameter_tree import PamareterNode
NN_KERNEL_NAME = {"paddle.nn.BatchNorm": "bn", NN_KERNEL_NAME = {"paddle.nn.BatchNorm": "bn",
...@@ -125,15 +128,34 @@ def rename_layers(layers, param_tree=None, is_rename_module=False): ...@@ -125,15 +128,34 @@ def rename_layers(layers, param_tree=None, is_rename_module=False):
rename_sub_layers(layers_cp, count) rename_sub_layers(layers_cp, count)
return layers_cp, nn_param_nodes, new_names return layers_cp, nn_param_nodes, new_names
def load_default_parameter():
path = x2paddle.__file__
path = path.replace("__init__.py", "core")
yaml_dir = osp.join(path, "paddle_default_parameter.yml")
with open(yaml_dir, "rb") as fr:
default_parameter = yaml.load(fr.read())
return default_parameter
def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=list()): def is_abandon(default_parameter, layer_kernel, param_key, param_value):
if layer_kernel not in default_parameter:
return False
params = default_parameter[layer_kernel]
if param_key not in params:
return False
if params[param_key] == param_value:
return True
else:
return False
def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()):
""" 根据sub_layers生成对应的Module代码。 """ 根据sub_layers生成对应的Module代码。
Args: Args:
graph (x2paddle.core.program.PaddleGraph): 整个Paddle图。 graph (x2paddle.core.program.PaddleGraph): 整个Paddle图。
sub_layers (dict): 子图的id和其对应layer组成的字典。 sub_layers (dict): 子图的id和其对应layer组成的字典。
sub_layers_name (str): 子图的名字。 sub_layers_name (str): 子图的名字。
different_attrs (list): 属性列表,这些属性表明在被调用时赋予不同值。 different_attrs (dict/list): 属性字典/列表,这些属性表明在被调用时赋予不同值。
""" """
def gen_codes(code_list, indent=0): def gen_codes(code_list, indent=0):
""" 根据code_list生成代码段。 """ 根据code_list生成代码段。
...@@ -158,6 +180,12 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=list()): ...@@ -158,6 +180,12 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=list()):
# 生成Layer的头部代码 # 生成Layer的头部代码
head = gen_codes(["class {}(paddle.nn.Layer):".format(sub_layers_name)], indent=0) head = gen_codes(["class {}(paddle.nn.Layer):".format(sub_layers_name)], indent=0)
# 生成init函数的头部代码 # 生成init函数的头部代码
diff_str_list = list()
if isinstance(different_attrs, dict):
for k, v in different_attrs.items():
diff_str_list.append("{}={}".format(k, v))
attrs_str = ", ".join(diff_str_list)
else:
attrs_str = ", ".join(different_attrs) attrs_str = ", ".join(different_attrs)
init_func_head = \ init_func_head = \
gen_codes(["def __init__(self, {}):".format(attrs_str)], indent=1) + \ gen_codes(["def __init__(self, {}):".format(attrs_str)], indent=1) + \
...@@ -213,6 +241,7 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=list()): ...@@ -213,6 +241,7 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=list()):
if is_set_item: if is_set_item:
outputs.append(layer.outputs[0]) outputs.append(layer.outputs[0])
no_output_count = 0 no_output_count = 0
default_parameter = load_default_parameter()
for i, (layer_id, layer) in enumerate(sub_layers.items()): for i, (layer_id, layer) in enumerate(sub_layers.items()):
if ("paddle.nn" in layer.kernel and "functional" not in layer.kernel) or \ if ("paddle.nn" in layer.kernel and "functional" not in layer.kernel) or \
layer.kernel.startswith("custom_layer"): layer.kernel.startswith("custom_layer"):
...@@ -226,6 +255,8 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=list()): ...@@ -226,6 +255,8 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=list()):
if key_name in different_attrs: if key_name in different_attrs:
line += "{}={}, ".format(k, key_name) line += "{}={}, ".format(k, key_name)
else: else:
if is_abandon(default_parameter, layer.kernel, k, v):
continue
line += "{}={}, ".format(k, v) line += "{}={}, ".format(k, v)
line = line.strip(", ") line = line.strip(", ")
line += ")" line += ")"
...@@ -267,7 +298,7 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=list()): ...@@ -267,7 +298,7 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=list()):
init_func=init_func, init_func=init_func,
forward_func=forward_func, forward_func=forward_func,
layer_id=layer_id, layer_id=layer_id,
different_attrs=different_attrs) different_attrs=list(different_attrs.keys()) if isinstance(different_attrs, dict) else different_attrs)
cur_outputs.extend(layer.outputs) cur_outputs.extend(layer.outputs)
else: else:
raise Exception( raise Exception(
...@@ -327,6 +358,8 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=list()): ...@@ -327,6 +358,8 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=list()):
line += "{}=self.{}, ".format(k, key_name) line += "{}=self.{}, ".format(k, key_name)
init_func.extend(gen_codes(["self.{} = {}".format(key_name, key_name)], indent=2)) init_func.extend(gen_codes(["self.{} = {}".format(key_name, key_name)], indent=2))
else: else:
if is_abandon(default_parameter, layer.kernel, k, v):
continue
line += "{}={}, ".format(k, v) line += "{}={}, ".format(k, v)
line = line.strip(", ") line = line.strip(", ")
line += ")" line += ")"
......
...@@ -179,16 +179,27 @@ class ModuleGraph(object): ...@@ -179,16 +179,27 @@ class ModuleGraph(object):
def analyze_attrs_table(self, attrs_table): def analyze_attrs_table(self, attrs_table):
""" 分析属性表格,哪些属性取值不一致。 """ 分析属性表格,哪些属性取值不一致。
""" """
diff_attrs_column = list() diff_attrs_column = dict()
for column in list(attrs_table.columns): for column in list(attrs_table.columns):
elements = list(attrs_table.get(column)) elements = list(attrs_table.get(column))
base = elements[0] elements_list = list()
for element in elements[1:]: count_list = list()
if isinstance(base, str) and "'" not in base: for element in elements:
break if isinstance(element, str) and "'" not in element:
if element != base:
diff_attrs_column.append(column)
break break
if element not in elements_list:
count_list.append(1)
elements_list.append(element)
else:
index = elements_list.index(element)
count_list[index] += 1
if len(elements_list) > 1:
max_ct = 0
for k, v in zip(elements_list, count_list):
if v > max_ct:
max_ele = k
max_ct = v
diff_attrs_column[column] = max_ele
return diff_attrs_column return diff_attrs_column
def analyze_graph(self, sub_layers_list): def analyze_graph(self, sub_layers_list):
...@@ -258,8 +269,10 @@ class ModuleGraph(object): ...@@ -258,8 +269,10 @@ class ModuleGraph(object):
outputs = ["{}_{}".format(mn, index)] + outputs outputs = ["{}_{}".format(mn, index)] + outputs
node_name = "{}_{}".format(module_name, index) node_name = "{}_{}".format(module_name, index)
diff_attrs = dict() diff_attrs = dict()
for column in diff_attrs_column: for column, element in diff_attrs_column.items():
diff_attrs[column] = attrs_table.get(column).loc[node_name] current_element = attrs_table.get(column).loc[node_name]
if current_element != element:
diff_attrs[column] = current_element
new_layer = PaddleLayer(id=list(sub_layers.keys())[-1], new_layer = PaddleLayer(id=list(sub_layers.keys())[-1],
kernel="module", kernel="module",
inputs=inputs_dict, inputs=inputs_dict,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册