提交 3ea681f3 编写于 作者: 李寅

Fold batchnorm (scale, offset)

上级 5270b335
......@@ -75,7 +75,8 @@ def get_input_tensor(op, index):
class TFConverter(object):
def __init__(self, tf_ops, net_def, dt, device, winograd):
def __init__(self, graph, tf_ops, net_def, dt, device, winograd):
self.graph = graph
self.net_def = net_def
self.tf_ops = tf_ops
self.dt = dt
......@@ -494,7 +495,12 @@ class TFConverter(object):
self.resolved_ops[op.name] = 1
if len(self.tf_graph.get(op.name, [])) == 1 and \
self.tf_graph[op.name][0].type == 'BiasAdd':
self.tf_graph[op.name][0].type == 'BiasAdd' or \
(len(self.tf_graph[op.name]) == 1 and
self.tf_graph[op.name][0].type == 'Add' and
len(self.tf_graph[op.name][0].inputs) == 2 and
len(self.graph.get_tensor_by_name(
self.tf_graph[op.name][0].inputs[1].name).shape) == 1):
bias_add_op = self.tf_graph[op.name][0]
if self.device == 'gpu':
output_name = self.add_buffer_to_image(
......@@ -650,61 +656,52 @@ class TFConverter(object):
self.net_def.op.extend([op_def])
def convert_batchnorm(self, op):
bn_ops = []
bn_ops.append(op)
for i in range(1, 3):
if len(self.tf_graph[bn_ops[i-1].name]) == 1 and \
self.tf_graph[bn_ops[i-1].name][0].type == BATCH_NORM_ORDER[i]:
bn_ops.append(self.tf_graph[bn_ops[i - 1].name][0])
else:
raise Exception('Invalid BatchNorm Op')
if len(self.tf_graph[bn_ops[2].name]) == 2 and \
self.tf_graph[bn_ops[2].name][0].type == \
BATCH_NORM_ORDER[3] and \
self.tf_graph[bn_ops[2].name][1].type == BATCH_NORM_ORDER[4]:
bn_ops.append(self.tf_graph[bn_ops[2].name][0])
bn_ops.append(self.tf_graph[bn_ops[2].name][1])
else:
raise Exception('Invalid BatchNorm Op')
bn_ops.append(self.tf_graph[bn_ops[4].name][0])
bn_ops.append(self.tf_graph[bn_ops[3].name][0])
op_def = mace_pb2.OperatorDef()
arg = op_def.arg.add()
arg.name = 'T'
arg.i = self.dt
input_name = get_input_tensor(bn_ops[3], 0).name
gamma = get_input_tensor(bn_ops[2], 1).name
beta = get_input_tensor(bn_ops[5], 0).name
mean = get_input_tensor(bn_ops[4], 0).name
variance = get_input_tensor(bn_ops[0], 0).name
op_def.name = op.name[:-4] # remove /add
op_def.type = 'BatchNorm'
if self.device == 'gpu':
op_def.input.extend([input_name])
for tensor_name in [gamma, beta, mean, variance]:
output_name = self.add_buffer_to_image(tensor_name, "ARGUMENT")
op_def.input.extend([output_name])
else:
op_def.input.extend([input_name, gamma, beta, mean, variance])
op_def.output.extend([output.name for output in bn_ops[6].outputs])
self.add_output_shape(bn_ops[6].outputs, op_def)
epsilon_arg = op_def.arg.add()
epsilon_arg.name = 'epsilon'
epsilon_arg.f = get_input_tensor(op, 1).eval().astype(np.float)
data_format_arg = op_def.arg.add()
data_format_arg.name = 'data_format'
if self.device == 'cpu':
data_format_arg.s = 'NCHW'
else:
data_format_arg.s = 'NHWC'
self.unused_tensor.add(get_input_tensor(op, 1).name)
op_def.name = op.name
op_def.type = 'FoldedBatchNorm'
add_op = self.tf_graph[op.name][0]
scale_tensor = get_input_tensor(op, 1)
offset_tensor = get_input_tensor(add_op, 1)
input_names = [scale_tensor.name, offset_tensor.name]
op_def.input.extend([op.inputs[0].name])
if self.device == 'gpu':
for name in input_names:
output_name = self.add_buffer_to_image(name, "ARGUMENT")
op_def.input.extend([output_name])
else:
op_def.input.extend([name for name in input_names])
self.resolved_ops[op.name] = 1
self.resolved_ops[add_op.name] = 1
final_op = add_op
if len(self.tf_graph[op.name]) == 1 \
and self.tf_graph[op.name][0].type in activation_name_map:
activation_op = self.tf_graph[op.name][0]
fused_act_arg = op_def.arg.add()
fused_act_arg.name = 'activation'
fused_act_arg.s = activation_name_map[activation_op.type]
if activation_op.type == 'Relu6':
max_limit_arg = op_def.arg.add()
max_limit_arg.name = 'max_limit'
max_limit_arg.f = 6
final_op = activation_op
self.resolved_ops[activation_op.name] = 1
op_def.output.extend([final_op.outputs[0].name])
self.add_output_shape([final_op.outputs[0]], op_def)
self.net_def.op.extend([op_def])
for i in range(0, 7):
self.resolved_ops[bn_ops[i].name] = 1
def convert_pooling(self, op):
op_def = self.net_def.op.add()
......@@ -1155,7 +1152,7 @@ class TFConverter(object):
self.convert_conv2d(op)
elif op.type == 'FusedBatchNorm':
self.convert_fused_batchnorm(op)
elif op.type == 'Add' and op.name.endswith('batchnorm/add'):
elif op.type == 'Mul' and op.name.find('batchnorm/mul') != -1:
self.convert_batchnorm(op)
elif op.type == 'AvgPool' or op.type == 'MaxPool':
self.convert_pooling(op)
......@@ -1382,7 +1379,7 @@ def convert_to_mace_pb(model_file, input_node, input_shape, output_node,
with session.graph.as_default() as graph:
tf.import_graph_def(input_graph_def, name="")
ops = graph.get_operations()
converter = TFConverter(ops, net_def, dt, device, winograd)
converter = TFConverter(graph, ops, net_def, dt, device, winograd)
converter.convert(input_nodes, output_nodes)
optimizer = Optimizer(net_def, device)
net_def = optimizer.optimize()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册