提交 c79d6d13 编写于 作者: 李寅

Merge branch 'fix-tensor-reuse' into 'master'

Fix transpose filter repeatly when different ops reuse same tensor.

See merge request !754
...@@ -983,6 +983,7 @@ class Transformer(base_converter.ConverterInterface): ...@@ -983,6 +983,7 @@ class Transformer(base_converter.ConverterInterface):
def transpose_filters(self): def transpose_filters(self):
net = self._model net = self._model
filter_format = self.filter_format() filter_format = self.filter_format()
transposed_filter = set()
if self._option.quantize: if self._option.quantize:
print("Transpose filters to OHWI") print("Transpose filters to OHWI")
...@@ -995,14 +996,16 @@ class Transformer(base_converter.ConverterInterface): ...@@ -995,14 +996,16 @@ class Transformer(base_converter.ConverterInterface):
"filter format: %s" % filter_format.name) "filter format: %s" % filter_format.name)
for op in net.op: for op in net.op:
if op.type == MaceOp.Conv2D.name \ if (op.type == MaceOp.Conv2D.name
or op.type == MaceOp.Deconv2D.name: or op.type == MaceOp.Deconv2D.name) \
and op.input[1] not in transposed_filter:
filter = self._consts[op.input[1]] filter = self._consts[op.input[1]]
filter_data = np.array(filter.float_data).reshape( filter_data = np.array(filter.float_data).reshape(
filter.dims) filter.dims)
filter_data = filter_data.transpose(transpose_order) filter_data = filter_data.transpose(transpose_order)
filter.float_data[:] = filter_data.flat filter.float_data[:] = filter_data.flat
filter.dims[:] = filter_data.shape filter.dims[:] = filter_data.shape
transposed_filter.add(op.input[1])
self.set_filter_format(FilterFormat.OHWI) self.set_filter_format(FilterFormat.OHWI)
else: else:
...@@ -1010,24 +1013,29 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1010,24 +1013,29 @@ class Transformer(base_converter.ConverterInterface):
# transpose filter to OIHW/MIHW for tensorflow (HWIO/HWIM) # transpose filter to OIHW/MIHW for tensorflow (HWIO/HWIM)
if filter_format == FilterFormat.HWIO: if filter_format == FilterFormat.HWIO:
for op in net.op: for op in net.op:
if op.type == MaceOp.Conv2D.name \ if (op.type == MaceOp.Conv2D.name
or op.type == MaceOp.Deconv2D.name \ or op.type == MaceOp.Deconv2D.name
or op.type == MaceOp.DepthwiseConv2d.name: or op.type == MaceOp.DepthwiseConv2d.name) \
and op.input[1] not in transposed_filter:
filter = self._consts[op.input[1]] filter = self._consts[op.input[1]]
filter_data = np.array(filter.float_data).reshape( filter_data = np.array(filter.float_data).reshape(
filter.dims) filter.dims)
filter_data = filter_data.transpose(3, 2, 0, 1) filter_data = filter_data.transpose(3, 2, 0, 1)
filter.float_data[:] = filter_data.flat filter.float_data[:] = filter_data.flat
filter.dims[:] = filter_data.shape filter.dims[:] = filter_data.shape
if (op.type == MaceOp.MatMul.name and transposed_filter.add(op.input[1])
ConverterUtil.get_arg(op, MaceKeyword.mace_winograd_filter_transformed) is not None): # noqa if (op.type == MaceOp.MatMul.name
and (ConverterUtil.get_arg(op, MaceKeyword.mace_winograd_filter_transformed) is not None) # noqa
and op.input[1] not in transposed_filter):
filter = self._consts[op.input[0]] filter = self._consts[op.input[0]]
filter_data = np.array(filter.float_data).reshape( filter_data = np.array(filter.float_data).reshape(
filter.dims) filter.dims)
filter_data = filter_data.transpose(3, 2, 0, 1) filter_data = filter_data.transpose(3, 2, 0, 1)
filter.float_data[:] = filter_data.flat filter.float_data[:] = filter_data.flat
filter.dims[:] = filter_data.shape filter.dims[:] = filter_data.shape
if op.type == MaceOp.FullyConnected.name: transposed_filter.add(op.input[0])
if op.type == MaceOp.FullyConnected.name \
and op.input[1] not in transposed_filter:
weight = self._consts[op.input[1]] weight = self._consts[op.input[1]]
if len(weight.dims) == 4: if len(weight.dims) == 4:
weight_data = np.array(weight.float_data).reshape( weight_data = np.array(weight.float_data).reshape(
...@@ -1035,16 +1043,19 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1035,16 +1043,19 @@ class Transformer(base_converter.ConverterInterface):
weight_data = weight_data.transpose(3, 2, 0, 1) weight_data = weight_data.transpose(3, 2, 0, 1)
weight.float_data[:] = weight_data.flat weight.float_data[:] = weight_data.flat
weight.dims[:] = weight_data.shape weight.dims[:] = weight_data.shape
transposed_filter.add(op.input[1])
self.set_filter_format(FilterFormat.OIHW) self.set_filter_format(FilterFormat.OIHW)
for op in net.op: for op in net.op:
if op.type == MaceOp.Deconv2D.name: if op.type == MaceOp.Deconv2D.name \
and op.input[1] not in transposed_filter:
filter = self._consts[op.input[1]] filter = self._consts[op.input[1]]
filter_data = np.array(filter.float_data).reshape( filter_data = np.array(filter.float_data).reshape(
filter.dims) filter.dims)
filter_data = filter_data.transpose(1, 0, 2, 3) filter_data = filter_data.transpose(1, 0, 2, 3)
filter.float_data[:] = filter_data.flat filter.float_data[:] = filter_data.flat
filter.dims[:] = filter_data.shape filter.dims[:] = filter_data.shape
transposed_filter.add(op.input[1])
return False return False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册