提交 107b956a 编写于 作者: 叶剑武

fix pad converter and support padv2

上级 2dbd1eb7
...@@ -102,6 +102,7 @@ TFSupportedOps = [ ...@@ -102,6 +102,7 @@ TFSupportedOps = [
'DepthToSpace', 'DepthToSpace',
'SpaceToDepth', 'SpaceToDepth',
'Pad', 'Pad',
'PadV2',
'ConcatV2', 'ConcatV2',
'Mean', 'Mean',
'Prod', 'Prod',
...@@ -252,6 +253,7 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -252,6 +253,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType.DepthToSpace.name: self.convert_space_depth, TFOpType.DepthToSpace.name: self.convert_space_depth,
TFOpType.SpaceToDepth.name: self.convert_space_depth, TFOpType.SpaceToDepth.name: self.convert_space_depth,
TFOpType.Pad.name: self.convert_pad, TFOpType.Pad.name: self.convert_pad,
TFOpType.PadV2.name: self.convert_pad,
TFOpType.ConcatV2.name: self.convert_concat, TFOpType.ConcatV2.name: self.convert_concat,
TFOpType.Const.name: self.convert_nop, TFOpType.Const.name: self.convert_nop,
TFOpType.Gather.name: self.convert_gather, TFOpType.Gather.name: self.convert_gather,
...@@ -785,13 +787,19 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -785,13 +787,19 @@ class TensorflowConverter(base_converter.ConverterInterface):
pad_type_arg = op.arg.add() pad_type_arg = op.arg.add()
pad_type_arg.name = MaceKeyword.mace_pad_type_str pad_type_arg.name = MaceKeyword.mace_pad_type_str
if tf_op.type == TFOpType.Pad: if tf_op.type == TFOpType.Pad or tf_op.type == TFOpType.PadV2:
if len(tf_op.inputs) == 3: if len(tf_op.inputs) == 3:
constant_value_arg = op.arg.add() constant_value_arg = op.arg.add()
constant_value_arg.name = MaceKeyword.mace_constant_value_str constant_value_arg.name = MaceKeyword.mace_constant_value_str
constant_value = tf_op.inputs[2].eval().astype(np.int32) \ constant_value = tf_op.inputs[2].eval().flat[0]
.flat[0] tf_dt = tf_op.inputs[2].dtype
if tf_dt == tf.float32:
constant_value_arg.f = constant_value
elif tf_dt == tf.int32:
constant_value_arg.i = constant_value constant_value_arg.i = constant_value
else:
mace_check(False,
"Unsupported data type: %s" % tf_dt.name)
self._skip_tensor.add(tf_op.inputs[2].name) self._skip_tensor.add(tf_op.inputs[2].name)
pad_type_arg.i = PadType.CONSTANT.value pad_type_arg.i = PadType.CONSTANT.value
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册