From 3ea681f3c074031538c870d727ae82f108d9ccbb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=AF=85?= Date: Wed, 2 May 2018 17:48:54 +0800 Subject: [PATCH] Fold batchnorm (scale, offset) --- mace/python/tools/tf_converter_lib.py | 91 +++++++++++++-------------- 1 file changed, 44 insertions(+), 47 deletions(-) diff --git a/mace/python/tools/tf_converter_lib.py b/mace/python/tools/tf_converter_lib.py index d562228f..dbf01b24 100644 --- a/mace/python/tools/tf_converter_lib.py +++ b/mace/python/tools/tf_converter_lib.py @@ -75,7 +75,8 @@ def get_input_tensor(op, index): class TFConverter(object): - def __init__(self, tf_ops, net_def, dt, device, winograd): + def __init__(self, graph, tf_ops, net_def, dt, device, winograd): + self.graph = graph self.net_def = net_def self.tf_ops = tf_ops self.dt = dt @@ -494,7 +495,12 @@ class TFConverter(object): self.resolved_ops[op.name] = 1 if len(self.tf_graph.get(op.name, [])) == 1 and \ - self.tf_graph[op.name][0].type == 'BiasAdd': + self.tf_graph[op.name][0].type == 'BiasAdd' or \ + (len(self.tf_graph[op.name]) == 1 and + self.tf_graph[op.name][0].type == 'Add' and + len(self.tf_graph[op.name][0].inputs) == 2 and + len(self.graph.get_tensor_by_name( + self.tf_graph[op.name][0].inputs[1].name).shape) == 1): bias_add_op = self.tf_graph[op.name][0] if self.device == 'gpu': output_name = self.add_buffer_to_image( @@ -650,61 +656,52 @@ class TFConverter(object): self.net_def.op.extend([op_def]) def convert_batchnorm(self, op): - bn_ops = [] - bn_ops.append(op) - for i in range(1, 3): - if len(self.tf_graph[bn_ops[i-1].name]) == 1 and \ - self.tf_graph[bn_ops[i-1].name][0].type == BATCH_NORM_ORDER[i]: - bn_ops.append(self.tf_graph[bn_ops[i - 1].name][0]) - else: - raise Exception('Invalid BatchNorm Op') - if len(self.tf_graph[bn_ops[2].name]) == 2 and \ - self.tf_graph[bn_ops[2].name][0].type == \ - BATCH_NORM_ORDER[3] and \ - self.tf_graph[bn_ops[2].name][1].type == BATCH_NORM_ORDER[4]: - bn_ops.append(self.tf_graph[bn_ops[2].name][0]) - bn_ops.append(self.tf_graph[bn_ops[2].name][1]) - else: - raise Exception('Invalid BatchNorm Op') - bn_ops.append(self.tf_graph[bn_ops[4].name][0]) - bn_ops.append(self.tf_graph[bn_ops[3].name][0]) - op_def = mace_pb2.OperatorDef() arg = op_def.arg.add() arg.name = 'T' arg.i = self.dt - - input_name = get_input_tensor(bn_ops[3], 0).name - gamma = get_input_tensor(bn_ops[2], 1).name - beta = get_input_tensor(bn_ops[5], 0).name - mean = get_input_tensor(bn_ops[4], 0).name - variance = get_input_tensor(bn_ops[0], 0).name - - op_def.name = op.name[:-4] # remove /add - op_def.type = 'BatchNorm' - if self.device == 'gpu': - op_def.input.extend([input_name]) - for tensor_name in [gamma, beta, mean, variance]: - output_name = self.add_buffer_to_image(tensor_name, "ARGUMENT") - op_def.input.extend([output_name]) - else: - op_def.input.extend([input_name, gamma, beta, mean, variance]) - op_def.output.extend([output.name for output in bn_ops[6].outputs]) - self.add_output_shape(bn_ops[6].outputs, op_def) - epsilon_arg = op_def.arg.add() - epsilon_arg.name = 'epsilon' - epsilon_arg.f = get_input_tensor(op, 1).eval().astype(np.float) data_format_arg = op_def.arg.add() data_format_arg.name = 'data_format' if self.device == 'cpu': data_format_arg.s = 'NCHW' else: data_format_arg.s = 'NHWC' - self.unused_tensor.add(get_input_tensor(op, 1).name) + op_def.name = op.name + op_def.type = 'FoldedBatchNorm' + + add_op = self.tf_graph[op.name][0] + scale_tensor = get_input_tensor(op, 1) + offset_tensor = get_input_tensor(add_op, 1) + input_names = [scale_tensor.name, offset_tensor.name] + op_def.input.extend([op.inputs[0].name]) + if self.device == 'gpu': + for name in input_names: + output_name = self.add_buffer_to_image(name, "ARGUMENT") + op_def.input.extend([output_name]) + else: + op_def.input.extend([name for name in input_names]) + + self.resolved_ops[op.name] = 1 + self.resolved_ops[add_op.name] = 1 + final_op = add_op + + if len(self.tf_graph[op.name]) == 1 \ + and self.tf_graph[op.name][0].type in activation_name_map: + activation_op = self.tf_graph[op.name][0] + fused_act_arg = op_def.arg.add() + fused_act_arg.name = 'activation' + fused_act_arg.s = activation_name_map[activation_op.type] + if activation_op.type == 'Relu6': + max_limit_arg = op_def.arg.add() + max_limit_arg.name = 'max_limit' + max_limit_arg.f = 6 + final_op = activation_op + self.resolved_ops[activation_op.name] = 1 + + op_def.output.extend([final_op.outputs[0].name]) + self.add_output_shape([final_op.outputs[0]], op_def) self.net_def.op.extend([op_def]) - for i in range(0, 7): - self.resolved_ops[bn_ops[i].name] = 1 def convert_pooling(self, op): op_def = self.net_def.op.add() @@ -1155,7 +1152,7 @@ class TFConverter(object): self.convert_conv2d(op) elif op.type == 'FusedBatchNorm': self.convert_fused_batchnorm(op) - elif op.type == 'Add' and op.name.endswith('batchnorm/add'): + elif op.type == 'Mul' and op.name.find('batchnorm/mul') != -1: self.convert_batchnorm(op) elif op.type == 'AvgPool' or op.type == 'MaxPool': self.convert_pooling(op) @@ -1382,7 +1379,7 @@ def convert_to_mace_pb(model_file, input_node, input_shape, output_node, with session.graph.as_default() as graph: tf.import_graph_def(input_graph_def, name="") ops = graph.get_operations() - converter = TFConverter(ops, net_def, dt, device, winograd) + converter = TFConverter(graph, ops, net_def, dt, device, winograd) converter.convert(input_nodes, output_nodes) optimizer = Optimizer(net_def, device) net_def = optimizer.optimize() -- GitLab