diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index 2f7bd84acb98320386c03c0d4b07e625c3978a44..803a61ce6ffa37d2b4d5b8b67ecb77c89bc459db 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -983,6 +983,7 @@ class Transformer(base_converter.ConverterInterface): net = self._model filter_format = self.filter_format() transposed_filter = set() + transposed_deconv_filter = set() if self._option.quantize: print("Transpose filters to OHWI") @@ -995,9 +996,9 @@ class Transformer(base_converter.ConverterInterface): "filter format: %s" % filter_format.name) for op in net.op: - if (op.type == MaceOp.Conv2D.name - or op.type == MaceOp.Deconv2D.name) \ - and op.input[1] not in transposed_filter: + if (op.type == MaceOp.Conv2D.name or + op.type == MaceOp.Deconv2D.name) and\ + op.input[1] not in transposed_filter: filter = self._consts[op.input[1]] filter_data = np.array(filter.float_data).reshape( filter.dims) @@ -1005,6 +1006,17 @@ class Transformer(base_converter.ConverterInterface): filter.float_data[:] = filter_data.flat filter.dims[:] = filter_data.shape transposed_filter.add(op.input[1]) + # deconv's filter's output channel and input channel is reversed + for op in net.op: + if op.type == MaceOp.Deconv2D.name and \ + op.input[1] not in transposed_deconv_filter: + filter = self._consts[op.input[1]] + filter_data = np.array(filter.float_data).reshape( + filter.dims) + filter_data = filter_data.transpose(3, 1, 2, 0) + filter.float_data[:] = filter_data.flat + filter.dims[:] = filter_data.shape + transposed_deconv_filter.add(op.input[1]) self.set_filter_format(FilterFormat.OHWI) else: @@ -1045,16 +1057,17 @@ class Transformer(base_converter.ConverterInterface): transposed_filter.add(op.input[1]) self.set_filter_format(FilterFormat.OIHW) + # deconv's filter's output channel and input channel is reversed for op in net.op: if op.type == MaceOp.Deconv2D.name \ - and op.input[1] not in transposed_filter: + and op.input[1] not in transposed_deconv_filter: filter = self._consts[op.input[1]] filter_data = np.array(filter.float_data).reshape( filter.dims) filter_data = filter_data.transpose(1, 0, 2, 3) filter.float_data[:] = filter_data.flat filter.dims[:] = filter_data.shape - transposed_filter.add(op.input[1]) + transposed_deconv_filter.add(op.input[1]) return False