diff --git a/mace/python/tools/converter_tool/tensorflow_converter.py b/mace/python/tools/converter_tool/tensorflow_converter.py index 1818070aa08433db815262240f641a4853f99451..8c272ac2442db86b33e9617d7cf9e05f1ebb477d 100644 --- a/mace/python/tools/converter_tool/tensorflow_converter.py +++ b/mace/python/tools/converter_tool/tensorflow_converter.py @@ -578,7 +578,7 @@ class TensorflowConverter(base_converter.ConverterInterface): axis_arg = op.arg.add() axis_arg.name = MaceKeyword.mace_axis_str axis = tf_op.inputs[-1].eval().astype(np.int32) - axis = 4 + axis if axis < 0 else axis + axis = len(op.output_shape[0].dims) + axis if axis < 0 else axis axis_arg.i = axis self._skip_tensor.add(tf_op.inputs[-1].name) diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index a6295b5280e12f5bc84654fb7c91a37954a3ccec..850519f5bcf37d57ff8134f1d8f4be4386eb902b 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -751,6 +751,15 @@ class Transformer(base_converter.ConverterInterface): "only support concat at " "channel dimension") arg.i = 3 + producer = self._producer[op.input[0]] + input_shape = producer.output_shape[0].dims + if producer.type == MaceOp.FullyConnected.name and \ + len(input_shape) == 2: + axis_arg = ConverterUtil.get_arg( + op, MaceKeyword.mace_axis_str) + if axis_arg.i == 1 \ + and self._target_data_format == DataFormat.NHWC: # noqa + axis_arg.i = 3 elif op.type == MaceOp.Squeeze.name: for arg in op.arg: @@ -938,7 +947,10 @@ class Transformer(base_converter.ConverterInterface): input_shape = list(input_op.output_shape[0].dims) input_data_format = ConverterUtil.data_format(input_op) weight.dims[:] = [weight.dims[0]] + input_shape[1:] - if input_data_format == DataFormat.NHWC: + if len(input_shape) == 2: + weight.dims[:] = weight.dims[:] + [1, 1] + if input_data_format == DataFormat.NHWC and \ + len(input_shape) == 4: self.transpose_shape(weight.dims, [0, 3, 1, 2]) return False @@ -1113,31 +1125,48 @@ class Transformer(base_converter.ConverterInterface): net = self._model filter_format = self.filter_format() for op in net.op: - # transform reshape + matmul -> fc + # transform input(4D) -> reshape(2D) -> matmul to fc # work for TensorFlow + if op.type == MaceOp.Reshape.name and \ + op.input[1] in self._consts and \ + len(op.output_shape[0].dims) == 2 and \ + filter_format == FilterFormat.HWIO: + input_op = self._producer[op.input[0]] + input_shape = input_op.output_shape[0].dims + # check input op + if len(input_shape) == 4 and \ + np.prod(input_shape[1:]) == op.output_shape[0].dims[1]: + is_fc = True + consumers = self._consumers[op.output[0]] + # check matmul op + for matmul_op in consumers: + if matmul_op.type != MaceOp.MatMul.name: + is_fc = False + else: + weight = self._consts[matmul_op.input[1]] + if len(weight.dims) != 2 or \ + weight.dims[0] != op.output_shape[0].dims[1]: + is_fc = False + if is_fc: + print 'convert reshape and matmul to fc' + self.safe_remove_node(op, input_op, + remove_input_tensor=True) + for matmul_op in consumers: + weight = self._consts[matmul_op.input[1]] + matmul_op.type = MaceOp.FullyConnected.name + weight_data = np.array(weight.float_data).reshape( + weight.dims) + weight.dims[:] = input_shape[1:] + \ + [weight_data.shape[1]] + return True + + # transform input(2D) -> matmul to fc if op.type == MaceOp.MatMul.name and \ filter_format == FilterFormat.HWIO: producer = self._producer[op.input[0]] weight = self._consts[op.input[1]] - if len(weight.dims) == 2 \ - and producer.type == MaceOp.Reshape.name \ - and len(producer.output) == 1 \ - and producer.input[1] in self._consts \ - and len(producer.output_shape[0].dims) == 2: - input_op = self._producer[producer.input[0]] - input_shape = input_op.output_shape[0].dims - feature_size = np.prod(input_shape[1:]) - self.safe_remove_node(producer, input_op, - remove_input_tensor=True) - if feature_size == producer.output_shape[0].dims[1]: - print 'convert reshape and matmul to fc' - op.type = MaceOp.FullyConnected.name - weight_data = np.array(weight.float_data).reshape( - weight.dims) - weight.dims[:] = input_shape[1:] + \ - [weight_data.shape[1]] - return True - elif len(weight.dims) == 2 and \ + if len(weight.dims) == 2 and \ + producer.type != MaceOp.Reshape.name and \ len(producer.output_shape[0].dims) == 2 and \ weight.dims[0] == producer.output_shape[0].dims[1]: print 'convert matmul to fc'