提交 b4eceac8 编写于 作者: W wjj19950828

Fixed FusedBatchNorm

上级 bfa2d61f
......@@ -563,7 +563,6 @@ class TFOpMapper():
n, h, w, c = input.out_shapes[0]
else:
n, c, h, w = input.out_shapes[0]
self.params["{}_{}".format(node.name, gamma.name)] = self.params[
gamma.name]
self.params["{}_{}".format(node.name, beta.name)] = self.params[
......@@ -584,7 +583,8 @@ class TFOpMapper():
moving_mean.name)),
moving_variance_name=string("{}_{}".format(node.name,
moving_var.name)),
is_test=True)
is_test=True,
trainable_statistics=node.get_attr("is_training"))
if data_format == "NHWC":
self.paddle_graph.add_layer(
......
......@@ -21,9 +21,9 @@ from x2paddle.optimizer.pytorch_code_optimizer.subgraphs_union import construct_
from x2paddle.optimizer.pytorch_code_optimizer.layer_code_generator import gen_layer_code, rename_layers
from x2paddle.optimizer.pytorch_code_optimizer.parameter_tree import PamareterNode, PamareterTree
NoModuleStart = ["paddle.nn.ReLU"]
class Apriori(object):
""" 使用Apriori算法挖掘频繁子图
1. 构建频繁1项集
......@@ -33,6 +33,7 @@ class Apriori(object):
Args:
min_support (int): 子图出现次数的最小值。
"""
def __init__(self, min_support):
self.min_support = min_support
......@@ -49,9 +50,9 @@ class Apriori(object):
if layer.kernel == "paddle.to_tensor" or \
layer.kernel == "prim.if" or \
layer.kernel == "prim.loop": #or \
# layer.kernel == "prim.list" or \
# layer.kernel == "prim.tuple" or \
# layer.kernel == "prim.dict_construct":
# layer.kernel == "prim.list" or \
# layer.kernel == "prim.tuple" or \
# layer.kernel == "prim.dict_construct":
continue
if self.pd_graph.edges_in.get(layer_id, 0) == 0 and \
self.pd_graph.edges_out.get(layer_id, 0) == 0:
......@@ -100,8 +101,9 @@ class Apriori(object):
class DP(object):
""" 使用动态规划找到使代码最短的组合方式。
""" 使用动态规划找到使代码最短的组合方式。
"""
def __init__(self, combination_itemset):
self.combination_itemset = combination_itemset
......@@ -145,7 +147,8 @@ class DP(object):
if j - 1 < 0:
last_itemset = list()
else:
last_itemset = copy.deepcopy(layer_combination_list[j - 1])
last_itemset = copy.deepcopy(layer_combination_list[
j - 1])
else:
if j == prefix_ids[0]:
min_count = len(layer_combination_list[j]) + 1
......@@ -163,6 +166,7 @@ class DP(object):
class ModuleGraph(object):
""" 更新PaddleGraph,生成代码。
"""
def __init__(self, graph):
self.pd_graph = graph
self.global_layers = graph.get_global_layers()
......@@ -196,7 +200,7 @@ class ModuleGraph(object):
if len(elements_list) > 1:
max_ct = 0
for k, v in zip(elements_list, count_list):
if v > max_ct and str(k) != "nan" :
if v > max_ct and str(k) != "nan":
max_ele = k
max_ct = v
diff_attrs_column[column] = max_ele
......@@ -214,24 +218,34 @@ class ModuleGraph(object):
layer_id2 = layer_id_list2[i]
if layer_id2 not in self.pd_graph.edges_in:
return False
if len(self.pd_graph.edges_in[layer_id1]) != len(self.pd_graph.edges_in[layer_id2]):
if len(self.pd_graph.edges_in[layer_id1]) != len(
self.pd_graph.edges_in[layer_id2]):
return False
for j, ipt_layer_id1 in enumerate(self.pd_graph.edges_in[layer_id1]):
for j, ipt_layer_id1 in enumerate(self.pd_graph.edges_in[
layer_id1]):
ipt_layer_id2 = self.pd_graph.edges_in[layer_id2][j]
if (ipt_layer_id1 in layer_id_list1) ^ (ipt_layer_id2 in layer_id_list2):
if (ipt_layer_id1 in layer_id_list1) ^ (
ipt_layer_id2 in layer_id_list2):
return False
if (layer_id1 in self.pd_graph.edges_out) ^ (layer_id2 in self.pd_graph.edges_out):
if (layer_id1 in self.pd_graph.edges_out) ^ (
layer_id2 in self.pd_graph.edges_out):
return False
if (layer_id1 in self.pd_graph.edges_out) and (layer_id2 in self.pd_graph.edges_out):
if (layer_id1 in self.pd_graph.edges_out) and (
layer_id2 in self.pd_graph.edges_out):
if (len(self.pd_graph.edges_out[layer_id1]) > 1 and len(self.pd_graph.edges_out[layer_id2]) == 1) or \
(len(self.pd_graph.edges_out[layer_id1]) == 1 and len(self.pd_graph.edges_out[layer_id2]) > 1):
return False
for j, opt_layer_id1 in enumerate(self.pd_graph.edges_out[layer_id1]):
if len(self.pd_graph.edges_out[layer_id1]) == 1 and len(self.pd_graph.edges_out[layer_id2]) == 1:
opt_layer_id2 = self.pd_graph.edges_out[layer_id2][j]
if (opt_layer_id1 in layer_id_list1) ^ (opt_layer_id2 in layer_id_list2):
for j, opt_layer_id1 in enumerate(self.pd_graph.edges_out[
layer_id1]):
if len(self.pd_graph.edges_out[layer_id1]) == 1 and len(
self.pd_graph.edges_out[layer_id2]) == 1:
opt_layer_id2 = self.pd_graph.edges_out[layer_id2][
j]
if (opt_layer_id1 in layer_id_list1) ^ (
opt_layer_id2 in layer_id_list2):
return False
return True
sub_layers_list_list = list()
id_list = list()
ipt_opt_list = list()
......@@ -252,12 +266,12 @@ class ModuleGraph(object):
id_list.append(i)
return sub_layers_list_list
def merge_node(self, sub_layers_list, attrs_table, module_name):
sub_layers = sub_layers_list[0]
diff_attrs_column = self.analyze_attrs_table(attrs_table)
sub_layers, _, _ = rename_layers(sub_layers)
code_str = gen_layer_code(self.pd_graph,
code_str = gen_layer_code(
self.pd_graph,
sub_layers,
module_name,
different_attrs=diff_attrs_column)
......@@ -275,7 +289,8 @@ class ModuleGraph(object):
current_element = attrs_table.get(column).loc[node_name]
if current_element != element:
diff_attrs[column] = current_element
new_layer = PaddleLayer(id=list(sub_layers.keys())[-1],
new_layer = PaddleLayer(
id=list(sub_layers.keys())[-1],
kernel="module",
inputs=inputs_dict,
outputs=outputs,
......@@ -318,12 +333,13 @@ class ModuleGraph(object):
else:
real_module_name = module_name + "__{}".format(i)
if len(sub_layers_list) > 1:
attrs_table = construct_attrs_table(sub_layers_list, module_name=real_module_name)
self.merge_node(sub_layers_list, attrs_table, real_module_name)
layers, nn_param_nodes, _ = rename_layers(self.pd_graph.layers, self.param_tree, is_rename_module=True)
code_str = gen_layer_code(self.pd_graph,
layers,
self.pd_graph.name)
attrs_table = construct_attrs_table(
sub_layers_list, module_name=real_module_name)
self.merge_node(sub_layers_list, attrs_table,
real_module_name)
layers, nn_param_nodes, _ = rename_layers(
self.pd_graph.layers, self.param_tree, is_rename_module=True)
code_str = gen_layer_code(self.pd_graph, layers, self.pd_graph.name)
self.codes.append(code_str)
param_node = PamareterNode(old_name="Module")
for node in nn_param_nodes:
......@@ -334,11 +350,13 @@ class ModuleGraph(object):
""" 更新参数。
"""
self.param_tree.traverse()
full_old_name_list = copy.deepcopy(list(self.pd_graph.parameters.keys()))
full_old_name_list = copy.deepcopy(
list(self.pd_graph.parameters.keys()))
for old_name, new_name in self.param_tree.old2new.items():
for full_old_name in full_old_name_list:
if full_old_name.startswith("{}.".format(old_name)):
full_new_name = full_old_name.replace("{}.".format(old_name), "{}.".format(new_name))
full_new_name = full_old_name.replace(
"{}.".format(old_name), "{}.".format(new_name))
params = self.pd_graph.parameters.pop(full_old_name)
self.pd_graph.parameters[full_new_name] = params
if full_old_name == old_name:
......@@ -351,18 +369,21 @@ class ModuleGraph(object):
input_data_name = ', '.join(self.pd_graph.inputs)
run_func_list = list()
run_func_list.append("def main({}):".format(input_data_name))
run_func_list.append(" # There are {} inputs.".format(len(self.pd_graph.inputs_info)))
run_func_list.append(" # There are {} inputs.".format(
len(self.pd_graph.inputs_info)))
for k, v in self.pd_graph.inputs_info.items():
run_func_list.append(" # {}: shape-{}, type-{}.".format(k, v[0], v[1]))
run_func_list.extend(
[" paddle.disable_static()",
" params = paddle.load('{}')".format(osp.join(osp.abspath(save_dir), "model.pdparams")),
run_func_list.append(" # {}: shape-{}, type-{}.".format(k, v[
0], v[1]))
run_func_list.extend([
" paddle.disable_static()",
" params = paddle.load('{}')".format(
osp.join(osp.abspath(save_dir), "model.pdparams")),
" model = {}()".format(self.pd_graph.name),
" model.set_dict(params)",
" model.eval()",
" out = model({})".format(input_data_name),
" return out"])
" model.set_dict(params)", " model.eval()",
" out = model({})".format(input_data_name), " return out"
])
return "\n".join(run_func_list)
combination, combination_id = self.get_updation_information()
self.convert_subgraph_to_layer(combination, combination_id)
self.update_parameters()
......@@ -382,4 +403,3 @@ class ModuleGraph(object):
run_func = gen_main_code()
f.write(run_func)
f.close()
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册