diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py index 9a5440f440b9307f268956361d5f31e2eb3505c1..cfd8409abfaad1de1b066583665afd9f47d4af41 100644 --- a/mace/python/tools/converter_tool/base_converter.py +++ b/mace/python/tools/converter_tool/base_converter.py @@ -195,6 +195,7 @@ class TransformerRule(Enum): QUANTIZE_NODES = 23 ADD_QUANTIZE_TENSOR_RANGE = 24 QUANTIZE_WEIGHTS = 25 + TRANSPOSE_MATMUL_WEIGHT = 26 class ConverterInterface(object): @@ -345,6 +346,8 @@ class ConverterOption(object): TransformerRule.FOLD_ACTIVATION, TransformerRule.TRANSFORM_GLOBAL_CONV_TO_FC, TransformerRule.RESHAPE_FC_WEIGHT, + # Transpose the weight of matmul if necessary + TransformerRule.TRANSPOSE_MATMUL_WEIGHT, # Model data format related transformation TransformerRule.TRANSPOSE_FILTERS, TransformerRule.TRANSPOSE_DATA_FORMAT, diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index 16d9eae007d305e12276d9c0bcd66e0ba1e9f1d3..5f179d653f25e5773587d178b6668d2fd0641356 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -64,6 +64,9 @@ class Transformer(base_converter.ConverterInterface): TransformerRule.FOLD_BIASADD: self.fold_biasadd, TransformerRule.FLATTEN_ATROUS_CONV: self.flatten_atrous_conv, TransformerRule.FOLD_ACTIVATION: self.fold_activation, + # TODO(liuqi): should move to transpose_filter + TransformerRule.TRANSPOSE_MATMUL_WEIGHT: + self.transpose_matmul_weight, TransformerRule.TRANSPOSE_FILTERS: self.transpose_filters, TransformerRule.TRANSPOSE_DATA_FORMAT: self.transpose_data_format, TransformerRule.ADD_IN_OUT_TENSOR_INFO: @@ -954,6 +957,29 @@ class Transformer(base_converter.ConverterInterface): return False + def transpose_matmul_weight(self): + if self._option.device != DeviceType.CPU.value: + return False + net = self._model + transpose_arg_names = [MaceKeyword.mace_transpose_a_str, + MaceKeyword.mace_transpose_b_str] + for op in net.op: + if op.type == MaceOp.MatMul.name: # noqa + for i in range(len(op.input)): + input = op.input[i] + if input in self._consts \ + and len(self._consts[input].dims) == 2: + arg = ConverterUtil.get_arg(op, transpose_arg_names[i]) + if arg is not None and arg.i == 1: + print 'convert matmul' + filter = self._consts[input] + filter_data = np.array(filter.float_data).reshape( + filter.dims) + filter_data = filter_data.transpose(1, 0) + filter.float_data[:] = filter_data.flat + filter.dims[:] = filter_data.shape + arg.i = 0 + def transpose_filters(self): net = self._model filter_format = self.filter_format() diff --git a/tools/converter.py b/tools/converter.py index 509d1eceed52ea185c996c06f6045b906c8bb45b..4eb6405c9fd682eb7b4a15485816868d3c6795e7 100644 --- a/tools/converter.py +++ b/tools/converter.py @@ -457,6 +457,7 @@ def format_model_config(flags): "'%s' is necessary in subgraph" % key) if not isinstance(value, list): subgraph[key] = [value] + subgraph[key] = [str(v) for v in subgraph[key]] input_data_types = subgraph.get(YAMLKeyword.input_data_types, "") if input_data_types: @@ -507,6 +508,8 @@ def format_model_config(flags): subgraph[YAMLKeyword.input_ranges] = [input_ranges] else: subgraph[YAMLKeyword.input_ranges] = input_ranges + subgraph[YAMLKeyword.input_ranges] =\ + [str(v) for v in subgraph[YAMLKeyword.input_ranges]] for key in [YAMLKeyword.limit_opencl_kernel_time, YAMLKeyword.nnlib_graph_mode, diff --git a/tools/generate_data.py b/tools/generate_data.py index 1e485f2034aeaad6e3d25ccfda24936cd827e880..d8d10ea2e68c16684bfa1dd0f5e7cc08a63a7ada 100644 --- a/tools/generate_data.py +++ b/tools/generate_data.py @@ -44,19 +44,22 @@ def generate_input_data(input_file, input_node, input_shape, input_ranges, input_data_type): input_names = [name for name in input_node.split(',')] input_shapes = [shape for shape in input_shape.split(':')] + if input_ranges: input_ranges = [r for r in input_ranges.split(':')] else: - input_ranges = [[-1, 1]] * len(input_names) + input_ranges = ["-1,1"] * len(input_names) if input_data_type: input_data_types = [data_type for data_type in input_data_type.split(',')] else: input_data_types = ['float32'] * len(input_names) + assert len(input_names) == len(input_shapes) == len(input_ranges) == len(input_data_types) # noqa for i in range(len(input_names)): shape = [int(x) for x in input_shapes[i].split(',')] - generate_data(input_names[i], shape, input_file, input_ranges[i], + input_range = [float(x) for x in input_ranges[i].split(',')] + generate_data(input_names[i], shape, input_file, input_range, input_data_types[i]) print "Generate input file done."