提交 b4eceac8 编写于 作者: W wjj19950828

Fixed FusedBatchNorm

上级 bfa2d61f
...@@ -563,7 +563,6 @@ class TFOpMapper(): ...@@ -563,7 +563,6 @@ class TFOpMapper():
n, h, w, c = input.out_shapes[0] n, h, w, c = input.out_shapes[0]
else: else:
n, c, h, w = input.out_shapes[0] n, c, h, w = input.out_shapes[0]
self.params["{}_{}".format(node.name, gamma.name)] = self.params[ self.params["{}_{}".format(node.name, gamma.name)] = self.params[
gamma.name] gamma.name]
self.params["{}_{}".format(node.name, beta.name)] = self.params[ self.params["{}_{}".format(node.name, beta.name)] = self.params[
...@@ -584,7 +583,8 @@ class TFOpMapper(): ...@@ -584,7 +583,8 @@ class TFOpMapper():
moving_mean.name)), moving_mean.name)),
moving_variance_name=string("{}_{}".format(node.name, moving_variance_name=string("{}_{}".format(node.name,
moving_var.name)), moving_var.name)),
is_test=True) is_test=True,
trainable_statistics=node.get_attr("is_training"))
if data_format == "NHWC": if data_format == "NHWC":
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
......
...@@ -21,27 +21,28 @@ from x2paddle.optimizer.pytorch_code_optimizer.subgraphs_union import construct_ ...@@ -21,27 +21,28 @@ from x2paddle.optimizer.pytorch_code_optimizer.subgraphs_union import construct_
from x2paddle.optimizer.pytorch_code_optimizer.layer_code_generator import gen_layer_code, rename_layers from x2paddle.optimizer.pytorch_code_optimizer.layer_code_generator import gen_layer_code, rename_layers
from x2paddle.optimizer.pytorch_code_optimizer.parameter_tree import PamareterNode, PamareterTree from x2paddle.optimizer.pytorch_code_optimizer.parameter_tree import PamareterNode, PamareterTree
NoModuleStart = ["paddle.nn.ReLU"] NoModuleStart = ["paddle.nn.ReLU"]
class Apriori(object): class Apriori(object):
""" 使用Apriori算法挖掘频繁子图 """ 使用Apriori算法挖掘频繁子图
1. 构建频繁1项集 1. 构建频繁1项集
2. 挖掘频繁k项集 2. 挖掘频繁k项集
3. 最终k项集和节点数满足最少节点数的子图组成集合GS 3. 最终k项集和节点数满足最少节点数的子图组成集合GS
Args: Args:
min_support (int): 子图出现次数的最小值。 min_support (int): 子图出现次数的最小值。
""" """
def __init__(self, min_support): def __init__(self, min_support):
self.min_support = min_support self.min_support = min_support
def is_match(self, item, sublayers): def is_match(self, item, sublayers):
for i in range(len(item)): for i in range(len(item)):
if len(sublayers) <= i or item[i] != sublayers[i].kernel: if len(sublayers) <= i or item[i] != sublayers[i].kernel:
return False return False
return True return True
def create_C1(self): def create_C1(self):
# 构建候选1-项集 # 构建候选1-项集
C1 = list() C1 = list()
...@@ -49,9 +50,9 @@ class Apriori(object): ...@@ -49,9 +50,9 @@ class Apriori(object):
if layer.kernel == "paddle.to_tensor" or \ if layer.kernel == "paddle.to_tensor" or \
layer.kernel == "prim.if" or \ layer.kernel == "prim.if" or \
layer.kernel == "prim.loop": #or \ layer.kernel == "prim.loop": #or \
# layer.kernel == "prim.list" or \ # layer.kernel == "prim.list" or \
# layer.kernel == "prim.tuple" or \ # layer.kernel == "prim.tuple" or \
# layer.kernel == "prim.dict_construct": # layer.kernel == "prim.dict_construct":
continue continue
if self.pd_graph.edges_in.get(layer_id, 0) == 0 and \ if self.pd_graph.edges_in.get(layer_id, 0) == 0 and \
self.pd_graph.edges_out.get(layer_id, 0) == 0: self.pd_graph.edges_out.get(layer_id, 0) == 0:
...@@ -59,7 +60,7 @@ class Apriori(object): ...@@ -59,7 +60,7 @@ class Apriori(object):
if [layer.kernel] not in C1: if [layer.kernel] not in C1:
C1.append([layer.kernel]) C1.append([layer.kernel])
return C1 return C1
def create_Ck(self, Lk_last, C1): def create_Ck(self, Lk_last, C1):
# 构建候选k-项集 # 构建候选k-项集
Ck = list() Ck = list()
...@@ -71,7 +72,7 @@ class Apriori(object): ...@@ -71,7 +72,7 @@ class Apriori(object):
continue continue
Ck.append(new_item) Ck.append(new_item)
return Ck return Ck
def generate_Lk_by_Ck(self, Ck): def generate_Lk_by_Ck(self, Ck):
# 生成频繁k-项集 # 生成频繁k-项集
Lk = list() Lk = list()
...@@ -82,9 +83,9 @@ class Apriori(object): ...@@ -82,9 +83,9 @@ class Apriori(object):
if self.is_match(item, sublayers): if self.is_match(item, sublayers):
count += 1 count += 1
if count >= self.min_support: if count >= self.min_support:
Lk.append(item) Lk.append(item)
return Lk return Lk
def run(self, graph): def run(self, graph):
self.pd_graph = graph self.pd_graph = graph
self.layers = graph.layers self.layers = graph.layers
...@@ -97,14 +98,15 @@ class Apriori(object): ...@@ -97,14 +98,15 @@ class Apriori(object):
Lk = self.generate_Lk_by_Ck(Ck) Lk = self.generate_Lk_by_Ck(Ck)
itemset.extend(Lk) itemset.extend(Lk)
return itemset return itemset
class DP(object): class DP(object):
""" 使用动态规划找到使代码最短的组合方式。 """ 使用动态规划找到使代码最短的组合方式。
""" """
def __init__(self, combination_itemset): def __init__(self, combination_itemset):
self.combination_itemset = combination_itemset self.combination_itemset = combination_itemset
def get_combination_id(self, combination, layers): def get_combination_id(self, combination, layers):
combination_id = list() combination_id = list()
for layer_obj in combination: for layer_obj in combination:
...@@ -117,7 +119,7 @@ class DP(object): ...@@ -117,7 +119,7 @@ class DP(object):
else: else:
combination_id.append(-1) combination_id.append(-1)
return combination_id return combination_id
def run(self, graph): def run(self, graph):
layers = graph.layers layers = graph.layers
layer_combination_list = list() layer_combination_list = list()
...@@ -135,7 +137,7 @@ class DP(object): ...@@ -135,7 +137,7 @@ class DP(object):
current_layer_id = list(layers.keys())[j] current_layer_id = list(layers.keys())[j]
current_layer = list(layers.values())[j] current_layer = list(layers.values())[j]
current_itemset.insert(0, current_layer_id) current_itemset.insert(0, current_layer_id)
kernel_itemset.insert(0, current_layer.kernel) kernel_itemset.insert(0, current_layer.kernel)
if kernel_itemset in self.combination_itemset: if kernel_itemset in self.combination_itemset:
current_count = len(layer_combination_list[j - 1]) current_count = len(layer_combination_list[j - 1])
all_count = current_count + 1 all_count = current_count + 1
...@@ -145,7 +147,8 @@ class DP(object): ...@@ -145,7 +147,8 @@ class DP(object):
if j - 1 < 0: if j - 1 < 0:
last_itemset = list() last_itemset = list()
else: else:
last_itemset = copy.deepcopy(layer_combination_list[j - 1]) last_itemset = copy.deepcopy(layer_combination_list[
j - 1])
else: else:
if j == prefix_ids[0]: if j == prefix_ids[0]:
min_count = len(layer_combination_list[j]) + 1 min_count = len(layer_combination_list[j]) + 1
...@@ -158,24 +161,25 @@ class DP(object): ...@@ -158,24 +161,25 @@ class DP(object):
final_combination = layer_combination_list[-1] final_combination = layer_combination_list[-1]
combination_id = self.get_combination_id(final_combination, layers) combination_id = self.get_combination_id(final_combination, layers)
return final_combination, combination_id return final_combination, combination_id
class ModuleGraph(object): class ModuleGraph(object):
""" 更新PaddleGraph,生成代码。 """ 更新PaddleGraph,生成代码。
""" """
def __init__(self, graph): def __init__(self, graph):
self.pd_graph = graph self.pd_graph = graph
self.global_layers = graph.get_global_layers() self.global_layers = graph.get_global_layers()
self.codes = list() self.codes = list()
self.param_tree = PamareterTree() self.param_tree = PamareterTree()
def get_updation_information(self): def get_updation_information(self):
aprior = Apriori(3) aprior = Apriori(3)
combination_itemset = aprior.run(self.pd_graph) combination_itemset = aprior.run(self.pd_graph)
dp = DP(combination_itemset) dp = DP(combination_itemset)
combination, combination_id = dp.run(self.pd_graph) combination, combination_id = dp.run(self.pd_graph)
return combination, combination_id return combination, combination_id
def analyze_attrs_table(self, attrs_table): def analyze_attrs_table(self, attrs_table):
""" 分析属性表格,哪些属性取值不一致。 """ 分析属性表格,哪些属性取值不一致。
""" """
...@@ -196,12 +200,12 @@ class ModuleGraph(object): ...@@ -196,12 +200,12 @@ class ModuleGraph(object):
if len(elements_list) > 1: if len(elements_list) > 1:
max_ct = 0 max_ct = 0
for k, v in zip(elements_list, count_list): for k, v in zip(elements_list, count_list):
if v > max_ct and str(k) != "nan" : if v > max_ct and str(k) != "nan":
max_ele = k max_ele = k
max_ct = v max_ct = v
diff_attrs_column[column] = max_ele diff_attrs_column[column] = max_ele
return diff_attrs_column return diff_attrs_column
def analyze_graph(self, sub_layers_list): def analyze_graph(self, sub_layers_list):
def is_same(sub_layers1, sub_layers2, id1, id2): def is_same(sub_layers1, sub_layers2, id1, id2):
inputs1, outputs1 = ipt_opt_list[id1] inputs1, outputs1 = ipt_opt_list[id1]
...@@ -214,24 +218,34 @@ class ModuleGraph(object): ...@@ -214,24 +218,34 @@ class ModuleGraph(object):
layer_id2 = layer_id_list2[i] layer_id2 = layer_id_list2[i]
if layer_id2 not in self.pd_graph.edges_in: if layer_id2 not in self.pd_graph.edges_in:
return False return False
if len(self.pd_graph.edges_in[layer_id1]) != len(self.pd_graph.edges_in[layer_id2]): if len(self.pd_graph.edges_in[layer_id1]) != len(
self.pd_graph.edges_in[layer_id2]):
return False return False
for j, ipt_layer_id1 in enumerate(self.pd_graph.edges_in[layer_id1]): for j, ipt_layer_id1 in enumerate(self.pd_graph.edges_in[
layer_id1]):
ipt_layer_id2 = self.pd_graph.edges_in[layer_id2][j] ipt_layer_id2 = self.pd_graph.edges_in[layer_id2][j]
if (ipt_layer_id1 in layer_id_list1) ^ (ipt_layer_id2 in layer_id_list2): if (ipt_layer_id1 in layer_id_list1) ^ (
ipt_layer_id2 in layer_id_list2):
return False return False
if (layer_id1 in self.pd_graph.edges_out) ^ (layer_id2 in self.pd_graph.edges_out): if (layer_id1 in self.pd_graph.edges_out) ^ (
layer_id2 in self.pd_graph.edges_out):
return False return False
if (layer_id1 in self.pd_graph.edges_out) and (layer_id2 in self.pd_graph.edges_out): if (layer_id1 in self.pd_graph.edges_out) and (
layer_id2 in self.pd_graph.edges_out):
if (len(self.pd_graph.edges_out[layer_id1]) > 1 and len(self.pd_graph.edges_out[layer_id2]) == 1) or \ if (len(self.pd_graph.edges_out[layer_id1]) > 1 and len(self.pd_graph.edges_out[layer_id2]) == 1) or \
(len(self.pd_graph.edges_out[layer_id1]) == 1 and len(self.pd_graph.edges_out[layer_id2]) > 1): (len(self.pd_graph.edges_out[layer_id1]) == 1 and len(self.pd_graph.edges_out[layer_id2]) > 1):
return False return False
for j, opt_layer_id1 in enumerate(self.pd_graph.edges_out[layer_id1]): for j, opt_layer_id1 in enumerate(self.pd_graph.edges_out[
if len(self.pd_graph.edges_out[layer_id1]) == 1 and len(self.pd_graph.edges_out[layer_id2]) == 1: layer_id1]):
opt_layer_id2 = self.pd_graph.edges_out[layer_id2][j] if len(self.pd_graph.edges_out[layer_id1]) == 1 and len(
if (opt_layer_id1 in layer_id_list1) ^ (opt_layer_id2 in layer_id_list2): self.pd_graph.edges_out[layer_id2]) == 1:
opt_layer_id2 = self.pd_graph.edges_out[layer_id2][
j]
if (opt_layer_id1 in layer_id_list1) ^ (
opt_layer_id2 in layer_id_list2):
return False return False
return True return True
sub_layers_list_list = list() sub_layers_list_list = list()
id_list = list() id_list = list()
ipt_opt_list = list() ipt_opt_list = list()
...@@ -251,16 +265,16 @@ class ModuleGraph(object): ...@@ -251,16 +265,16 @@ class ModuleGraph(object):
sub_layers_list_list[j + 1].append(sub_layer) sub_layers_list_list[j + 1].append(sub_layer)
id_list.append(i) id_list.append(i)
return sub_layers_list_list return sub_layers_list_list
def merge_node(self, sub_layers_list, attrs_table, module_name): def merge_node(self, sub_layers_list, attrs_table, module_name):
sub_layers = sub_layers_list[0] sub_layers = sub_layers_list[0]
diff_attrs_column = self.analyze_attrs_table(attrs_table) diff_attrs_column = self.analyze_attrs_table(attrs_table)
sub_layers, _, _ = rename_layers(sub_layers) sub_layers, _, _ = rename_layers(sub_layers)
code_str = gen_layer_code(self.pd_graph, code_str = gen_layer_code(
sub_layers, self.pd_graph,
module_name, sub_layers,
different_attrs=diff_attrs_column) module_name,
different_attrs=diff_attrs_column)
self.codes.append(code_str) self.codes.append(code_str)
for index, sub_layers in enumerate(sub_layers_list): for index, sub_layers in enumerate(sub_layers_list):
inputs, outputs = get_inputs_outputs(self.pd_graph, sub_layers) inputs, outputs = get_inputs_outputs(self.pd_graph, sub_layers)
...@@ -270,18 +284,19 @@ class ModuleGraph(object): ...@@ -270,18 +284,19 @@ class ModuleGraph(object):
mn = module_name.lower() mn = module_name.lower()
outputs = ["{}_{}".format(mn, index)] + outputs outputs = ["{}_{}".format(mn, index)] + outputs
node_name = "{}_{}".format(module_name, index) node_name = "{}_{}".format(module_name, index)
diff_attrs = dict() diff_attrs = dict()
for column, element in diff_attrs_column.items(): for column, element in diff_attrs_column.items():
current_element = attrs_table.get(column).loc[node_name] current_element = attrs_table.get(column).loc[node_name]
if current_element != element: if current_element != element:
diff_attrs[column] = current_element diff_attrs[column] = current_element
new_layer = PaddleLayer(id=list(sub_layers.keys())[-1], new_layer = PaddleLayer(
kernel="module", id=list(sub_layers.keys())[-1],
inputs=inputs_dict, kernel="module",
outputs=outputs, inputs=inputs_dict,
module=module_name, outputs=outputs,
**diff_attrs) module=module_name,
**diff_attrs)
_, nn_param_nodes, _ = rename_layers(sub_layers, self.param_tree) _, nn_param_nodes, _ = rename_layers(sub_layers, self.param_tree)
param_node = PamareterNode(old_name=outputs[0]) param_node = PamareterNode(old_name=outputs[0])
for node in nn_param_nodes: for node in nn_param_nodes:
...@@ -297,7 +312,7 @@ class ModuleGraph(object): ...@@ -297,7 +312,7 @@ class ModuleGraph(object):
self.pd_graph.layers.pop(layer_id) self.pd_graph.layers.pop(layer_id)
self.pd_graph.build() self.pd_graph.build()
def convert_subgraph_to_layer(self, combination, combination_id): def convert_subgraph_to_layer(self, combination, combination_id):
combination_id_set = set(combination_id) combination_id_set = set(combination_id)
for s in list(combination_id_set): for s in list(combination_id_set):
...@@ -318,51 +333,57 @@ class ModuleGraph(object): ...@@ -318,51 +333,57 @@ class ModuleGraph(object):
else: else:
real_module_name = module_name + "__{}".format(i) real_module_name = module_name + "__{}".format(i)
if len(sub_layers_list) > 1: if len(sub_layers_list) > 1:
attrs_table = construct_attrs_table(sub_layers_list, module_name=real_module_name) attrs_table = construct_attrs_table(
self.merge_node(sub_layers_list, attrs_table, real_module_name) sub_layers_list, module_name=real_module_name)
layers, nn_param_nodes, _ = rename_layers(self.pd_graph.layers, self.param_tree, is_rename_module=True) self.merge_node(sub_layers_list, attrs_table,
code_str = gen_layer_code(self.pd_graph, real_module_name)
layers, layers, nn_param_nodes, _ = rename_layers(
self.pd_graph.name) self.pd_graph.layers, self.param_tree, is_rename_module=True)
code_str = gen_layer_code(self.pd_graph, layers, self.pd_graph.name)
self.codes.append(code_str) self.codes.append(code_str)
param_node = PamareterNode(old_name="Module") param_node = PamareterNode(old_name="Module")
for node in nn_param_nodes: for node in nn_param_nodes:
param_node.add_child(node) param_node.add_child(node)
self.param_tree.add_node(param_node) self.param_tree.add_node(param_node)
def update_parameters(self): def update_parameters(self):
""" 更新参数。 """ 更新参数。
""" """
self.param_tree.traverse() self.param_tree.traverse()
full_old_name_list = copy.deepcopy(list(self.pd_graph.parameters.keys())) full_old_name_list = copy.deepcopy(
list(self.pd_graph.parameters.keys()))
for old_name, new_name in self.param_tree.old2new.items(): for old_name, new_name in self.param_tree.old2new.items():
for full_old_name in full_old_name_list: for full_old_name in full_old_name_list:
if full_old_name.startswith("{}.".format(old_name)): if full_old_name.startswith("{}.".format(old_name)):
full_new_name = full_old_name.replace("{}.".format(old_name), "{}.".format(new_name)) full_new_name = full_old_name.replace(
"{}.".format(old_name), "{}.".format(new_name))
params = self.pd_graph.parameters.pop(full_old_name) params = self.pd_graph.parameters.pop(full_old_name)
self.pd_graph.parameters[full_new_name] = params self.pd_graph.parameters[full_new_name] = params
if full_old_name == old_name: if full_old_name == old_name:
full_new_name = full_old_name.replace(old_name, new_name) full_new_name = full_old_name.replace(old_name, new_name)
params = self.pd_graph.parameters.pop(full_old_name) params = self.pd_graph.parameters.pop(full_old_name)
self.pd_graph.parameters[full_new_name] = params self.pd_graph.parameters[full_new_name] = params
def save_source_files(self, save_dir): def save_source_files(self, save_dir):
def gen_main_code(): def gen_main_code():
input_data_name = ', '.join(self.pd_graph.inputs) input_data_name = ', '.join(self.pd_graph.inputs)
run_func_list = list() run_func_list = list()
run_func_list.append("def main({}):".format(input_data_name)) run_func_list.append("def main({}):".format(input_data_name))
run_func_list.append(" # There are {} inputs.".format(len(self.pd_graph.inputs_info))) run_func_list.append(" # There are {} inputs.".format(
len(self.pd_graph.inputs_info)))
for k, v in self.pd_graph.inputs_info.items(): for k, v in self.pd_graph.inputs_info.items():
run_func_list.append(" # {}: shape-{}, type-{}.".format(k, v[0], v[1])) run_func_list.append(" # {}: shape-{}, type-{}.".format(k, v[
run_func_list.extend( 0], v[1]))
[" paddle.disable_static()", run_func_list.extend([
" params = paddle.load('{}')".format(osp.join(osp.abspath(save_dir), "model.pdparams")), " paddle.disable_static()",
" model = {}()".format(self.pd_graph.name), " params = paddle.load('{}')".format(
" model.set_dict(params)", osp.join(osp.abspath(save_dir), "model.pdparams")),
" model.eval()", " model = {}()".format(self.pd_graph.name),
" out = model({})".format(input_data_name), " model.set_dict(params)", " model.eval()",
" return out"]) " out = model({})".format(input_data_name), " return out"
])
return "\n".join(run_func_list) return "\n".join(run_func_list)
combination, combination_id = self.get_updation_information() combination, combination_id = self.get_updation_information()
self.convert_subgraph_to_layer(combination, combination_id) self.convert_subgraph_to_layer(combination, combination_id)
self.update_parameters() self.update_parameters()
...@@ -382,4 +403,3 @@ class ModuleGraph(object): ...@@ -382,4 +403,3 @@ class ModuleGraph(object):
run_func = gen_main_code() run_func = gen_main_code()
f.write(run_func) f.write(run_func)
f.close() f.close()
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册