提交 1d11cac2 编写于 作者: W w-adamski 提交者: Liangliang He

Added sqrt() to tensorflow converter (#343)

上级 fef913af
...@@ -20,7 +20,7 @@ Operator lists ...@@ -20,7 +20,7 @@ Operator lists
"DEPTHWISE_CONV_2D","Y","Only multiplier = 1 is supported; Fusion is supported." "DEPTHWISE_CONV_2D","Y","Only multiplier = 1 is supported; Fusion is supported."
"DEPTH_TO_SPACE","Y","" "DEPTH_TO_SPACE","Y",""
"DEQUANTIZE","Y","Model quantization will be supported later." "DEQUANTIZE","Y","Model quantization will be supported later."
"ELEMENT_WISE","Y","ADD/MUL/DIV/MIN/MAX/NEG/ABS/SQR_DIFF/POW/RSQRT/EQUAL" "ELEMENT_WISE","Y","ADD/MUL/DIV/MIN/MAX/NEG/ABS/SQR_DIFF/POW/RSQRT/SQRT/EQUAL"
"EMBEDDING_LOOKUP","Y","" "EMBEDDING_LOOKUP","Y",""
"EXPANDDIMS","Y","Only CPU and TensorFlow is supported." "EXPANDDIMS","Y","Only CPU and TensorFlow is supported."
"FILL","Y","Only CPU and TensorFlow is supported." "FILL","Y","Only CPU and TensorFlow is supported."
......
...@@ -113,6 +113,7 @@ TFSupportedOps = [ ...@@ -113,6 +113,7 @@ TFSupportedOps = [
'ArgMax', 'ArgMax',
'Split', 'Split',
'FakeQuantWithMinMaxVars', 'FakeQuantWithMinMaxVars',
'Sqrt',
] ]
TFOpType = Enum('TFOpType', [(op, op) for op in TFSupportedOps], type=str) TFOpType = Enum('TFOpType', [(op, op) for op in TFSupportedOps], type=str)
...@@ -187,6 +188,7 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -187,6 +188,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType.SquaredDifference.name: EltwiseType.SQR_DIFF, TFOpType.SquaredDifference.name: EltwiseType.SQR_DIFF,
TFOpType.Square.name: EltwiseType.POW, TFOpType.Square.name: EltwiseType.POW,
TFOpType.Rsqrt.name: EltwiseType.POW, TFOpType.Rsqrt.name: EltwiseType.POW,
TFOpType.Sqrt.name: EltwiseType.POW,
TFOpType.Equal.name: EltwiseType.EQUAL, TFOpType.Equal.name: EltwiseType.EQUAL,
} }
...@@ -262,6 +264,7 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -262,6 +264,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType.ArgMax.name: self.convert_argmax, TFOpType.ArgMax.name: self.convert_argmax,
TFOpType.Split.name: self.convert_split, TFOpType.Split.name: self.convert_split,
TFOpType.FakeQuantWithMinMaxVars.name: self.convert_fake_quantize, TFOpType.FakeQuantWithMinMaxVars.name: self.convert_fake_quantize,
TFOpType.Sqrt.name: self.convert_elementwise,
} }
self._option = option self._option = option
self._mace_net_def = mace_pb2.NetDef() self._mace_net_def = mace_pb2.NetDef()
...@@ -509,6 +512,10 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -509,6 +512,10 @@ class TensorflowConverter(base_converter.ConverterInterface):
value_arg = op.arg.add() value_arg = op.arg.add()
value_arg.name = MaceKeyword.mace_scalar_input_str value_arg.name = MaceKeyword.mace_scalar_input_str
value_arg.f = -0.5 value_arg.f = -0.5
elif tf_op.type == TFOpType.Sqrt:
value_arg = op.arg.add()
value_arg.name = MaceKeyword.mace_scalar_input_str
value_arg.f = 0.5
if type_arg.i != EltwiseType.NEG.value \ if type_arg.i != EltwiseType.NEG.value \
and type_arg.i != EltwiseType.ABS.value: and type_arg.i != EltwiseType.ABS.value:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册