提交 0c993cd2 编写于 作者: S SunAhong1993

add remove default param

上级 d9f33f45
......@@ -929,7 +929,7 @@ def aten_constant_pad_nd(mapper, graph, node):
outputs=[inputs_name[0] + "_list"],
scope_name=scope_name)
block.add_layer(
"paddle.tensor.unsqueeze",
"paddle.unsqueeze",
inputs={"x": inputs_name[0],
"axis": inputs_name[0] + "_list"},
outputs=[inputs_name[0] + "_var"],
......@@ -941,7 +941,7 @@ def aten_constant_pad_nd(mapper, graph, node):
scope_name=scope_name,
**layer_attrs)
block.add_layer(
"paddle.tensor.squeeze",
"paddle.squeeze",
inputs={"x": output_name,
"axis": inputs_name[0] + "_list"},
outputs=[output_name],
......@@ -1703,7 +1703,7 @@ def aten_expand_as(mapper, graph, node):
outputs=[inputs_name[1] + "_type"],
scope_name=scope_name)
block.add_layer(
"fluid.layers.cast",
"paddle.cast",
inputs={"x": inputs_name[0]},
outputs=[inputs_name[0]],
scope_name=scope_name,
......@@ -1722,7 +1722,7 @@ def aten_expand_as(mapper, graph, node):
if_layer = graph.layers[list(graph.layers.keys())[-1]]
block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph")
block.add_layer(
"fluid.layers.cast",
"paddle.cast",
inputs={"x": layer_outputs[0]},
outputs=copy.deepcopy(layer_outputs),
scope_name=scope_name,
......@@ -4074,7 +4074,7 @@ def aten_squeeze(mapper, graph, node):
layer_inputs["axis"] = inputs_name[1]
current_inputs.append(inputs_name[1])
graph.add_layer(
"paddle.tensor.squeeze",
"paddle.squeeze",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name,
......
......@@ -67,7 +67,10 @@ class HierarchicalTree(Tree):
else:
len1 = len(min_scope_name.split("/"))
len2 = len(scope_name.split("/"))
if len1 > len2 and scope_name in self.scope_name_list:
if scope_name not in self.scope_name_list:
min_scope_name = scope_name
continue
if len1 > len2:
min_scope_name = scope_name
if min_scope_name == "":
self.create_node(tag=layer.id,
......
......@@ -14,6 +14,9 @@
# limitations under the License.
import copy
import yaml
import os.path as osp
import x2paddle
from x2paddle.optimizer.code_optimizer.parameter_tree import PamareterNode
NN_KERNEL_NAME = {"paddle.nn.BatchNorm": "bn",
......@@ -125,15 +128,34 @@ def rename_layers(layers, param_tree=None, is_rename_module=False):
rename_sub_layers(layers_cp, count)
return layers_cp, nn_param_nodes, new_names
def load_default_parameter():
path = x2paddle.__file__
path = path.replace("__init__.py", "core")
yaml_dir = osp.join(path, "paddle_default_parameter.yml")
with open(yaml_dir, "rb") as fr:
default_parameter = yaml.load(fr.read())
return default_parameter
def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=list()):
def is_abandon(default_parameter, layer_kernel, param_key, param_value):
if layer_kernel not in default_parameter:
return False
params = default_parameter[layer_kernel]
if param_key not in params:
return False
if params[param_key] == param_value:
return True
else:
return False
def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()):
""" 根据sub_layers生成对应的Module代码。
Args:
graph (x2paddle.core.program.PaddleGraph): 整个Paddle图。
sub_layers (dict): 子图的id和其对应layer组成的字典。
sub_layers_name (str): 子图的名字。
different_attrs (list): 属性列表,这些属性表明在被调用时赋予不同值。
different_attrs (dict/list): 属性字典/列表,这些属性表明在被调用时赋予不同值。
"""
def gen_codes(code_list, indent=0):
""" 根据code_list生成代码段。
......@@ -158,6 +180,12 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=list()):
# 生成Layer的头部代码
head = gen_codes(["class {}(paddle.nn.Layer):".format(sub_layers_name)], indent=0)
# 生成init函数的头部代码
diff_str_list = list()
if isinstance(different_attrs, dict):
for k, v in different_attrs.items():
diff_str_list.append("{}={}".format(k, v))
attrs_str = ", ".join(diff_str_list)
else:
attrs_str = ", ".join(different_attrs)
init_func_head = \
gen_codes(["def __init__(self, {}):".format(attrs_str)], indent=1) + \
......@@ -213,6 +241,7 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=list()):
if is_set_item:
outputs.append(layer.outputs[0])
no_output_count = 0
default_parameter = load_default_parameter()
for i, (layer_id, layer) in enumerate(sub_layers.items()):
if ("paddle.nn" in layer.kernel and "functional" not in layer.kernel) or \
layer.kernel.startswith("custom_layer"):
......@@ -226,6 +255,8 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=list()):
if key_name in different_attrs:
line += "{}={}, ".format(k, key_name)
else:
if is_abandon(default_parameter, layer.kernel, k, v):
continue
line += "{}={}, ".format(k, v)
line = line.strip(", ")
line += ")"
......@@ -267,7 +298,7 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=list()):
init_func=init_func,
forward_func=forward_func,
layer_id=layer_id,
different_attrs=different_attrs)
different_attrs=list(different_attrs.keys()) if isinstance(different_attrs, dict) else different_attrs)
cur_outputs.extend(layer.outputs)
else:
raise Exception(
......@@ -327,6 +358,8 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=list()):
line += "{}=self.{}, ".format(k, key_name)
init_func.extend(gen_codes(["self.{} = {}".format(key_name, key_name)], indent=2))
else:
if is_abandon(default_parameter, layer.kernel, k, v):
continue
line += "{}={}, ".format(k, v)
line = line.strip(", ")
line += ")"
......
......@@ -179,16 +179,27 @@ class ModuleGraph(object):
def analyze_attrs_table(self, attrs_table):
""" 分析属性表格,哪些属性取值不一致。
"""
diff_attrs_column = list()
diff_attrs_column = dict()
for column in list(attrs_table.columns):
elements = list(attrs_table.get(column))
base = elements[0]
for element in elements[1:]:
if isinstance(base, str) and "'" not in base:
break
if element != base:
diff_attrs_column.append(column)
elements_list = list()
count_list = list()
for element in elements:
if isinstance(element, str) and "'" not in element:
break
if element not in elements_list:
count_list.append(1)
elements_list.append(element)
else:
index = elements_list.index(element)
count_list[index] += 1
if len(elements_list) > 1:
max_ct = 0
for k, v in zip(elements_list, count_list):
if v > max_ct:
max_ele = k
max_ct = v
diff_attrs_column[column] = max_ele
return diff_attrs_column
def analyze_graph(self, sub_layers_list):
......@@ -258,8 +269,10 @@ class ModuleGraph(object):
outputs = ["{}_{}".format(mn, index)] + outputs
node_name = "{}_{}".format(module_name, index)
diff_attrs = dict()
for column in diff_attrs_column:
diff_attrs[column] = attrs_table.get(column).loc[node_name]
for column, element in diff_attrs_column.items():
current_element = attrs_table.get(column).loc[node_name]
if current_element != element:
diff_attrs[column] = current_element
new_layer = PaddleLayer(id=list(sub_layers.keys())[-1],
kernel="module",
inputs=inputs_dict,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册