diff --git a/x2paddle/convert.py b/x2paddle/convert.py index 0ead2235275015564751224151792b23d8984fb2..904a10b916fa2ac17d60f841c31527e8fc0054ea 100644 --- a/x2paddle/convert.py +++ b/x2paddle/convert.py @@ -106,6 +106,8 @@ def tf2paddle(model_path, # optimizer below is experimental optimizer.merge_activation() optimizer.merge_bias() + optimizer.merge_batch_norm() + optimizer.merge_prelu() else: mapper = TFOpMapperNHWC(model) optimizer = TFOptimizer(mapper) diff --git a/x2paddle/optimizer/tf_optimizer.py b/x2paddle/optimizer/tf_optimizer.py index 223a9b816625c6446ac90bb6184bcf62b2c385f6..1980510147bb46079b6dea0f7f34af48477c2c56 100644 --- a/x2paddle/optimizer/tf_optimizer.py +++ b/x2paddle/optimizer/tf_optimizer.py @@ -16,6 +16,7 @@ from x2paddle.op_mapper.tf_op_mapper import TFOpMapper from x2paddle.core.fluid_code import Layer from x2paddle.core.util import * +import numpy import copy as cp @@ -351,3 +352,311 @@ class TFOptimizer(object): if node.fluid_code.layers[-1].op == "transpose": node.fluid_code.layers[-2].output = name del node.fluid_code.layers[-1] + + def merge_batch_norm(self): + for i, name in enumerate(self.graph.topo_sort): + node = self.graph.get_node(name) + if node is None: + continue + is_batch_norm = True + if node.layer_type == "Add": + in_nodes0 = [ + self.graph.get_node(in_name) for in_name in node.inputs + ] + if in_nodes0[0].layer_type != "Mul" or in_nodes0[ + 1].layer_type != "Sub": + is_batch_norm = False + continue + + in_nodes1 = [ + self.graph.get_node(in_name) + for in_name in in_nodes0[0].inputs + ] + in_nodes2 = [ + self.graph.get_node(in_name) + for in_name in in_nodes0[1].inputs + ] + if len(in_nodes1[0].out_shapes[0]) != 4: + is_batch_norm = False + continue + if in_nodes1[1].layer_type != "Mul": + is_batch_norm = False + continue + + if in_nodes2[0].layer_type != "Const" or in_nodes2[ + 1].layer_type != "Mul": + is_batch_norm = False + continue + + in_nodes3 = [ + self.graph.get_node(in_name) + for in_name in in_nodes1[1].inputs + ] + if in_nodes3[0].layer_type != "Rsqrt" or in_nodes3[ + 1].layer_type != "Const": + is_batch_norm = False + continue + + in_nodes4 = [ + self.graph.get_node(in_name) + for in_name in in_nodes2[1].inputs + ] + if in_nodes4[0].layer_type != "Const" or in_nodes4[ + 1].layer_name != in_nodes1[1].layer_name: + is_batch_norm = False + continue + + in_nodes5 = self.graph.get_node(in_nodes3[0].inputs[0]) + if in_nodes5.layer_type != "Add": + is_batch_norm = False + continue + + in_nodes6 = [ + self.graph.get_node(in_name) for in_name in in_nodes5.inputs + ] + if in_nodes6[0].layer_type != "Const" or in_nodes6[ + 1].layer_type != "Const": + is_batch_norm = False + continue + + if len(in_nodes0[0].outputs) != 1: + is_batch_norm = False + continue + if len(in_nodes0[1].outputs) != 1: + is_batch_norm = False + continue + if len(in_nodes1[1].outputs) != 2: + is_batch_norm = False + continue + if len(in_nodes2[0].outputs) != 1: + is_batch_norm = False + continue + if len(in_nodes2[1].outputs) != 1: + is_batch_norm = False + continue + if len(in_nodes3[0].outputs) != 1: + is_batch_norm = False + continue + if len(in_nodes3[1].outputs) != 1: + is_batch_norm = False + continue + if len(in_nodes4[0].outputs) != 1: + is_batch_norm = False + continue + if len(in_nodes5.outputs) != 1: + is_batch_norm = False + continue + if len(in_nodes6[0].outputs) != 1: + is_batch_norm = False + continue + if len(in_nodes6[1].outputs) != 1: + is_batch_norm = False + continue + + conv_shape = in_nodes1[0].out_shapes[0] + if conv_shape[3] < 0: + is_batch_norm = False + continue + + # moving_variance + if in_nodes6[0].value.size != conv_shape[3]: + is_batch_norm = False + continue + + # epsilon + if in_nodes6[1].value.size != 1: + is_batch_norm = False + continue + + # gamma + if in_nodes3[1].value.size != conv_shape[3]: + is_batch_norm = False + continue + + # moving_mean + if in_nodes4[0].value.size != conv_shape[3]: + is_batch_norm = False + continue + + # beta + if in_nodes2[0].value.size != conv_shape[3]: + is_batch_norm = False + continue + + if is_batch_norm: + index = in_nodes1[0].outputs.index(in_nodes0[0].layer_name) + del in_nodes1[0].outputs[index] + node.layer_type = "FusedBatchNorm" + node.inputs = [in_nodes1[0].layer_name] + node.outputs = node.outputs + act = node.fluid_code.layers[-1].param_attr.get("act", None) + node.fluid_code.clear() + attr = { + "epsilon": in_nodes6[1].value, + "param_attr": string(in_nodes3[1].layer_name), + "bias_attr": string(in_nodes2[0].layer_name), + "moving_mean_name": string(in_nodes4[0].layer_name), + "moving_variance_name": string(in_nodes6[0].layer_name), + "is_test": True, + "act": act + } + + node.fluid_code.add_layer( + "batch_norm", + inputs=in_nodes1[0].fluid_code.layers[-1].output, + output=node, + param_attr=attr) + + del self.graph.node_map[in_nodes0[0].layer_name] + del self.graph.node_map[in_nodes0[1].layer_name] + del self.graph.node_map[in_nodes1[1].layer_name] + del self.graph.node_map[in_nodes2[1].layer_name] + del self.graph.node_map[in_nodes3[0].layer_name] + del self.graph.node_map[in_nodes4[0].layer_name] + del self.graph.node_map[in_nodes5.layer_name] + + def merge_prelu(self): + for i, name in enumerate(self.graph.topo_sort): + node = self.graph.get_node(name) + if node is None: + continue + is_prelu = True + if node.layer_type == "Add": + in_nodes0 = [ + self.graph.get_node(in_name) for in_name in node.inputs + ] + if in_nodes0[0].layer_type != "Relu" or in_nodes0[ + 1].layer_type != "Mul": + is_prelu = False + continue + if len(in_nodes0[0].outputs) != 1 or len( + in_nodes0[1].outputs) != 1: + is_prelu = False + continue + + in_nodes1 = self.graph.get_node(in_nodes0[0].inputs[0]) + in_nodes2 = [ + self.graph.get_node(in_name) + for in_name in in_nodes0[1].inputs + ] + if in_nodes2[1].layer_type != "Const" or numpy.fabs( + in_nodes2[1].value - 0.5) > 1e-06: + is_prelu = False + continue + if in_nodes2[0].layer_type != "Mul": + is_prelu = False + continue + if len(in_nodes2[1].outputs) != 1 or len( + in_nodes2[0].outputs) != 1: + is_prelu = False + continue + + in_nodes3 = [ + self.graph.get_node(in_name) + for in_name in in_nodes2[0].inputs + ] + if in_nodes3[0].layer_type != "Const" or in_nodes3[ + 1].layer_type != "Sub": + is_prelu = False + continue + if len(in_nodes3[0].outputs) != 1 or len( + in_nodes3[1].outputs) != 1: + is_prelu = False + continue + + in_nodes4 = [ + self.graph.get_node(in_name) + for in_name in in_nodes3[1].inputs + ] + if in_nodes4[0].layer_name != in_nodes1.layer_name or in_nodes4[ + 1].layer_type != "Abs": + is_prelu = False + continue + if len(in_nodes4[1].outputs) != 1: + is_prelu = False + continue + + in_nodes5 = self.graph.get_node(in_nodes4[1].inputs[0]) + if in_nodes5.layer_name != in_nodes1.layer_name: + is_prelu = False + continue + + if len(in_nodes0[0].outputs) != 1: + is_prelu = false + continue + if len(in_nodes0[1].outputs) != 1: + is_prelu = False + continue + if len(in_nodes1.outputs) < 3: + is_prelu = False + continue + if len(in_nodes2[0].outputs) != 1: + is_prelu = false + continue + if len(in_nodes2[1].outputs) != 1: + is_prelu = False + continue + if len(in_nodes3[0].outputs) != 1: + is_prelu = False + continue + if len(in_nodes3[1].outputs) != 1: + is_prelu = false + continue + if len(in_nodes4[1].outputs) != 1: + is_prelu = False + continue + + mode = None + in_shape = in_nodes1.out_shapes[0] + if in_shape == list(in_nodes3[0].value.shape): + mode = "element" + elif len(in_nodes3[0].value.shape) == 0: + mode = "all" + elif len(in_nodes3[0].value.shape + ) == 1 and in_nodes3[0].value.shape[0] == 1: + mode = "all" + elif len(in_shape) == 4 and len( + in_nodes3[0].value.shape + ) == 1 and in_nodes3[0].value.shape[0] == in_shape[-1]: + mode = "channel" + weight = self.op_mapper.weights[in_nodes3[0].layer_name] + weight = numpy.expand_dims(weight, 0) + weight = numpy.expand_dims(weight, 2) + weight = numpy.expand_dims(weight, 3) + self.op_mapper.weights[in_nodes3[0].layer_name] = weight + in_nodes3[0].fluid_code.layers[0].param_attr["shape"] = [ + 1, in_shape[-1], 1, 1 + ] + else: + is_prelu = False + continue + + if is_prelu: + index = in_nodes1.outputs.index(in_nodes0[0].layer_name) + del in_nodes1.outputs[index] + index = in_nodes1.outputs.index(in_nodes3[1].layer_name) + del in_nodes1.outputs[index] + index = in_nodes1.outputs.index(in_nodes4[1].layer_name) + del in_nodes1.outputs[index] + + node.layer_type = "Prelu" + node.inputs = [in_nodes1.layer_name] + node.outputs = node.outputs + act = node.fluid_code.layers[-1].param_attr.get("act", None) + node.fluid_code.clear() + attr = { + "mode": string(mode), + "param_attr": string(in_nodes3[0].layer_name) + } + + node.fluid_code.add_layer( + "prelu", + inputs=in_nodes1.fluid_code.layers[-1].output, + output=node, + param_attr=attr) + del self.graph.node_map[in_nodes0[0].layer_name] + del self.graph.node_map[in_nodes0[1].layer_name] + del self.graph.node_map[in_nodes2[0].layer_name] + del self.graph.node_map[in_nodes2[1].layer_name] + del self.graph.node_map[in_nodes3[1].layer_name] + del self.graph.node_map[in_nodes4[1].layer_name]