diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py index aeb626a681a579a120008c81082bed0e505f582e..a2158081c1575465712f928f6281525521c0ff62 100644 --- a/mace/python/tools/converter_tool/base_converter.py +++ b/mace/python/tools/converter_tool/base_converter.py @@ -256,6 +256,7 @@ class TransformerRule(Enum): TRANSPOSE_MATMUL_WEIGHT = 34 FOLD_EMBEDDING_LOOKUP = 35 TRANSPOSE_CAFFE_RESHAPE_AND_FLATTEN = 36 + FOLD_FC_RESHAPE = 37 class ConverterInterface(object): @@ -461,6 +462,7 @@ class ConverterOption(object): TransformerRule.FOLD_SQRDIFF_MEAN, TransformerRule.TRANSFORM_GLOBAL_CONV_TO_FC, TransformerRule.RESHAPE_FC_WEIGHT, + TransformerRule.FOLD_FC_RESHAPE, # Model data format related transformation TransformerRule.TRANSPOSE_FILTERS, TransformerRule.TRANSPOSE_DATA_FORMAT, diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index 5e564fa41a36529294c39837d41acf0a0ef3e653..1fb61f715f05ea406294afc6fa53cde831869a43 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -75,6 +75,8 @@ class Transformer(base_converter.ConverterInterface): TransformerRule.TRANSPOSE_FILTERS: self.transpose_filters, TransformerRule.TRANSPOSE_MATMUL_WEIGHT: self.transpose_matmul_weight, + TransformerRule.FOLD_FC_RESHAPE: + self.fold_fc_reshape, TransformerRule.TRANSPOSE_DATA_FORMAT: self.transpose_data_format, TransformerRule.ADD_WINOGRAD_ARG: self.add_winograd_arg, TransformerRule.ADD_IN_OUT_TENSOR_INFO: @@ -1227,11 +1229,24 @@ class Transformer(base_converter.ConverterInterface): return True return False + def is_after_fc(self, op): + while op.input[0] in self._producer: + producer = self._producer[op.input[0]] + if producer.type in [MaceOp.Activation.name, MaceOp.BiasAdd.name]: + op = producer + continue + elif producer.type == MaceOp.FullyConnected.name: + return True + else: + return False + return False + def transform_matmul_to_fc(self): net = self._model filter_format = self.filter_format() for op in net.op: - # transform input(4D) -> reshape(2D) -> matmul to fc + # transform `input(4D) -> reshape(2D) -> matmul` to `fc(2D)` + # fc output is 2D in transformer, using as 4D in op kernel # work for TensorFlow if op.type == MaceOp.Reshape.name and \ len(op.input) == 2 and \ @@ -1268,6 +1283,21 @@ class Transformer(base_converter.ConverterInterface): [weight_data.shape[1]] return True + # transform `fc1(2D) -> matmul` to `fc1(2D) -> fc1(2D)` + if op.type == MaceOp.MatMul.name and \ + filter_format == FilterFormat.HWIO: + producer = self._producer[op.input[0]] + weight = self._consts[op.input[1]] + if len(weight.dims) == 2 and self.is_after_fc(op) and \ + len(producer.output_shape[0].dims) == 2 and \ + weight.dims[0] == producer.output_shape[0].dims[1]: + six.print_('convert matmul to fc') + op.type = MaceOp.FullyConnected.name + weight_data = np.array(weight.float_data).reshape( + weight.dims) + weight.dims[:] = [1, 1] + list(weight_data.shape) + return True + return False def update_float_op_data_type(self): @@ -1750,3 +1780,22 @@ class Transformer(base_converter.ConverterInterface): shape_tensor.data_type = mace_pb2.DT_INT32 shape_tensor.int32_data.extend(dims) op.input.append(shape_tensor.name) + + def fold_fc_reshape(self): + net = self._model + for op in net.op: + # whether to reshape fc output(default 4D) + if op.type == MaceOp.FullyConnected.name: + consumers = self._consumers[op.output[0]] + op_output_shape = op.output_shape[0].dims[:] + for consumer in consumers: + if consumer.type == MaceOp.Reshape.name and \ + consumer.input[1] in self._consts and \ + self._consts[consumer.input[1]].int32_data[:] == \ + [op_output_shape[0], 1, 1, op_output_shape[1]]: + # work for tensorflow + net.tensors.remove(self._consts[consumer.input[1]]) + del consumer.input[1] + self.safe_remove_node(consumer, None) + return True + return False diff --git a/tools/converter.py b/tools/converter.py index 4af30403ad82550d86ec7b2d81c2129b732cdc2c..486014ab315a9854739701ee3ed0d939133cfe03 100644 --- a/tools/converter.py +++ b/tools/converter.py @@ -411,7 +411,7 @@ def format_model_config(flags): ModuleName.YAML_CONFIG, "'input_data_formats' must be in " + str(DataFormatStrs) + ", but got " - + input_data_formats) + + input_data_format) else: subgraph[YAMLKeyword.input_data_formats] = [DataFormat.NHWC] @@ -431,7 +431,7 @@ def format_model_config(flags): subgraph[YAMLKeyword.output_data_formats]: mace_check(output_data_format in DataFormatStrs, ModuleName.YAML_CONFIG, - "'input_data_formats' must be in " + "'output_data_formats' must be in " + str(DataFormatStrs)) else: subgraph[YAMLKeyword.output_data_formats] = [DataFormat.NHWC]