提交 be56d6a3 编写于 作者: 刘托

Merge branch 'fix_gpu_squeeze' into 'master'

fix gpu squeeze

See merge request !1129
...@@ -196,6 +196,7 @@ MaceTransposableDataFormatOps = [MaceOp.Activation, ...@@ -196,6 +196,7 @@ MaceTransposableDataFormatOps = [MaceOp.Activation,
MaceOp.Reduce, MaceOp.Reduce,
MaceOp.Softmax, MaceOp.Softmax,
MaceOp.Split, MaceOp.Split,
MaceOp.Squeeze,
MaceOp.SqrDiffMean] MaceOp.SqrDiffMean]
......
...@@ -1431,6 +1431,18 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1431,6 +1431,18 @@ class Transformer(base_converter.ConverterInterface):
if axis_arg.i == 1: if axis_arg.i == 1:
axis_arg.i = 3 axis_arg.i = 3
elif op.type == MaceOp.Squeeze.name:
for arg in op.arg:
if arg.name == MaceKeyword.mace_axis_str:
if (src_data_format == DataFormat.NCHW
and has_data_format
and len(self._producer[op.input[0]].output_shape[0].dims) == 4 # noqa
and len(op.output_shape[0].dims) == 2
and arg.ints == [2, 3]):
print("Transpose squeeze args: %s(%s)"
% (op.name, op.type))
arg.ints[:] = [1, 2]
elif op.type == MaceOp.Reduce.name: elif op.type == MaceOp.Reduce.name:
for arg in op.arg: for arg in op.arg:
if arg.name == MaceKeyword.mace_axis_str: if arg.name == MaceKeyword.mace_axis_str:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册