From b4eceac8cd83d2095bb8ddca78ec3f93d90105c8 Mon Sep 17 00:00:00 2001 From: wjj19950828 Date: Tue, 25 Jan 2022 15:32:11 +0800 Subject: [PATCH] Fixed FusedBatchNorm --- x2paddle/op_mapper/tf2paddle/tf_op_mapper.py | 4 +- .../pytorch_code_optimizer/module_graph.py | 158 ++++++++++-------- 2 files changed, 91 insertions(+), 71 deletions(-) diff --git a/x2paddle/op_mapper/tf2paddle/tf_op_mapper.py b/x2paddle/op_mapper/tf2paddle/tf_op_mapper.py index c913bba..e163857 100644 --- a/x2paddle/op_mapper/tf2paddle/tf_op_mapper.py +++ b/x2paddle/op_mapper/tf2paddle/tf_op_mapper.py @@ -563,7 +563,6 @@ class TFOpMapper(): n, h, w, c = input.out_shapes[0] else: n, c, h, w = input.out_shapes[0] - self.params["{}_{}".format(node.name, gamma.name)] = self.params[ gamma.name] self.params["{}_{}".format(node.name, beta.name)] = self.params[ @@ -584,7 +583,8 @@ class TFOpMapper(): moving_mean.name)), moving_variance_name=string("{}_{}".format(node.name, moving_var.name)), - is_test=True) + is_test=True, + trainable_statistics=node.get_attr("is_training")) if data_format == "NHWC": self.paddle_graph.add_layer( diff --git a/x2paddle/optimizer/pytorch_code_optimizer/module_graph.py b/x2paddle/optimizer/pytorch_code_optimizer/module_graph.py index 669e305..ac98da8 100644 --- a/x2paddle/optimizer/pytorch_code_optimizer/module_graph.py +++ b/x2paddle/optimizer/pytorch_code_optimizer/module_graph.py @@ -21,27 +21,28 @@ from x2paddle.optimizer.pytorch_code_optimizer.subgraphs_union import construct_ from x2paddle.optimizer.pytorch_code_optimizer.layer_code_generator import gen_layer_code, rename_layers from x2paddle.optimizer.pytorch_code_optimizer.parameter_tree import PamareterNode, PamareterTree - NoModuleStart = ["paddle.nn.ReLU"] + class Apriori(object): """ 使用Apriori算法挖掘频繁子图 1. 构建频繁1项集 2. 挖掘频繁k项集 3. 最终k项集和节点数满足最少节点数的子图组成集合GS - + Args: min_support (int): 子图出现次数的最小值。 """ + def __init__(self, min_support): - self.min_support = min_support - + self.min_support = min_support + def is_match(self, item, sublayers): for i in range(len(item)): if len(sublayers) <= i or item[i] != sublayers[i].kernel: return False return True - + def create_C1(self): # 构建候选1-项集 C1 = list() @@ -49,9 +50,9 @@ class Apriori(object): if layer.kernel == "paddle.to_tensor" or \ layer.kernel == "prim.if" or \ layer.kernel == "prim.loop": #or \ -# layer.kernel == "prim.list" or \ -# layer.kernel == "prim.tuple" or \ -# layer.kernel == "prim.dict_construct": + # layer.kernel == "prim.list" or \ + # layer.kernel == "prim.tuple" or \ + # layer.kernel == "prim.dict_construct": continue if self.pd_graph.edges_in.get(layer_id, 0) == 0 and \ self.pd_graph.edges_out.get(layer_id, 0) == 0: @@ -59,7 +60,7 @@ class Apriori(object): if [layer.kernel] not in C1: C1.append([layer.kernel]) return C1 - + def create_Ck(self, Lk_last, C1): # 构建候选k-项集 Ck = list() @@ -71,7 +72,7 @@ class Apriori(object): continue Ck.append(new_item) return Ck - + def generate_Lk_by_Ck(self, Ck): # 生成频繁k-项集 Lk = list() @@ -82,9 +83,9 @@ class Apriori(object): if self.is_match(item, sublayers): count += 1 if count >= self.min_support: - Lk.append(item) + Lk.append(item) return Lk - + def run(self, graph): self.pd_graph = graph self.layers = graph.layers @@ -97,14 +98,15 @@ class Apriori(object): Lk = self.generate_Lk_by_Ck(Ck) itemset.extend(Lk) return itemset - + class DP(object): - """ 使用动动态规划找到使代码最短的组合方式。 + """ 使用动态规划找到使代码最短的组合方式。 """ + def __init__(self, combination_itemset): self.combination_itemset = combination_itemset - + def get_combination_id(self, combination, layers): combination_id = list() for layer_obj in combination: @@ -117,7 +119,7 @@ class DP(object): else: combination_id.append(-1) return combination_id - + def run(self, graph): layers = graph.layers layer_combination_list = list() @@ -135,7 +137,7 @@ class DP(object): current_layer_id = list(layers.keys())[j] current_layer = list(layers.values())[j] current_itemset.insert(0, current_layer_id) - kernel_itemset.insert(0, current_layer.kernel) + kernel_itemset.insert(0, current_layer.kernel) if kernel_itemset in self.combination_itemset: current_count = len(layer_combination_list[j - 1]) all_count = current_count + 1 @@ -145,7 +147,8 @@ class DP(object): if j - 1 < 0: last_itemset = list() else: - last_itemset = copy.deepcopy(layer_combination_list[j - 1]) + last_itemset = copy.deepcopy(layer_combination_list[ + j - 1]) else: if j == prefix_ids[0]: min_count = len(layer_combination_list[j]) + 1 @@ -158,24 +161,25 @@ class DP(object): final_combination = layer_combination_list[-1] combination_id = self.get_combination_id(final_combination, layers) return final_combination, combination_id - - + + class ModuleGraph(object): """ 更新PaddleGraph,生成代码。 """ + def __init__(self, graph): self.pd_graph = graph self.global_layers = graph.get_global_layers() self.codes = list() self.param_tree = PamareterTree() - + def get_updation_information(self): aprior = Apriori(3) combination_itemset = aprior.run(self.pd_graph) dp = DP(combination_itemset) combination, combination_id = dp.run(self.pd_graph) return combination, combination_id - + def analyze_attrs_table(self, attrs_table): """ 分析属性表格,哪些属性取值不一致。 """ @@ -196,12 +200,12 @@ class ModuleGraph(object): if len(elements_list) > 1: max_ct = 0 for k, v in zip(elements_list, count_list): - if v > max_ct and str(k) != "nan" : + if v > max_ct and str(k) != "nan": max_ele = k max_ct = v diff_attrs_column[column] = max_ele return diff_attrs_column - + def analyze_graph(self, sub_layers_list): def is_same(sub_layers1, sub_layers2, id1, id2): inputs1, outputs1 = ipt_opt_list[id1] @@ -214,24 +218,34 @@ class ModuleGraph(object): layer_id2 = layer_id_list2[i] if layer_id2 not in self.pd_graph.edges_in: return False - if len(self.pd_graph.edges_in[layer_id1]) != len(self.pd_graph.edges_in[layer_id2]): + if len(self.pd_graph.edges_in[layer_id1]) != len( + self.pd_graph.edges_in[layer_id2]): return False - for j, ipt_layer_id1 in enumerate(self.pd_graph.edges_in[layer_id1]): + for j, ipt_layer_id1 in enumerate(self.pd_graph.edges_in[ + layer_id1]): ipt_layer_id2 = self.pd_graph.edges_in[layer_id2][j] - if (ipt_layer_id1 in layer_id_list1) ^ (ipt_layer_id2 in layer_id_list2): + if (ipt_layer_id1 in layer_id_list1) ^ ( + ipt_layer_id2 in layer_id_list2): return False - if (layer_id1 in self.pd_graph.edges_out) ^ (layer_id2 in self.pd_graph.edges_out): + if (layer_id1 in self.pd_graph.edges_out) ^ ( + layer_id2 in self.pd_graph.edges_out): return False - if (layer_id1 in self.pd_graph.edges_out) and (layer_id2 in self.pd_graph.edges_out): + if (layer_id1 in self.pd_graph.edges_out) and ( + layer_id2 in self.pd_graph.edges_out): if (len(self.pd_graph.edges_out[layer_id1]) > 1 and len(self.pd_graph.edges_out[layer_id2]) == 1) or \ (len(self.pd_graph.edges_out[layer_id1]) == 1 and len(self.pd_graph.edges_out[layer_id2]) > 1): return False - for j, opt_layer_id1 in enumerate(self.pd_graph.edges_out[layer_id1]): - if len(self.pd_graph.edges_out[layer_id1]) == 1 and len(self.pd_graph.edges_out[layer_id2]) == 1: - opt_layer_id2 = self.pd_graph.edges_out[layer_id2][j] - if (opt_layer_id1 in layer_id_list1) ^ (opt_layer_id2 in layer_id_list2): + for j, opt_layer_id1 in enumerate(self.pd_graph.edges_out[ + layer_id1]): + if len(self.pd_graph.edges_out[layer_id1]) == 1 and len( + self.pd_graph.edges_out[layer_id2]) == 1: + opt_layer_id2 = self.pd_graph.edges_out[layer_id2][ + j] + if (opt_layer_id1 in layer_id_list1) ^ ( + opt_layer_id2 in layer_id_list2): return False return True + sub_layers_list_list = list() id_list = list() ipt_opt_list = list() @@ -251,16 +265,16 @@ class ModuleGraph(object): sub_layers_list_list[j + 1].append(sub_layer) id_list.append(i) return sub_layers_list_list - - + def merge_node(self, sub_layers_list, attrs_table, module_name): sub_layers = sub_layers_list[0] diff_attrs_column = self.analyze_attrs_table(attrs_table) sub_layers, _, _ = rename_layers(sub_layers) - code_str = gen_layer_code(self.pd_graph, - sub_layers, - module_name, - different_attrs=diff_attrs_column) + code_str = gen_layer_code( + self.pd_graph, + sub_layers, + module_name, + different_attrs=diff_attrs_column) self.codes.append(code_str) for index, sub_layers in enumerate(sub_layers_list): inputs, outputs = get_inputs_outputs(self.pd_graph, sub_layers) @@ -270,18 +284,19 @@ class ModuleGraph(object): mn = module_name.lower() outputs = ["{}_{}".format(mn, index)] + outputs node_name = "{}_{}".format(module_name, index) - diff_attrs = dict() + diff_attrs = dict() for column, element in diff_attrs_column.items(): current_element = attrs_table.get(column).loc[node_name] if current_element != element: diff_attrs[column] = current_element - new_layer = PaddleLayer(id=list(sub_layers.keys())[-1], - kernel="module", - inputs=inputs_dict, - outputs=outputs, - module=module_name, - **diff_attrs) - + new_layer = PaddleLayer( + id=list(sub_layers.keys())[-1], + kernel="module", + inputs=inputs_dict, + outputs=outputs, + module=module_name, + **diff_attrs) + _, nn_param_nodes, _ = rename_layers(sub_layers, self.param_tree) param_node = PamareterNode(old_name=outputs[0]) for node in nn_param_nodes: @@ -297,7 +312,7 @@ class ModuleGraph(object): self.pd_graph.layers.pop(layer_id) self.pd_graph.build() - + def convert_subgraph_to_layer(self, combination, combination_id): combination_id_set = set(combination_id) for s in list(combination_id_set): @@ -318,51 +333,57 @@ class ModuleGraph(object): else: real_module_name = module_name + "__{}".format(i) if len(sub_layers_list) > 1: - attrs_table = construct_attrs_table(sub_layers_list, module_name=real_module_name) - self.merge_node(sub_layers_list, attrs_table, real_module_name) - layers, nn_param_nodes, _ = rename_layers(self.pd_graph.layers, self.param_tree, is_rename_module=True) - code_str = gen_layer_code(self.pd_graph, - layers, - self.pd_graph.name) + attrs_table = construct_attrs_table( + sub_layers_list, module_name=real_module_name) + self.merge_node(sub_layers_list, attrs_table, + real_module_name) + layers, nn_param_nodes, _ = rename_layers( + self.pd_graph.layers, self.param_tree, is_rename_module=True) + code_str = gen_layer_code(self.pd_graph, layers, self.pd_graph.name) self.codes.append(code_str) param_node = PamareterNode(old_name="Module") for node in nn_param_nodes: param_node.add_child(node) self.param_tree.add_node(param_node) - + def update_parameters(self): """ 更新参数。 """ self.param_tree.traverse() - full_old_name_list = copy.deepcopy(list(self.pd_graph.parameters.keys())) + full_old_name_list = copy.deepcopy( + list(self.pd_graph.parameters.keys())) for old_name, new_name in self.param_tree.old2new.items(): for full_old_name in full_old_name_list: if full_old_name.startswith("{}.".format(old_name)): - full_new_name = full_old_name.replace("{}.".format(old_name), "{}.".format(new_name)) + full_new_name = full_old_name.replace( + "{}.".format(old_name), "{}.".format(new_name)) params = self.pd_graph.parameters.pop(full_old_name) self.pd_graph.parameters[full_new_name] = params if full_old_name == old_name: full_new_name = full_old_name.replace(old_name, new_name) params = self.pd_graph.parameters.pop(full_old_name) self.pd_graph.parameters[full_new_name] = params - + def save_source_files(self, save_dir): def gen_main_code(): input_data_name = ', '.join(self.pd_graph.inputs) run_func_list = list() run_func_list.append("def main({}):".format(input_data_name)) - run_func_list.append(" # There are {} inputs.".format(len(self.pd_graph.inputs_info))) + run_func_list.append(" # There are {} inputs.".format( + len(self.pd_graph.inputs_info))) for k, v in self.pd_graph.inputs_info.items(): - run_func_list.append(" # {}: shape-{}, type-{}.".format(k, v[0], v[1])) - run_func_list.extend( - [" paddle.disable_static()", - " params = paddle.load('{}')".format(osp.join(osp.abspath(save_dir), "model.pdparams")), - " model = {}()".format(self.pd_graph.name), - " model.set_dict(params)", - " model.eval()", - " out = model({})".format(input_data_name), - " return out"]) + run_func_list.append(" # {}: shape-{}, type-{}.".format(k, v[ + 0], v[1])) + run_func_list.extend([ + " paddle.disable_static()", + " params = paddle.load('{}')".format( + osp.join(osp.abspath(save_dir), "model.pdparams")), + " model = {}()".format(self.pd_graph.name), + " model.set_dict(params)", " model.eval()", + " out = model({})".format(input_data_name), " return out" + ]) return "\n".join(run_func_list) + combination, combination_id = self.get_updation_information() self.convert_subgraph_to_layer(combination, combination_id) self.update_parameters() @@ -382,4 +403,3 @@ class ModuleGraph(object): run_func = gen_main_code() f.write(run_func) f.close() - \ No newline at end of file -- GitLab