提交 7b5a9af2 编写于 作者: Y yejianwu

fix gpu squeeze

上级 4b40d95c
......@@ -196,6 +196,7 @@ MaceTransposableDataFormatOps = [MaceOp.Activation,
MaceOp.Reduce,
MaceOp.Softmax,
MaceOp.Split,
MaceOp.Squeeze,
MaceOp.SqrDiffMean]
......
......@@ -1431,6 +1431,18 @@ class Transformer(base_converter.ConverterInterface):
if axis_arg.i == 1:
axis_arg.i = 3
elif op.type == MaceOp.Squeeze.name:
for arg in op.arg:
if arg.name == MaceKeyword.mace_axis_str:
if (src_data_format == DataFormat.NCHW
and has_data_format
and len(self._producer[op.input[0]].output_shape[0].dims) == 4 # noqa
and len(op.output_shape[0].dims) == 2
and arg.ints == [2, 3]):
print("Transpose squeeze args: %s(%s)"
% (op.name, op.type))
arg.ints[:] = [1, 2]
elif op.type == MaceOp.Reduce.name:
for arg in op.arg:
if arg.name == MaceKeyword.mace_axis_str:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册