diff --git a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc index 283d6263238a7046f726931ba059a4d9cd0d37d5..fbe0b1d6638e386290763798c77ec05fe99858d1 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc @@ -82,7 +82,9 @@ static std::map tbe_func_adapter_map = { {"batch_to_space", "batch_to_space_d"}, {"resize_bilinear", "resize_bilinear_v2_d"}, {"resize_bilinear_grad", "resize_bilinear_v2_grad"}, - {"adam", "apply_adam_d"}}; + {"adam", "apply_adam_d"}, + {"r_oi_align", "roi_align"}, + {"r_oi_align_grad", "roi_align_grad"}}; void TbeAdapter::NormalizeFuncName(std::string *func_name) { if (func_name == nullptr) { diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 6b26263e72bef55149e1c7b3c0e7ee4cb4739dbf..4e8e827e5249ed03fe5d12127caa7cd6c9304c32 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -169,3 +169,5 @@ from .log1p import _log1p_tbe from .resize_bilinear import _resize_bilinear_tbe from .resize_bilinear_grad import _resize_bilinear_grad_tbe from .flatten import _flatten_tbe +from .roi_align import _roi_align_tbe +from .roi_align_grad import _roi_align_grad_tbe diff --git a/mindspore/ops/_op_impl/tbe/roi_align.py b/mindspore/ops/_op_impl/tbe/roi_align.py new file mode 100644 index 0000000000000000000000000000000000000000..bc4eed80ce57e5482e3fd1eec9b92f25e8d04e91 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/roi_align.py @@ -0,0 +1,43 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ROIAlign op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +roi_align_op_info = TBERegOp("ROIAlign") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("roi_align.so") \ + .compute_cost(10) \ + .kernel_name("roi_align") \ + .partial_flag(True) \ + .attr("spatial_scale", "required", "float", "all") \ + .attr("pooled_height", "required", "int", "all") \ + .attr("pooled_width", "required", "int", "all") \ + .attr("sample_num", "optional", "int", "all", "2") \ + .attr("roi_end_mode", "optional", "0,1", "1") \ + .input(0, "features", False, "required", "all") \ + .input(1, "rois", False, "required", "all") \ + .input(2, "rois_n", False, "optional", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.F16_Default, DataType.I32_Default, DataType.F16_5HD) \ + .dtype_format(DataType.F32_5HD, DataType.F32_Default, DataType.I32_Default, DataType.F32_5HD) \ + .get_op_info() + + +@op_info_register(roi_align_op_info) +def _roi_align_tbe(): + """ROIAlign TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/roi_align_grad.py b/mindspore/ops/_op_impl/tbe/roi_align_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..c69fae3fd74900c096d3b368608f0eb041d09575 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/roi_align_grad.py @@ -0,0 +1,42 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ROIAlignGrad op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +roi_align_grad_op_info = TBERegOp("ROIAlignGrad") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("roi_align_grad.so") \ + .compute_cost(10) \ + .kernel_name("roi_align_grad") \ + .partial_flag(True) \ + .attr("xdiff_shape", "required", "listInt", "all") \ + .attr("pooled_width", "required", "int", "all") \ + .attr("pooled_height", "required", "int", "all") \ + .attr("spatial_scale", "required", "float", "all") \ + .attr("sample_num", "optional", "int", "all") \ + .input(0, "ydiff", False, "required", "all") \ + .input(1, "rois", False, "required", "all") \ + .input(2, "rois_n", False, "optional", "all") \ + .output(0, "xdiff", False, "required", "all") \ + .dtype_format(DataType.F32_5HD, DataType.F32_Default, DataType.I32_Default, DataType.F32_5HD) \ + .get_op_info() + + +@op_info_register(roi_align_grad_op_info) +def _roi_align_grad_tbe(): + """ROIAlignGrad TBE register""" + return diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 058df1a0f156a6ff6cc637c3edce20fa5518c327..f399d5ce30022f04ef4cde465ee7c04f2043e3f7 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -942,6 +942,15 @@ test_case_nn_ops = [ 'desc_inputs': [Tensor([[[[1, 2, 3, 4, 5]]]], mstype.float32), Tensor([[[[1, 2, 3, 4, 5]]]], mstype.float32)], 'desc_bprop': [Tensor([[[[1, 2, 3, 4, 5]]]], mstype.float32)], 'skip': ['backward']}), + ('ROIAlign', { + 'block': P.ROIAlign(7, 7, 0.03125, 2), + 'desc_inputs': [[2, 256, 192, 320], [1024, 5]], + 'desc_bprop': [[7,7]]}), + ('ROIAlignGrad', { + 'block': G.ROIAlignGrad((1, 1, 1, 1), 2, 2, 0.5, 2), + 'desc_inputs': [[1, 1, 2, 2], [1, 5]], + 'desc_bprop': [[1, 1, 2, 2]], + 'skip': ['backward']}), ] test_case_array_ops = [