提交 9ecd568e 编写于 作者: J jiangjiajun

add merge batch_norm optimization

上级 e6e5dbb9
...@@ -106,6 +106,7 @@ def tf2paddle(model_path, ...@@ -106,6 +106,7 @@ def tf2paddle(model_path,
# optimizer below is experimental # optimizer below is experimental
optimizer.merge_activation() optimizer.merge_activation()
optimizer.merge_bias() optimizer.merge_bias()
optimizer.merge_batch_norm()
else: else:
mapper = TFOpMapperNHWC(model) mapper = TFOpMapperNHWC(model)
optimizer = TFOptimizer(mapper) optimizer = TFOptimizer(mapper)
......
...@@ -351,3 +351,130 @@ class TFOptimizer(object): ...@@ -351,3 +351,130 @@ class TFOptimizer(object):
if node.fluid_code.layers[-1].op == "transpose": if node.fluid_code.layers[-1].op == "transpose":
node.fluid_code.layers[-2].output = name node.fluid_code.layers[-2].output = name
del node.fluid_code.layers[-1] del node.fluid_code.layers[-1]
def merge_batch_norm(self):
for i, name in enumerate(self.graph.topo_sort):
node = self.graph.get_node(name)
if node is None:
continue
is_batch_norm = True
if node.layer_type == "Add":
in_nodes0 = [
self.graph.get_node(in_name) for in_name in node.inputs
]
if in_nodes0[0].layer_type != "Mul" or in_nodes0[
1].layer_type != "Sub":
is_batch_norm = False
continue
in_nodes1 = [
self.graph.get_node(in_name)
for in_name in in_nodes0[0].inputs
]
in_nodes2 = [
self.graph.get_node(in_name)
for in_name in in_nodes0[1].inputs
]
if len(in_nodes1[0].out_shapes[0]) != 4:
is_batch_norm = False
continue
if in_nodes1[1].layer_type != "Mul":
is_batch_norm = False
continue
if in_nodes2[0].layer_type != "Const" or in_nodes2[
1].layer_type != "Mul":
is_batch_norm = False
continue
in_nodes3 = [
self.graph.get_node(in_name)
for in_name in in_nodes1[1].inputs
]
if in_nodes3[0].layer_type != "Rsqrt" or in_nodes3[
1].layer_type != "Const":
is_batch_norm = False
continue
in_nodes4 = [
self.graph.get_node(in_name)
for in_name in in_nodes2[1].inputs
]
if in_nodes4[0].layer_type != "Const" or in_nodes4[
1].layer_name != in_nodes1[1].layer_name:
is_batch_norm = False
continue
in_nodes5 = self.graph.get_node(in_nodes3[0].inputs[0])
if in_nodes5.layer_type != "Add":
is_batch_norm = False
continue
in_nodes6 = [
self.graph.get_node(in_name) for in_name in in_nodes5.inputs
]
if in_nodes6[0].layer_type != "Const" or in_nodes6[
1].layer_type != "Const":
is_batch_norm = False
continue
conv_shape = in_nodes1[0].out_shapes[0]
if conv_shape[3] < 0:
is_batch_norm = False
continue
# moving_variance
if in_nodes6[0].value.size != conv_shape[3]:
is_batch_norm = False
continue
# epsilon
if in_nodes6[1].value.size != 1:
is_batch_norm = False
continue
# gamma
if in_nodes3[1].value.size != conv_shape[3]:
is_batch_norm = False
continue
# moving_mean
if in_nodes4[0].value.size != conv_shape[3]:
is_batch_norm = False
continue
# beta
if in_nodes2[0].value.size != conv_shape[3]:
is_batch_norm = False
continue
if is_batch_norm:
index = in_nodes1[0].outputs.index(in_nodes0[0].layer_name)
del in_nodes1[0].outputs[index]
node.layer_type = "FusedBatchNorm"
node.inputs = [in_nodes1[0].layer_name]
node.outputs = node.outputs
act = node.fluid_code.layers[-1].param_attr.get("act", None)
node.fluid_code.clear()
attr = {
"epsilon": in_nodes6[1].value,
"param_attr": string(in_nodes3[1].layer_name),
"bias_attr": string(in_nodes2[0].layer_name),
"moving_mean_name": string(in_nodes4[0].layer_name),
"moving_variance_name": string(in_nodes6[0].layer_name),
"is_test": True,
"act": act
}
node.fluid_code.add_layer("batch_norm",
inputs=cp.copy(in_nodes1[0]),
output=node,
param_attr=attr)
del self.graph.node_map[in_nodes0[0].layer_name]
del self.graph.node_map[in_nodes0[1].layer_name]
del self.graph.node_map[in_nodes1[1].layer_name]
del self.graph.node_map[in_nodes2[1].layer_name]
del self.graph.node_map[in_nodes3[0].layer_name]
del self.graph.node_map[in_nodes4[0].layer_name]
del self.graph.node_map[in_nodes5.layer_name]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册