From 95536bcbf7851e393a0d6ee85a3a63ca49add64b Mon Sep 17 00:00:00 2001 From: liutuo Date: Wed, 5 Jun 2019 11:42:32 +0800 Subject: [PATCH] fix transpose matmul weight when multi matmul ops shares one kernel --- .../tools/converter_tool/transformer.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index c1fd9ee0..6a07db05 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -1047,6 +1047,7 @@ class Transformer(base_converter.ConverterInterface): if self._option.device != DeviceType.CPU.value: return False net = self._model + transposed_weights = [] for op in net.op: if op.type == MaceOp.MatMul.name: # noqa rhs = op.input[1] @@ -1058,15 +1059,17 @@ class Transformer(base_converter.ConverterInterface): arg.name = MaceKeyword.mace_transpose_b_str arg.i = 0 if arg.i == 0: - filter = self._consts[rhs] - filter_data = np.array(filter.float_data).reshape( - filter.dims) - filter_data = filter_data.transpose(1, 0) - filter.float_data[:] = filter_data.flat - filter.dims[:] = filter_data.shape arg.i = 1 - six.print_('Transpose matmul weight to shape:', - filter.dims) + if rhs not in transposed_weights: + filter = self._consts[rhs] + filter_data = np.array(filter.float_data).reshape( + filter.dims) + filter_data = filter_data.transpose(1, 0) + filter.float_data[:] = filter_data.flat + filter.dims[:] = filter_data.shape + transposed_weights.append(rhs) + six.print_('Transpose matmul weight to shape:', + filter.dims) def transpose_filters(self): net = self._model -- GitLab