diff --git a/mace/python/tools/tf_converter_lib.py b/mace/python/tools/tf_converter_lib.py index 27df84accf8859a20454f4c512ce688ccea8081a..823090d3c8720256db09bef36cbb8ca3759c0c4d 100644 --- a/mace/python/tools/tf_converter_lib.py +++ b/mace/python/tools/tf_converter_lib.py @@ -30,14 +30,13 @@ def convert_tensor(op, tensor): else: raise Exception("Not supported tensor type: " + tf_dt.name) - def get_input_tensor(op, index): input_tensor = op.inputs[index] if input_tensor.op.type == 'Reshape': input_tensor = get_input_tensor(input_tensor.op, 0) return input_tensor -def convert_ops(unresolved_ops, net_def): +def convert_ops(ops_map, unresolved_ops, net_def): ops_count = len(unresolved_ops) resolved_count = 1 @@ -77,6 +76,28 @@ def convert_ops(unresolved_ops, net_def): bias_add_op = unresolved_ops[1] op_def.input.extend([bias_add_op.inputs[1].name]) resolved_count = 2 + + if ops_count >= 3 and unresolved_ops[1].type == 'Relu': + op_def.name = "FusedConv2D" + resolved_count = 3 + elif first_op.type == 'FusedBatchNorm': + op_def = net_def.op.add() + op_def.name = first_op.name + first_op.type = 'BatchNorm' + op_def.input.extend([input.name for input in first_op.inputs]) + op_def.output.extend([output.name for output in first_op.outputs]) + output_shapes = [] + for output in first_op.outputs: + output_shape = mace_pb2.OutputShape() + output_shape.dims.extend(output.shape.as_list()) + output_shapes.append(output_shape) + op_def.output_shape.extend(output_shapes) + epsilon_arg = op_def.arg.add() + epsilon_arg.name = 'epsilon' + epsilon_arg.f = first_op.get_attr('epsilon') + data_format_arg = op_def.arg.add() + data_format_arg.name = 'data_format' + data_format_arg.s = 'NHWC' elif first_op.type == 'Add' and first_op.name.endswith( 'batchnorm/add') and ops_count > 7: add_op = first_op @@ -90,6 +111,7 @@ def convert_ops(unresolved_ops, net_def): mul_1_op.type != 'Mul' or sub_op.type != 'Sub' or add_1_op.type != 'Add': raise Exception('Invalid BatchNorm Op') + get_input_tensor(mul_1_op, 0) input_name = get_input_tensor(mul_1_op, 0).name gamma = get_input_tensor(mul_op, 1).name beta = get_input_tensor(sub_op, 0).name @@ -168,15 +190,33 @@ def convert_ops(unresolved_ops, net_def): op_def = net_def.op.add() op_def.name = first_op.name op_def.type = "Concat" - op_def.input.extend([input.name for input in first_op.inputs]) + op_def.input.extend([first_op.inputs[i] for i in xrange(2)]) op_def.output.extend([output.name for output in first_op.outputs]) + axis_arg = op_def.arg.add() + axis_arg.name = 'axis' + axis_arg.i = get_input_tensor(first_op, 2).eval().astype(np.int32) output_shapes = [] for output in first_op.outputs: output_shape = mace_pb2.OutputShape() output_shape.dims.extend(output.shape.as_list()) output_shapes.append(output_shape) op_def.output_shape.extend(output_shapes) - elif first_op.type in ['Relu', 'ResizeBilinear', 'SpaceToBatchND', + elif first_op.type == 'ResizeBilinear': + op_def = net_def.op.add() + op_def.name = first_op.name + op_def.type = "ResizeBilinear" + op_def.input.extend(first_op.inputs[0]) + op_def.output.extend([output.name for output in first_op.outputs]) + size_arg = op_def.arg.add() + size_arg.name = 'size' + size_arg.ints.extend(get_input_tensor(first_op, 1).eval().astype(np.int32).flat) + output_shapes = [] + for output in first_op.outputs: + output_shape = mace_pb2.OutputShape() + output_shape.dims.extend(output.shape.as_list()) + output_shapes.append(output_shape) + op_def.output_shape.extend(output_shapes) + elif first_op.type in ['Relu', 'SpaceToBatchND', 'BatchToSpaceND', 'BiasAdd', 'FusedBatchNorm']: op_def = net_def.op.add() op_def.name = first_op.name @@ -204,9 +244,16 @@ def convert_to_mace_pb(input_graph_def): with session.graph.as_default() as graph: tf.import_graph_def(input_graph_def, name="") ops = graph.get_operations() + ops_map = {} + for op in ops: + if op.name not in ops_map.keys(): + ops_map[op.name] = op + else: + raise ValueError("Duplicate op names detected for ", op.name) + unresolved_ops = ops while len(unresolved_ops) > 0: - convert_ops(unresolved_ops, net_def) + convert_ops(ops_map, unresolved_ops, net_def) print "Done."