From 4ea191e1a21866e6c244392a3fd68725af860b80 Mon Sep 17 00:00:00 2001 From: liuqi Date: Wed, 22 Aug 2018 14:33:27 +0800 Subject: [PATCH] Fix transpose filter repeatly when different ops reuse same tensor. --- .../tools/converter_tool/transformer.py | 29 +++++++++++++------ 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index 5f179d65..7908c315 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -983,6 +983,7 @@ class Transformer(base_converter.ConverterInterface): def transpose_filters(self): net = self._model filter_format = self.filter_format() + transposed_filter = set() if self._option.quantize: print("Transpose filters to OHWI") @@ -995,14 +996,16 @@ class Transformer(base_converter.ConverterInterface): "filter format: %s" % filter_format.name) for op in net.op: - if op.type == MaceOp.Conv2D.name \ - or op.type == MaceOp.Deconv2D.name: + if (op.type == MaceOp.Conv2D.name + or op.type == MaceOp.Deconv2D.name) \ + and op.input[1] not in transposed_filter: filter = self._consts[op.input[1]] filter_data = np.array(filter.float_data).reshape( filter.dims) filter_data = filter_data.transpose(transpose_order) filter.float_data[:] = filter_data.flat filter.dims[:] = filter_data.shape + transposed_filter.add(op.input[1]) self.set_filter_format(FilterFormat.OHWI) else: @@ -1010,24 +1013,29 @@ class Transformer(base_converter.ConverterInterface): # transpose filter to OIHW/MIHW for tensorflow (HWIO/HWIM) if filter_format == FilterFormat.HWIO: for op in net.op: - if op.type == MaceOp.Conv2D.name \ - or op.type == MaceOp.Deconv2D.name \ - or op.type == MaceOp.DepthwiseConv2d.name: + if (op.type == MaceOp.Conv2D.name + or op.type == MaceOp.Deconv2D.name + or op.type == MaceOp.DepthwiseConv2d.name) \ + and op.input[1] not in transposed_filter: filter = self._consts[op.input[1]] filter_data = np.array(filter.float_data).reshape( filter.dims) filter_data = filter_data.transpose(3, 2, 0, 1) filter.float_data[:] = filter_data.flat filter.dims[:] = filter_data.shape - if (op.type == MaceOp.MatMul.name and - ConverterUtil.get_arg(op, MaceKeyword.mace_winograd_filter_transformed) is not None): # noqa + transposed_filter.add(op.input[1]) + if (op.type == MaceOp.MatMul.name + and (ConverterUtil.get_arg(op, MaceKeyword.mace_winograd_filter_transformed) is not None) # noqa + and op.input[1] not in transposed_filter): filter = self._consts[op.input[0]] filter_data = np.array(filter.float_data).reshape( filter.dims) filter_data = filter_data.transpose(3, 2, 0, 1) filter.float_data[:] = filter_data.flat filter.dims[:] = filter_data.shape - if op.type == MaceOp.FullyConnected.name: + transposed_filter.add(op.input[0]) + if op.type == MaceOp.FullyConnected.name \ + and op.input[1] not in transposed_filter: weight = self._consts[op.input[1]] if len(weight.dims) == 4: weight_data = np.array(weight.float_data).reshape( @@ -1035,16 +1043,19 @@ class Transformer(base_converter.ConverterInterface): weight_data = weight_data.transpose(3, 2, 0, 1) weight.float_data[:] = weight_data.flat weight.dims[:] = weight_data.shape + transposed_filter.add(op.input[1]) self.set_filter_format(FilterFormat.OIHW) for op in net.op: - if op.type == MaceOp.Deconv2D.name: + if op.type == MaceOp.Deconv2D.name \ + and op.input[1] not in transposed_filter: filter = self._consts[op.input[1]] filter_data = np.array(filter.float_data).reshape( filter.dims) filter_data = filter_data.transpose(1, 0, 2, 3) filter.float_data[:] = filter_data.flat filter.dims[:] = filter_data.shape + transposed_filter.add(op.input[1]) return False -- GitLab