提交 1837e2f7 编写于 作者: L liutuo

fix deconv transpose filter

上级 32be5ec8
......@@ -983,6 +983,7 @@ class Transformer(base_converter.ConverterInterface):
net = self._model
filter_format = self.filter_format()
transposed_filter = set()
transposed_deconv_filter = set()
if self._option.quantize:
print("Transpose filters to OHWI")
......@@ -995,9 +996,9 @@ class Transformer(base_converter.ConverterInterface):
"filter format: %s" % filter_format.name)
for op in net.op:
if (op.type == MaceOp.Conv2D.name
or op.type == MaceOp.Deconv2D.name) \
and op.input[1] not in transposed_filter:
if (op.type == MaceOp.Conv2D.name or
op.type == MaceOp.Deconv2D.name) and\
op.input[1] not in transposed_filter:
filter = self._consts[op.input[1]]
filter_data = np.array(filter.float_data).reshape(
filter.dims)
......@@ -1005,6 +1006,17 @@ class Transformer(base_converter.ConverterInterface):
filter.float_data[:] = filter_data.flat
filter.dims[:] = filter_data.shape
transposed_filter.add(op.input[1])
# deconv's filter's output channel and input channel is reversed
for op in net.op:
if op.type == MaceOp.Deconv2D.name and \
op.input[1] not in transposed_deconv_filter:
filter = self._consts[op.input[1]]
filter_data = np.array(filter.float_data).reshape(
filter.dims)
filter_data = filter_data.transpose(3, 1, 2, 0)
filter.float_data[:] = filter_data.flat
filter.dims[:] = filter_data.shape
transposed_deconv_filter.add(op.input[1])
self.set_filter_format(FilterFormat.OHWI)
else:
......@@ -1045,16 +1057,17 @@ class Transformer(base_converter.ConverterInterface):
transposed_filter.add(op.input[1])
self.set_filter_format(FilterFormat.OIHW)
# deconv's filter's output channel and input channel is reversed
for op in net.op:
if op.type == MaceOp.Deconv2D.name \
and op.input[1] not in transposed_filter:
and op.input[1] not in transposed_deconv_filter:
filter = self._consts[op.input[1]]
filter_data = np.array(filter.float_data).reshape(
filter.dims)
filter_data = filter_data.transpose(1, 0, 2, 3)
filter.float_data[:] = filter_data.flat
filter.dims[:] = filter_data.shape
transposed_filter.add(op.input[1])
transposed_deconv_filter.add(op.input[1])
return False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册