提交 4ea191e1 编写于 作者: L liuqi

Fix transpose filter repeatly when different ops reuse same tensor.

上级 e62ea9b1
......@@ -983,6 +983,7 @@ class Transformer(base_converter.ConverterInterface):
def transpose_filters(self):
net = self._model
filter_format = self.filter_format()
transposed_filter = set()
if self._option.quantize:
print("Transpose filters to OHWI")
......@@ -995,14 +996,16 @@ class Transformer(base_converter.ConverterInterface):
"filter format: %s" % filter_format.name)
for op in net.op:
if op.type == MaceOp.Conv2D.name \
or op.type == MaceOp.Deconv2D.name:
if (op.type == MaceOp.Conv2D.name
or op.type == MaceOp.Deconv2D.name) \
and op.input[1] not in transposed_filter:
filter = self._consts[op.input[1]]
filter_data = np.array(filter.float_data).reshape(
filter.dims)
filter_data = filter_data.transpose(transpose_order)
filter.float_data[:] = filter_data.flat
filter.dims[:] = filter_data.shape
transposed_filter.add(op.input[1])
self.set_filter_format(FilterFormat.OHWI)
else:
......@@ -1010,24 +1013,29 @@ class Transformer(base_converter.ConverterInterface):
# transpose filter to OIHW/MIHW for tensorflow (HWIO/HWIM)
if filter_format == FilterFormat.HWIO:
for op in net.op:
if op.type == MaceOp.Conv2D.name \
or op.type == MaceOp.Deconv2D.name \
or op.type == MaceOp.DepthwiseConv2d.name:
if (op.type == MaceOp.Conv2D.name
or op.type == MaceOp.Deconv2D.name
or op.type == MaceOp.DepthwiseConv2d.name) \
and op.input[1] not in transposed_filter:
filter = self._consts[op.input[1]]
filter_data = np.array(filter.float_data).reshape(
filter.dims)
filter_data = filter_data.transpose(3, 2, 0, 1)
filter.float_data[:] = filter_data.flat
filter.dims[:] = filter_data.shape
if (op.type == MaceOp.MatMul.name and
ConverterUtil.get_arg(op, MaceKeyword.mace_winograd_filter_transformed) is not None): # noqa
transposed_filter.add(op.input[1])
if (op.type == MaceOp.MatMul.name
and (ConverterUtil.get_arg(op, MaceKeyword.mace_winograd_filter_transformed) is not None) # noqa
and op.input[1] not in transposed_filter):
filter = self._consts[op.input[0]]
filter_data = np.array(filter.float_data).reshape(
filter.dims)
filter_data = filter_data.transpose(3, 2, 0, 1)
filter.float_data[:] = filter_data.flat
filter.dims[:] = filter_data.shape
if op.type == MaceOp.FullyConnected.name:
transposed_filter.add(op.input[0])
if op.type == MaceOp.FullyConnected.name \
and op.input[1] not in transposed_filter:
weight = self._consts[op.input[1]]
if len(weight.dims) == 4:
weight_data = np.array(weight.float_data).reshape(
......@@ -1035,16 +1043,19 @@ class Transformer(base_converter.ConverterInterface):
weight_data = weight_data.transpose(3, 2, 0, 1)
weight.float_data[:] = weight_data.flat
weight.dims[:] = weight_data.shape
transposed_filter.add(op.input[1])
self.set_filter_format(FilterFormat.OIHW)
for op in net.op:
if op.type == MaceOp.Deconv2D.name:
if op.type == MaceOp.Deconv2D.name \
and op.input[1] not in transposed_filter:
filter = self._consts[op.input[1]]
filter_data = np.array(filter.float_data).reshape(
filter.dims)
filter_data = filter_data.transpose(1, 0, 2, 3)
filter.float_data[:] = filter_data.flat
filter.dims[:] = filter_data.shape
transposed_filter.add(op.input[1])
return False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册