From 1837e2f79bef85dc22e2b752fcb6ae846507b6bd Mon Sep 17 00:00:00 2001 From: liutuo Date: Thu, 23 Aug 2018 11:53:59 +0800 Subject: [PATCH] fix deconv transpose filter --- .../tools/converter_tool/transformer.py | 23 +++++++++++++++---- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index 2f7bd84a..803a61ce 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -983,6 +983,7 @@ class Transformer(base_converter.ConverterInterface): net = self._model filter_format = self.filter_format() transposed_filter = set() + transposed_deconv_filter = set() if self._option.quantize: print("Transpose filters to OHWI") @@ -995,9 +996,9 @@ class Transformer(base_converter.ConverterInterface): "filter format: %s" % filter_format.name) for op in net.op: - if (op.type == MaceOp.Conv2D.name - or op.type == MaceOp.Deconv2D.name) \ - and op.input[1] not in transposed_filter: + if (op.type == MaceOp.Conv2D.name or + op.type == MaceOp.Deconv2D.name) and\ + op.input[1] not in transposed_filter: filter = self._consts[op.input[1]] filter_data = np.array(filter.float_data).reshape( filter.dims) @@ -1005,6 +1006,17 @@ class Transformer(base_converter.ConverterInterface): filter.float_data[:] = filter_data.flat filter.dims[:] = filter_data.shape transposed_filter.add(op.input[1]) + # deconv's filter's output channel and input channel is reversed + for op in net.op: + if op.type == MaceOp.Deconv2D.name and \ + op.input[1] not in transposed_deconv_filter: + filter = self._consts[op.input[1]] + filter_data = np.array(filter.float_data).reshape( + filter.dims) + filter_data = filter_data.transpose(3, 1, 2, 0) + filter.float_data[:] = filter_data.flat + filter.dims[:] = filter_data.shape + transposed_deconv_filter.add(op.input[1]) self.set_filter_format(FilterFormat.OHWI) else: @@ -1045,16 +1057,17 @@ class Transformer(base_converter.ConverterInterface): transposed_filter.add(op.input[1]) self.set_filter_format(FilterFormat.OIHW) + # deconv's filter's output channel and input channel is reversed for op in net.op: if op.type == MaceOp.Deconv2D.name \ - and op.input[1] not in transposed_filter: + and op.input[1] not in transposed_deconv_filter: filter = self._consts[op.input[1]] filter_data = np.array(filter.float_data).reshape( filter.dims) filter_data = filter_data.transpose(1, 0, 2, 3) filter.float_data[:] = filter_data.flat filter.dims[:] = filter_data.shape - transposed_filter.add(op.input[1]) + transposed_deconv_filter.add(op.input[1]) return False -- GitLab