提交 be144525 编写于 作者: 卢旭辉

Merge branch 'transpose' into 'master'

Fix squeeze if not transposable

See merge request deep-computing/mace!1269
......@@ -1411,24 +1411,37 @@ class Transformer(base_converter.ConverterInterface):
return False
def is_transposable_data_format_ops(self, op):
transposable = op.type in MaceTransposableDataFormatOps
if op.type == MaceOp.Reshape:
input_op = self._producer[op.input[0]]
input_dims = input_op.output_shape[0].dims
output_dims = op.output_shape[0].dims
tranposable = True
if len(input_op.output_shape) != 1 or \
len(input_dims) != 4 or len(output_dims) != 4:
tranposable = False
transposable = False
else:
in_b, in_h, in_w, in_c = self.sort_feature_map_shape(
input_dims, ConverterUtil.data_format(input_op))
ou_b, ou_h, ou_w, ou_c = self.sort_feature_map_shape(
output_dims, ConverterUtil.data_format(op))
tranposable = (in_b == ou_b and in_c == ou_c)
if not tranposable:
print("In this model, reshape is not transposable op.")
return tranposable
return op.type in MaceTransposableDataFormatOps
transposable = (in_b == ou_b and in_c == ou_c)
elif op.type == MaceOp.Squeeze:
input_dims = self._producer[op.input[0]].output_shape[0].dims
output_dims = op.output_shape[0].dims
src_df = ConverterUtil.data_format(self._model)
arg = ConverterUtil.get_arg(op, MaceKeyword.mace_axis_str)
if len(input_dims) == 4 and len(output_dims) == 2 and \
((src_df == DataFormat.NCHW and arg.ints == [2, 3]) or
(src_df == DataFormat.NHWC and arg.ints == [1, 2])):
transposable = True
else:
transposable = False
if op.type in MaceTransposableDataFormatOps and not transposable:
print("%s(%s) is not a transposable op in this model."
% (op.name, op.type))
return transposable
def update_data_format(self):
print("update data format")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册