From 9ecd568e8365cdc9ba8df2f43475519bcbfdaa8c Mon Sep 17 00:00:00 2001 From: jiangjiajun Date: Wed, 4 Sep 2019 12:05:07 +0800 Subject: [PATCH] add merge batch_norm optimization --- x2paddle/convert.py | 1 + x2paddle/optimizer/tf_optimizer.py | 127 +++++++++++++++++++++++++++++ 2 files changed, 128 insertions(+) diff --git a/x2paddle/convert.py b/x2paddle/convert.py index 9642fcc..673c9f2 100644 --- a/x2paddle/convert.py +++ b/x2paddle/convert.py @@ -106,6 +106,7 @@ def tf2paddle(model_path, # optimizer below is experimental optimizer.merge_activation() optimizer.merge_bias() + optimizer.merge_batch_norm() else: mapper = TFOpMapperNHWC(model) optimizer = TFOptimizer(mapper) diff --git a/x2paddle/optimizer/tf_optimizer.py b/x2paddle/optimizer/tf_optimizer.py index 223a9b8..610a433 100644 --- a/x2paddle/optimizer/tf_optimizer.py +++ b/x2paddle/optimizer/tf_optimizer.py @@ -351,3 +351,130 @@ class TFOptimizer(object): if node.fluid_code.layers[-1].op == "transpose": node.fluid_code.layers[-2].output = name del node.fluid_code.layers[-1] + + def merge_batch_norm(self): + for i, name in enumerate(self.graph.topo_sort): + node = self.graph.get_node(name) + if node is None: + continue + is_batch_norm = True + if node.layer_type == "Add": + in_nodes0 = [ + self.graph.get_node(in_name) for in_name in node.inputs + ] + if in_nodes0[0].layer_type != "Mul" or in_nodes0[ + 1].layer_type != "Sub": + is_batch_norm = False + continue + + in_nodes1 = [ + self.graph.get_node(in_name) + for in_name in in_nodes0[0].inputs + ] + in_nodes2 = [ + self.graph.get_node(in_name) + for in_name in in_nodes0[1].inputs + ] + if len(in_nodes1[0].out_shapes[0]) != 4: + is_batch_norm = False + continue + if in_nodes1[1].layer_type != "Mul": + is_batch_norm = False + continue + + if in_nodes2[0].layer_type != "Const" or in_nodes2[ + 1].layer_type != "Mul": + is_batch_norm = False + continue + + in_nodes3 = [ + self.graph.get_node(in_name) + for in_name in in_nodes1[1].inputs + ] + if in_nodes3[0].layer_type != "Rsqrt" or in_nodes3[ + 1].layer_type != "Const": + is_batch_norm = False + continue + + in_nodes4 = [ + self.graph.get_node(in_name) + for in_name in in_nodes2[1].inputs + ] + if in_nodes4[0].layer_type != "Const" or in_nodes4[ + 1].layer_name != in_nodes1[1].layer_name: + is_batch_norm = False + continue + + in_nodes5 = self.graph.get_node(in_nodes3[0].inputs[0]) + if in_nodes5.layer_type != "Add": + is_batch_norm = False + continue + + in_nodes6 = [ + self.graph.get_node(in_name) for in_name in in_nodes5.inputs + ] + if in_nodes6[0].layer_type != "Const" or in_nodes6[ + 1].layer_type != "Const": + is_batch_norm = False + continue + + conv_shape = in_nodes1[0].out_shapes[0] + if conv_shape[3] < 0: + is_batch_norm = False + continue + + # moving_variance + if in_nodes6[0].value.size != conv_shape[3]: + is_batch_norm = False + continue + + # epsilon + if in_nodes6[1].value.size != 1: + is_batch_norm = False + continue + + # gamma + if in_nodes3[1].value.size != conv_shape[3]: + is_batch_norm = False + continue + + # moving_mean + if in_nodes4[0].value.size != conv_shape[3]: + is_batch_norm = False + continue + + # beta + if in_nodes2[0].value.size != conv_shape[3]: + is_batch_norm = False + continue + + if is_batch_norm: + index = in_nodes1[0].outputs.index(in_nodes0[0].layer_name) + del in_nodes1[0].outputs[index] + node.layer_type = "FusedBatchNorm" + node.inputs = [in_nodes1[0].layer_name] + node.outputs = node.outputs + act = node.fluid_code.layers[-1].param_attr.get("act", None) + node.fluid_code.clear() + attr = { + "epsilon": in_nodes6[1].value, + "param_attr": string(in_nodes3[1].layer_name), + "bias_attr": string(in_nodes2[0].layer_name), + "moving_mean_name": string(in_nodes4[0].layer_name), + "moving_variance_name": string(in_nodes6[0].layer_name), + "is_test": True, + "act": act + } + + node.fluid_code.add_layer("batch_norm", + inputs=cp.copy(in_nodes1[0]), + output=node, + param_attr=attr) + + del self.graph.node_map[in_nodes0[0].layer_name] + del self.graph.node_map[in_nodes0[1].layer_name] + del self.graph.node_map[in_nodes1[1].layer_name] + del self.graph.node_map[in_nodes2[1].layer_name] + del self.graph.node_map[in_nodes3[0].layer_name] + del self.graph.node_map[in_nodes4[0].layer_name] + del self.graph.node_map[in_nodes5.layer_name] -- GitLab