From 1d11cac2a95bf336a9fb95c0780a09eaf385547d Mon Sep 17 00:00:00 2001 From: w-adamski <43744870+w-adamski@users.noreply.github.com> Date: Mon, 11 Feb 2019 02:48:09 +0100 Subject: [PATCH] Added sqrt() to tensorflow converter (#343) --- docs/user_guide/op_lists.rst | 2 +- mace/python/tools/converter_tool/tensorflow_converter.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/docs/user_guide/op_lists.rst b/docs/user_guide/op_lists.rst index 6c0b4246..c8dd1448 100644 --- a/docs/user_guide/op_lists.rst +++ b/docs/user_guide/op_lists.rst @@ -20,7 +20,7 @@ Operator lists "DEPTHWISE_CONV_2D","Y","Only multiplier = 1 is supported; Fusion is supported." "DEPTH_TO_SPACE","Y","" "DEQUANTIZE","Y","Model quantization will be supported later." - "ELEMENT_WISE","Y","ADD/MUL/DIV/MIN/MAX/NEG/ABS/SQR_DIFF/POW/RSQRT/EQUAL" + "ELEMENT_WISE","Y","ADD/MUL/DIV/MIN/MAX/NEG/ABS/SQR_DIFF/POW/RSQRT/SQRT/EQUAL" "EMBEDDING_LOOKUP","Y","" "EXPANDDIMS","Y","Only CPU and TensorFlow is supported." "FILL","Y","Only CPU and TensorFlow is supported." diff --git a/mace/python/tools/converter_tool/tensorflow_converter.py b/mace/python/tools/converter_tool/tensorflow_converter.py index a45a84e0..fd07247c 100644 --- a/mace/python/tools/converter_tool/tensorflow_converter.py +++ b/mace/python/tools/converter_tool/tensorflow_converter.py @@ -113,6 +113,7 @@ TFSupportedOps = [ 'ArgMax', 'Split', 'FakeQuantWithMinMaxVars', + 'Sqrt', ] TFOpType = Enum('TFOpType', [(op, op) for op in TFSupportedOps], type=str) @@ -187,6 +188,7 @@ class TensorflowConverter(base_converter.ConverterInterface): TFOpType.SquaredDifference.name: EltwiseType.SQR_DIFF, TFOpType.Square.name: EltwiseType.POW, TFOpType.Rsqrt.name: EltwiseType.POW, + TFOpType.Sqrt.name: EltwiseType.POW, TFOpType.Equal.name: EltwiseType.EQUAL, } @@ -262,6 +264,7 @@ class TensorflowConverter(base_converter.ConverterInterface): TFOpType.ArgMax.name: self.convert_argmax, TFOpType.Split.name: self.convert_split, TFOpType.FakeQuantWithMinMaxVars.name: self.convert_fake_quantize, + TFOpType.Sqrt.name: self.convert_elementwise, } self._option = option self._mace_net_def = mace_pb2.NetDef() @@ -509,6 +512,10 @@ class TensorflowConverter(base_converter.ConverterInterface): value_arg = op.arg.add() value_arg.name = MaceKeyword.mace_scalar_input_str value_arg.f = -0.5 + elif tf_op.type == TFOpType.Sqrt: + value_arg = op.arg.add() + value_arg.name = MaceKeyword.mace_scalar_input_str + value_arg.f = 0.5 if type_arg.i != EltwiseType.NEG.value \ and type_arg.i != EltwiseType.ABS.value: -- GitLab