diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index ce1e02e915286d06d52a2921aa213cb8c780cffd..1b09b50cdd03ca27e074e2ec2a9bfc330eafba95 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -153,3 +153,4 @@ from .floor_mod import _floor_mod_tbe from .scatter_nd_update import _scatter_nd_update_tbe from .avg_pool import _avg_pool_tbe from .avg_pool_grad import _avg_pool_grad_tbe +from .ones_like import _ones_like_tbe diff --git a/mindspore/ops/_op_impl/tbe/ones_like.py b/mindspore/ops/_op_impl/tbe/ones_like.py new file mode 100644 index 0000000000000000000000000000000000000000..ae6871cac5b357aa54e5b824ec55d96085b6cbdf --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/ones_like.py @@ -0,0 +1,39 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""OnesLike op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +ones_like_op_info = TBERegOp("OnesLike") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("ones_like.so") \ + .compute_cost(10) \ + .kernel_name("ones_like") \ + .partial_flag(True) \ + .input(0, "x", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.U8_Default, DataType.U8_Default) \ + .dtype_format(DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(ones_like_op_info) +def _ones_like_tbe(): + """OnesLike TBE register""" + return