diff --git a/docs/user_guide/op_lists.rst b/docs/user_guide/op_lists.rst index 6c0b4246cd801a7a2c069c625cf8c980bdd7a1a6..c8dd1448c244a86de20aa266c1040f0de4f29848 100644 --- a/docs/user_guide/op_lists.rst +++ b/docs/user_guide/op_lists.rst @@ -20,7 +20,7 @@ Operator lists "DEPTHWISE_CONV_2D","Y","Only multiplier = 1 is supported; Fusion is supported." "DEPTH_TO_SPACE","Y","" "DEQUANTIZE","Y","Model quantization will be supported later." - "ELEMENT_WISE","Y","ADD/MUL/DIV/MIN/MAX/NEG/ABS/SQR_DIFF/POW/RSQRT/EQUAL" + "ELEMENT_WISE","Y","ADD/MUL/DIV/MIN/MAX/NEG/ABS/SQR_DIFF/POW/RSQRT/SQRT/EQUAL" "EMBEDDING_LOOKUP","Y","" "EXPANDDIMS","Y","Only CPU and TensorFlow is supported." "FILL","Y","Only CPU and TensorFlow is supported." diff --git a/mace/python/tools/converter_tool/tensorflow_converter.py b/mace/python/tools/converter_tool/tensorflow_converter.py index a45a84e0ea2b9ed8710ecbe9af33109b03766469..fd07247c095891ab4b6acba3913857be2f46e38d 100644 --- a/mace/python/tools/converter_tool/tensorflow_converter.py +++ b/mace/python/tools/converter_tool/tensorflow_converter.py @@ -113,6 +113,7 @@ TFSupportedOps = [ 'ArgMax', 'Split', 'FakeQuantWithMinMaxVars', + 'Sqrt', ] TFOpType = Enum('TFOpType', [(op, op) for op in TFSupportedOps], type=str) @@ -187,6 +188,7 @@ class TensorflowConverter(base_converter.ConverterInterface): TFOpType.SquaredDifference.name: EltwiseType.SQR_DIFF, TFOpType.Square.name: EltwiseType.POW, TFOpType.Rsqrt.name: EltwiseType.POW, + TFOpType.Sqrt.name: EltwiseType.POW, TFOpType.Equal.name: EltwiseType.EQUAL, } @@ -262,6 +264,7 @@ class TensorflowConverter(base_converter.ConverterInterface): TFOpType.ArgMax.name: self.convert_argmax, TFOpType.Split.name: self.convert_split, TFOpType.FakeQuantWithMinMaxVars.name: self.convert_fake_quantize, + TFOpType.Sqrt.name: self.convert_elementwise, } self._option = option self._mace_net_def = mace_pb2.NetDef() @@ -509,6 +512,10 @@ class TensorflowConverter(base_converter.ConverterInterface): value_arg = op.arg.add() value_arg.name = MaceKeyword.mace_scalar_input_str value_arg.f = -0.5 + elif tf_op.type == TFOpType.Sqrt: + value_arg = op.arg.add() + value_arg.name = MaceKeyword.mace_scalar_input_str + value_arg.f = 0.5 if type_arg.i != EltwiseType.NEG.value \ and type_arg.i != EltwiseType.ABS.value: