diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index b30eda03ea2fbf9998e3877f527bb3e845268bce..99aff51b5698b8b0f7f13995012f9f6ee07f90e6 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -189,3 +189,5 @@ from .pack import _pack_tbe from .unpack import _unpack_tbe from .prelu import _prelu_tbe from .prelu_grad import _prelu_grad_tbe +from .binary_cross_entropy import _binary_cross_entropy_tbe +from .binary_cross_entropy_grad import _binary_cross_entropy_grad_tbe diff --git a/mindspore/ops/_op_impl/tbe/binary_cross_entropy.py b/mindspore/ops/_op_impl/tbe/binary_cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..bbb4dcab0bdb6e6ee9660c39e37d3e212f4469d1 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/binary_cross_entropy.py @@ -0,0 +1,41 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""BinaryCrossEntropy op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +binary_cross_entropy_op_info = TBERegOp("BinaryCrossEntropy") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("binary_cross_entropy.so") \ + .compute_cost(10) \ + .kernel_name("binary_cross_entropy") \ + .partial_flag(True) \ + .attr("reduction", "optional", "str", "all") \ + .input(0, "x", False, "required", "all") \ + .input(1, "y", False, "required", "all") \ + .input(2, "weight", False, "optional", "all") \ + .output(0, "output", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ + .get_op_info() + + +@op_info_register(binary_cross_entropy_op_info) +def _binary_cross_entropy_tbe(): + """BinaryCrossEntropy TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad.py b/mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..9813d448055098939c66ab31aa0ea66f93975c09 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad.py @@ -0,0 +1,44 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""BinaryCrossEntropyGrad op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +binary_cross_entropy_grad_op_info = TBERegOp("BinaryCrossEntropyGrad") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("binary_cross_entropy_grad.so") \ + .compute_cost(10) \ + .kernel_name("binary_cross_entropy_grad") \ + .partial_flag(True) \ + .attr("reduction", "optional", "str", "all") \ + .input(0, "x", False, "required", "all") \ + .input(1, "y", False, "required", "all") \ + .input(2, "grad_output", False, "required", "all") \ + .input(3, "weight", False, "optional", "all") \ + .output(0, "output", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, + DataType.F16_Default) \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ + .get_op_info() + + +@op_info_register(binary_cross_entropy_grad_op_info) +def _binary_cross_entropy_grad_tbe(): + """BinaryCrossEntropyGrad TBE register""" + return diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 3509fb1d02dd3245468a39236c25d1ea2d5a0d0b..de884259c9c8fc3da1da461ec32b397460f0f79d 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -972,6 +972,19 @@ test_case_nn_ops = [ 'desc_inputs': [[3, 3], [3, 3], Tensor(0.001, mstype.float32), [3, 3], Tensor(0.1, mstype.float32), [3, 3]], 'desc_bprop': [3, 3], 'skip': ['backward']}), + ('BinaryCrossEntropy', { + 'block': P.BinaryCrossEntropy(), + 'desc_inputs': [Tensor([[0.3, 0.8], [0.4, 0.3]], mstype.float16), + Tensor([[0.4, 1.2], [-0.4, -0.9]], mstype.float16), + Tensor([[-1.4, -0.7], [0.9, 0.7]], mstype.float16)], + 'desc_bprop': []}), + ('BinaryCrossEntropyGrad', { + 'block': G.BinaryCrossEntropyGrad(), + 'desc_inputs': [Tensor([[0.3, 0.8], [0.4, 0.3]], mstype.float16), + Tensor([[0.4, 1.2], [-0.4, -0.9]], mstype.float16), Tensor(0.85, mstype.float16), + Tensor([[-1.4, -0.7], [0.9, 0.7]], mstype.float16)], + 'desc_bprop': [], + 'skip': ['backward']}), ] test_case_array_ops = [