提交 95536bcb 编写于 作者: L liutuo

fix transpose matmul weight when multi matmul ops shares one kernel

上级 be56d6a3
......@@ -1047,6 +1047,7 @@ class Transformer(base_converter.ConverterInterface):
if self._option.device != DeviceType.CPU.value:
return False
net = self._model
transposed_weights = []
for op in net.op:
if op.type == MaceOp.MatMul.name: # noqa
rhs = op.input[1]
......@@ -1058,15 +1059,17 @@ class Transformer(base_converter.ConverterInterface):
arg.name = MaceKeyword.mace_transpose_b_str
arg.i = 0
if arg.i == 0:
filter = self._consts[rhs]
filter_data = np.array(filter.float_data).reshape(
filter.dims)
filter_data = filter_data.transpose(1, 0)
filter.float_data[:] = filter_data.flat
filter.dims[:] = filter_data.shape
arg.i = 1
six.print_('Transpose matmul weight to shape:',
filter.dims)
if rhs not in transposed_weights:
filter = self._consts[rhs]
filter_data = np.array(filter.float_data).reshape(
filter.dims)
filter_data = filter_data.transpose(1, 0)
filter.float_data[:] = filter_data.flat
filter.dims[:] = filter_data.shape
transposed_weights.append(rhs)
six.print_('Transpose matmul weight to shape:',
filter.dims)
def transpose_filters(self):
net = self._model
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册