diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index a1776b2232583a660b304d5d8245d8e1967bc82c..e297590abeedc5f11ad63e9b87065daf17daf7cd 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -700,16 +700,17 @@ class Transformer(base_converter.ConverterInterface): return False def reshape_fc_weight(self): + print("Reshape fully connecrted weight shape") net = self._model for op in net.op: if op.type == MaceOp.FullyConnected.name: weight = self._consts[op.input[1]] - # NCHW - input_shape = list(self._producer[op.input[0]] - .output_shape[0].dims) - weight_shape = [weight.dims[0]] + input_shape[1:] - del weight.dims[:] - weight.dims.extend(weight_shape) + input_op = self._producer[op.input[0]] + input_shape = list(input_op.output_shape[0].dims) + input_data_format = ConverterUtil.data_format(input_op) + weight.dims[:] = [weight.dims[0]] + input_shape[1:] + if input_data_format == DataFormat.NHWC: + self.transpose_shape(weight.dims, [0, 3, 1, 2]) return False