提交 e3085ec6 编写于 作者: L liuqi

Fix fc weight shape transformer bug.

上级 3b11a1b2
......@@ -700,16 +700,17 @@ class Transformer(base_converter.ConverterInterface):
return False
def reshape_fc_weight(self):
print("Reshape fully connecrted weight shape")
net = self._model
for op in net.op:
if op.type == MaceOp.FullyConnected.name:
weight = self._consts[op.input[1]]
# NCHW
input_shape = list(self._producer[op.input[0]]
.output_shape[0].dims)
weight_shape = [weight.dims[0]] + input_shape[1:]
del weight.dims[:]
weight.dims.extend(weight_shape)
input_op = self._producer[op.input[0]]
input_shape = list(input_op.output_shape[0].dims)
input_data_format = ConverterUtil.data_format(input_op)
weight.dims[:] = [weight.dims[0]] + input_shape[1:]
if input_data_format == DataFormat.NHWC:
self.transpose_shape(weight.dims, [0, 3, 1, 2])
return False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册