diff --git a/mace/proto/mace.proto b/mace/proto/mace.proto index 7a0ad7211737a8efd41eafa71b9a2c07ed2815da..119e1fed79a7cad1374cdb3891745ec2c83716bb 100644 --- a/mace/proto/mace.proto +++ b/mace/proto/mace.proto @@ -77,7 +77,7 @@ message OperatorDef { optional string name = 3; optional string type = 4; repeated Argument arg = 5; - optional OutputShape output_shape = 6; + repeated OutputShape output_shape = 6; // Memory optimization: only support one single output op optional int32 mem_id = 10 [default = -1]; diff --git a/mace/python/tools/tf_converter_lib.py b/mace/python/tools/tf_converter_lib.py index 97575bf23ce9583f1db75ce37d5bc699d0f0189e..27df84accf8859a20454f4c512ce688ccea8081a 100644 --- a/mace/python/tools/tf_converter_lib.py +++ b/mace/python/tools/tf_converter_lib.py @@ -18,15 +18,6 @@ def convert_tensor(op, tensor): tensor.name = op.outputs[0].name shape = list(tf_tensor.shape) - if (op.name.find('pointwise_kernel') != -1 or - op.name.find('depthwise_kernel') != -1 or - op.name.endswith('weights') or - op.name.endswith('kernel')) \ - and op.outputs[0].consumers()[0].type.find('Conv') != -1: - if op.outputs[0].consumers()[0].get_attr('data_format') == 'NHWC': - tf_tensor = np.transpose(tf_tensor, axes=(3, 2, 0, 1)) - shape = [shape[3], shape[2], shape[0], shape[1]] - # print (tensor.name, shape) tensor.dims.extend(shape) tf_dt = op.get_attr('dtype') @@ -66,6 +57,12 @@ def convert_ops(unresolved_ops, net_def): op_def.type = first_op.type op_def.input.extend([input.name for input in first_op.inputs]) op_def.output.extend([output.name for output in first_op.outputs]) + output_shapes = [] + for output in first_op.outputs: + output_shape = mace_pb2.OutputShape() + output_shape.dims.extend(output.shape.as_list()) + output_shapes.append(output_shape) + op_def.output_shape.extend(output_shapes) padding_arg = op_def.arg.add() padding_arg.name = 'padding' padding_arg.i = padding_mode[first_op.get_attr('padding')] @@ -74,7 +71,7 @@ def convert_ops(unresolved_ops, net_def): strides_arg.ints.extend(first_op.get_attr('strides')[1:3]) data_format_arg = op_def.arg.add() data_format_arg.name = 'data_format' - data_format_arg.s = 'NCHW' + data_format_arg.s = 'NHWC' if ops_count >= 2 and unresolved_ops[1].type == 'BiasAdd': bias_add_op = unresolved_ops[1] @@ -105,6 +102,12 @@ def convert_ops(unresolved_ops, net_def): op_def.type = 'BatchNorm' op_def.input.extend([input_name, gamma, beta, mean, variance, epsilon]) op_def.output.extend([output.name for output in add_1_op.outputs]) + output_shapes = [] + for output in add_1_op.outputs: + output_shape = mace_pb2.OutputShape() + output_shape.dims.extend(output.shape.as_list()) + output_shapes.append(output_shape) + op_def.output_shape.extend(output_shapes) resolved_count = 7 elif first_op.type == 'Relu6': @@ -113,6 +116,12 @@ def convert_ops(unresolved_ops, net_def): op_def.type = 'Relu' op_def.input.extend([input.name for input in first_op.inputs]) op_def.output.extend([output.name for output in first_op.outputs]) + output_shapes = [] + for output in first_op.outputs: + output_shape = mace_pb2.OutputShape() + output_shape.dims.extend(output.shape.as_list()) + output_shapes.append(output_shape) + op_def.output_shape.extend(output_shapes) max_limit_arg = op_def.arg.add() max_limit_arg.name = 'max_limit' max_limit_arg.f = 6 @@ -122,6 +131,12 @@ def convert_ops(unresolved_ops, net_def): op_def.type = 'Pooling' op_def.input.extend([input.name for input in first_op.inputs]) op_def.output.extend([output.name for output in first_op.outputs]) + output_shapes = [] + for output in first_op.outputs: + output_shape = mace_pb2.OutputShape() + output_shape.dims.extend(output.shape.as_list()) + output_shapes.append(output_shape) + op_def.output_shape.extend(output_shapes) pooling_type_arg = op_def.arg.add() pooling_type_arg.name = 'pooling_type' pooling_type_arg.i = pooling_type_mode[first_op.type] @@ -136,21 +151,46 @@ def convert_ops(unresolved_ops, net_def): kernels_arg.ints.extend(first_op.get_attr('ksize')[1:3]) data_format_arg = op_def.arg.add() data_format_arg.name = 'data_format' - data_format_arg.s = 'NCHW' + data_format_arg.s = 'NHWC' elif first_op.type == 'Add': op_def = net_def.op.add() op_def.name = first_op.name op_def.type = "AddN" op_def.input.extend([input.name for input in first_op.inputs]) op_def.output.extend([output.name for output in first_op.outputs]) - elif first_op.type in ['Relu', 'ResizeBilinear', 'SpaceToBatchND', 'BatchToSpaceND']: + output_shapes = [] + for output in first_op.outputs: + output_shape = mace_pb2.OutputShape() + output_shape.dims.extend(output.shape.as_list()) + output_shapes.append(output_shape) + op_def.output_shape.extend(output_shapes) + elif first_op.type == 'ConcatV2': + op_def = net_def.op.add() + op_def.name = first_op.name + op_def.type = "Concat" + op_def.input.extend([input.name for input in first_op.inputs]) + op_def.output.extend([output.name for output in first_op.outputs]) + output_shapes = [] + for output in first_op.outputs: + output_shape = mace_pb2.OutputShape() + output_shape.dims.extend(output.shape.as_list()) + output_shapes.append(output_shape) + op_def.output_shape.extend(output_shapes) + elif first_op.type in ['Relu', 'ResizeBilinear', 'SpaceToBatchND', + 'BatchToSpaceND', 'BiasAdd', 'FusedBatchNorm']: op_def = net_def.op.add() op_def.name = first_op.name op_def.type = first_op.type op_def.input.extend([input.name for input in first_op.inputs]) op_def.output.extend([output.name for output in first_op.outputs]) + output_shapes = [] + for output in first_op.outputs: + output_shape = mace_pb2.OutputShape() + output_shape.dims.extend(output.shape.as_list()) + output_shapes.append(output_shape) + op_def.output_shape.extend(output_shapes) else: - raise Exception('Unknown Op: ' + first_op.name) + raise Exception('Unknown Op: %s, type: %s' % (first_op.name, first_op.type)) pass for i in range(resolved_count):