提交 19c241b1 编写于 作者: 刘琦

Merge branch 'fix_transpose_filter' into 'master'

fix deconv transpose filter

See merge request !760
...@@ -983,6 +983,7 @@ class Transformer(base_converter.ConverterInterface): ...@@ -983,6 +983,7 @@ class Transformer(base_converter.ConverterInterface):
net = self._model net = self._model
filter_format = self.filter_format() filter_format = self.filter_format()
transposed_filter = set() transposed_filter = set()
transposed_deconv_filter = set()
if self._option.quantize: if self._option.quantize:
print("Transpose filters to OHWI") print("Transpose filters to OHWI")
...@@ -995,9 +996,9 @@ class Transformer(base_converter.ConverterInterface): ...@@ -995,9 +996,9 @@ class Transformer(base_converter.ConverterInterface):
"filter format: %s" % filter_format.name) "filter format: %s" % filter_format.name)
for op in net.op: for op in net.op:
if (op.type == MaceOp.Conv2D.name if (op.type == MaceOp.Conv2D.name or
or op.type == MaceOp.Deconv2D.name) \ op.type == MaceOp.Deconv2D.name) and\
and op.input[1] not in transposed_filter: op.input[1] not in transposed_filter:
filter = self._consts[op.input[1]] filter = self._consts[op.input[1]]
filter_data = np.array(filter.float_data).reshape( filter_data = np.array(filter.float_data).reshape(
filter.dims) filter.dims)
...@@ -1005,6 +1006,17 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1005,6 +1006,17 @@ class Transformer(base_converter.ConverterInterface):
filter.float_data[:] = filter_data.flat filter.float_data[:] = filter_data.flat
filter.dims[:] = filter_data.shape filter.dims[:] = filter_data.shape
transposed_filter.add(op.input[1]) transposed_filter.add(op.input[1])
# deconv's filter's output channel and input channel is reversed
for op in net.op:
if op.type == MaceOp.Deconv2D.name and \
op.input[1] not in transposed_deconv_filter:
filter = self._consts[op.input[1]]
filter_data = np.array(filter.float_data).reshape(
filter.dims)
filter_data = filter_data.transpose(3, 1, 2, 0)
filter.float_data[:] = filter_data.flat
filter.dims[:] = filter_data.shape
transposed_deconv_filter.add(op.input[1])
self.set_filter_format(FilterFormat.OHWI) self.set_filter_format(FilterFormat.OHWI)
else: else:
...@@ -1045,16 +1057,17 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1045,16 +1057,17 @@ class Transformer(base_converter.ConverterInterface):
transposed_filter.add(op.input[1]) transposed_filter.add(op.input[1])
self.set_filter_format(FilterFormat.OIHW) self.set_filter_format(FilterFormat.OIHW)
# deconv's filter's output channel and input channel is reversed
for op in net.op: for op in net.op:
if op.type == MaceOp.Deconv2D.name \ if op.type == MaceOp.Deconv2D.name \
and op.input[1] not in transposed_filter: and op.input[1] not in transposed_deconv_filter:
filter = self._consts[op.input[1]] filter = self._consts[op.input[1]]
filter_data = np.array(filter.float_data).reshape( filter_data = np.array(filter.float_data).reshape(
filter.dims) filter.dims)
filter_data = filter_data.transpose(1, 0, 2, 3) filter_data = filter_data.transpose(1, 0, 2, 3)
filter.float_data[:] = filter_data.flat filter.float_data[:] = filter_data.flat
filter.dims[:] = filter_data.shape filter.dims[:] = filter_data.shape
transposed_filter.add(op.input[1]) transposed_deconv_filter.add(op.input[1])
return False return False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册