提交 0030e3bc 编写于 作者: 刘琦

Merge branch 'fix_multi_inputs_converter' into 'master'

Fix pad and conv dilation issues

See merge request !527
......@@ -294,7 +294,12 @@ class TensorflowConverter(base_converter.ConverterInterface):
if op.type != MaceOp.Deconv2D.name:
dilation_arg = op.arg.add()
dilation_arg.name = MaceKeyword.mace_dilations_str
dilation_arg.ints.extend(tf_op.get_attr(tf_dilations_str)[1:3])
try:
dilation_val = tf_op.get_attr(tf_dilations_str)[1:3]
except ValueError:
dilation_val = [1, 1]
dilation_arg.ints.extend(dilation_val)
def convert_elementwise(self, tf_op):
op = self.convert_general_op(tf_op)
......
......@@ -713,18 +713,21 @@ class Transformer(base_converter.ConverterInterface):
# transpose args
if op.type == MaceOp.Pad.name:
for arg in op.arg:
if arg.name == MaceKeyword.mace_paddings_str and len(
arg.ints) == 4:
if arg.name == MaceKeyword.mace_paddings_str:
mace_check(len(arg.ints) == 8,
"pad dim rank should be 8.")
if ConverterUtil.data_format(op) == DataFormat.NHWC \
and self._target_data_format == DataFormat.NCHW: # noqa
print("Transpose pad args: %s(%s)"
% (op.name, op.type))
self.transpose_shape(arg.ints, [0, 3, 1, 2])
self.transpose_shape(arg.ints,
[0, 1, 6, 7, 2, 3, 4, 5])
elif ConverterUtil.data_format(op) == DataFormat.NCHW \
and self._target_data_format == DataFormat.NHWC: # noqa
print("Transpose pad args: %s(%s)"
% (op.name, op.type))
self.transpose_shape(arg.ints, [0, 2, 3, 1])
self.transpose_shape(arg.ints,
[0, 1, 4, 5, 6, 7, 2, 3])
elif op.type == MaceOp.Concat.name or op.type == MaceOp.Slice.name:
for arg in op.arg:
if arg.name == MaceKeyword.mace_axis_str:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册