From 70c80c05951b37c74837e07e948da80ea41aee12 Mon Sep 17 00:00:00 2001 From: buxue Date: Mon, 27 Apr 2020 17:41:15 +0800 Subject: [PATCH] dock FloorMod GreaterEqual NotEqual ScatterNdUpdate --- mindspore/ops/_op_impl/tbe/__init__.py | 6 ++- .../ops/_op_impl/tbe/{fill_d.py => fill.py} | 4 +- mindspore/ops/_op_impl/tbe/floor_mod.py | 38 ++++++++++++++++ mindspore/ops/_op_impl/tbe/greater_equal.py | 45 +++++++++++++++++++ mindspore/ops/_op_impl/tbe/not_equal.py | 45 +++++++++++++++++++ mindspore/ops/_op_impl/tbe/scatter_nd.py | 2 +- .../ops/_op_impl/tbe/scatter_nd_update.py | 42 +++++++++++++++++ 7 files changed, 178 insertions(+), 4 deletions(-) rename mindspore/ops/_op_impl/tbe/{fill_d.py => fill.py} (97%) create mode 100644 mindspore/ops/_op_impl/tbe/floor_mod.py create mode 100644 mindspore/ops/_op_impl/tbe/greater_equal.py create mode 100644 mindspore/ops/_op_impl/tbe/not_equal.py create mode 100644 mindspore/ops/_op_impl/tbe/scatter_nd_update.py diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 8030aac5c..f9240ee32 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -142,8 +142,12 @@ from .smooth_l1_loss_grad import _smooth_l1_loss_grad_tbe from .fused_mul_add import _fused_mul_add_tbe from .fused_mul_add_n import _fused_mul_add_n_tbe from .fused_mul_apply_momentum import _fused_mul_apply_momentum_tbe -from .fill_d import _fill_d_op_tbe +from .fill import _fill_op_tbe from .erf import _erf_op_tbe from .depthwise_conv2d import _depthwise_conv2d_tbe from .depthwise_conv2d_backprop_filter import _depthwise_conv2d_backprop_filter_tbe from .depthwise_conv2d_backprop_input import _depthwise_conv2d_backprop_input_tbe +from .greater_equal import _greater_equal_tbe +from .not_equal import _not_equal_tbe +from .floor_mod import _floor_mod_tbe +from .scatter_nd_update import _scatter_nd_update_tbe diff --git a/mindspore/ops/_op_impl/tbe/fill_d.py b/mindspore/ops/_op_impl/tbe/fill.py similarity index 97% rename from mindspore/ops/_op_impl/tbe/fill_d.py rename to mindspore/ops/_op_impl/tbe/fill.py index 97c6b73cf..90301f123 100644 --- a/mindspore/ops/_op_impl/tbe/fill_d.py +++ b/mindspore/ops/_op_impl/tbe/fill.py @@ -16,7 +16,7 @@ """FillD op""" from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType -fill_d_op_info = TBERegOp("FillD") \ +fill_d_op_info = TBERegOp("Fill") \ .fusion_type("ELEMWISE") \ .async_flag(False) \ .binfile_name("fill_d.so") \ @@ -50,6 +50,6 @@ fill_d_op_info = TBERegOp("FillD") \ @op_info_register(fill_d_op_info) -def _fill_d_op_tbe(): +def _fill_op_tbe(): """FillD TBE register""" return diff --git a/mindspore/ops/_op_impl/tbe/floor_mod.py b/mindspore/ops/_op_impl/tbe/floor_mod.py new file mode 100644 index 000000000..031f160e0 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/floor_mod.py @@ -0,0 +1,38 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""FloorMod op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +floor_mod_op_info = TBERegOp("FloorMod") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("floor_mod.so") \ + .compute_cost(10) \ + .kernel_name("floor_mod") \ + .partial_flag(True) \ + .input(0, "x1", False, "required", "all") \ + .input(1, "x2", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .get_op_info() + + +@op_info_register(floor_mod_op_info) +def _floor_mod_tbe(): + """FloorMod TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/greater_equal.py b/mindspore/ops/_op_impl/tbe/greater_equal.py new file mode 100644 index 000000000..5609f15f1 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/greater_equal.py @@ -0,0 +1,45 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""GreaterEqual op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +greater_equal_op_info = TBERegOp("GreaterEqual") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("greater_equal.so") \ + .compute_cost(10) \ + .kernel_name("greater_equal") \ + .partial_flag(True) \ + .input(0, "x1", False, "required", "all") \ + .input(1, "x2", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.BOOL_5HD) \ + .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.BOOL_5HD) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.BOOL_5HD) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.BOOL_5HD) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.BOOL_5HD) \ + .get_op_info() + + +@op_info_register(greater_equal_op_info) +def _greater_equal_tbe(): + """Greater TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/not_equal.py b/mindspore/ops/_op_impl/tbe/not_equal.py new file mode 100644 index 000000000..bd801d9a4 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/not_equal.py @@ -0,0 +1,45 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""NotEqual op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +not_equal_op_info = TBERegOp("NotEqual") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("not_equal.so") \ + .compute_cost(10) \ + .kernel_name("not_equal") \ + .partial_flag(True) \ + .input(0, "x1", False, "required", "all") \ + .input(1, "x2", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.BOOL_5HD) \ + .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.BOOL_5HD) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.BOOL_5HD) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.BOOL_5HD) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.BOOL_5HD) \ + .get_op_info() + + +@op_info_register(not_equal_op_info) +def _not_equal_tbe(): + """Equal TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/scatter_nd.py b/mindspore/ops/_op_impl/tbe/scatter_nd.py index 6c9eae3ad..168b34582 100644 --- a/mindspore/ops/_op_impl/tbe/scatter_nd.py +++ b/mindspore/ops/_op_impl/tbe/scatter_nd.py @@ -37,5 +37,5 @@ scatter_nd_op_info = TBERegOp("ScatterNd") \ @op_info_register(scatter_nd_op_info) def _scatter_nd_tbe(): - """Conv2D TBE register""" + """ScatterNd TBE register""" return diff --git a/mindspore/ops/_op_impl/tbe/scatter_nd_update.py b/mindspore/ops/_op_impl/tbe/scatter_nd_update.py new file mode 100644 index 000000000..df0996f26 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/scatter_nd_update.py @@ -0,0 +1,42 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ScatterNdUpdate op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +scatter_nd_update_op_info = TBERegOp("ScatterNdUpdate") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("scatter_nd_update.so") \ + .compute_cost(10) \ + .kernel_name("scatter_nd_update") \ + .partial_flag(True) \ + .attr("use_locking", "optional", "bool", "all") \ + .input(0, "var", False, "required", "all") \ + .input(1, "indices", False, "required", "all") \ + .input(1, "updates", False, "required", "all") \ + .output(0, "var", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default,) \ + .dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ + .get_op_info() + + +@op_info_register(scatter_nd_update_op_info) +def _scatter_nd_update_tbe(): + """ScatterNdUpdate TBE register""" + return -- GitLab