diff --git a/mindspore/_extends/parallel_compile/tbe_compiler/compiler.py b/mindspore/_extends/parallel_compile/tbe_compiler/compiler.py index 8af0f8d3b4055d774f82ef74bd654659802ad200..c385f7dee06ef98d19cf3dee6675205b83f18b87 100755 --- a/mindspore/_extends/parallel_compile/tbe_compiler/compiler.py +++ b/mindspore/_extends/parallel_compile/tbe_compiler/compiler.py @@ -114,6 +114,9 @@ def build_op(build_type, json_str): return get_op_pattern() # call function + if kernel_name[0:19] == "bounding_box_encode": + return op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name_val=kernel_name) + return op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name) except Exception as e: diff --git a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc index fbe0b1d6638e386290763798c77ec05fe99858d1..9da7ddd71e2194ec61ac4930ab723b4d8d9a8b62 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc @@ -84,7 +84,11 @@ static std::map tbe_func_adapter_map = { {"resize_bilinear_grad", "resize_bilinear_v2_grad"}, {"adam", "apply_adam_d"}, {"r_oi_align", "roi_align"}, - {"r_oi_align_grad", "roi_align_grad"}}; + {"r_oi_align_grad", "roi_align_grad"}, + {"i_ou", "iou"}, + {"s_gd", "sgd"}, + {"l_ars_update", "lars_v2_update"}, + {"n_ms_with_mask", "nms_with_mask"}}; void TbeAdapter::NormalizeFuncName(std::string *func_name) { if (func_name == nullptr) { diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc b/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc index 35413247970f7943fd4db717a62d60e205dfdf20..161ca39383e4dacd03c48960e5909b4c0a7e9e17 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc @@ -430,6 +430,18 @@ void TbeKernelJsonCreator::ParseAttrValue(const std::string &type, const mindspo attr_value = GetValue>(value); } (*attr_obj)["value"] = attr_value; + } else if (type == "listFloat") { + std::vector attr_value; + auto value_type = value->type(); + MS_EXCEPTION_IF_NULL(value_type); + auto value_type_str = value_type->ToString(); + if (value_type_str == "float") { + float data = GetValue(value); + attr_value.push_back(data); + } else { + attr_value = GetValue>(value); + } + (*attr_obj)["value"] = attr_value; } else if (type == "listListInt") { auto attr_value = GetValue>>(value); (*attr_obj)["value"] = attr_value; diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 4e8e827e5249ed03fe5d12127caa7cd6c9304c32..f47285309dbe2a3b6df86da95cf62370ebabbab0 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -171,3 +171,11 @@ from .resize_bilinear_grad import _resize_bilinear_grad_tbe from .flatten import _flatten_tbe from .roi_align import _roi_align_tbe from .roi_align_grad import _roi_align_grad_tbe +from .bounding_box_decode import _bounding_box_decode_tbe +from .bounding_box_encode import _bounding_box_encode_tbe +from .check_valid import _check_valid_tbe +from .iou import _iou_tbe +from .nms_with_mask import nms_with_mask_op_info +from .random_choice_with_mask import random_choice_with_mask_op_info +from .sgd import sgd_op_info +from .lars_update import lars_update_op_info diff --git a/mindspore/ops/_op_impl/tbe/bounding_box_decode.py b/mindspore/ops/_op_impl/tbe/bounding_box_decode.py new file mode 100644 index 0000000000000000000000000000000000000000..e3f2f5b057f8e2568c69872681c093351a7aa75c --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/bounding_box_decode.py @@ -0,0 +1,41 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""BoundingBoxDecode op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +bounding_box_decode_op_info = TBERegOp("BoundingBoxDecode") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("bounding_box_decode.so") \ + .compute_cost(10) \ + .kernel_name("bounding_box_decode") \ + .partial_flag(True) \ + .attr("means", "optional", "listFloat", "all") \ + .attr("stds", "optional", "listFloat", "all") \ + .attr("max_shape", "optional", "listInt", "all") \ + .attr("wh_ratio_clip", "optional", "float", "all") \ + .input(0, "rois", False, "required", "all") \ + .input(1, "deltas", False, "required", "all") \ + .output(0, "bboxes", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(bounding_box_decode_op_info) +def _bounding_box_decode_tbe(): + """BoundingBoxDecode TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/bounding_box_encode.py b/mindspore/ops/_op_impl/tbe/bounding_box_encode.py new file mode 100644 index 0000000000000000000000000000000000000000..e2c2306b647a6998174fc410fea303c3760169ab --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/bounding_box_encode.py @@ -0,0 +1,38 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""BoundingBoxEncode op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +bounding_box_encode_op_info = TBERegOp("BoundingBoxEncode") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("bounding_box_encode.so") \ + .compute_cost(10) \ + .kernel_name("bounding_box_encode") \ + .partial_flag(True) \ + .attr("means", "optional", "listFloat", "all") \ + .attr("stds", "optional", "listFloat", "all") \ + .input(0, "anchor_box", False, "required", "all") \ + .input(1, "ground_truth_box", False, "required", "all") \ + .output(0, "delats", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .get_op_info() + + +@op_info_register(bounding_box_encode_op_info) +def _bounding_box_encode_tbe(): + """BoundingBoxEncode TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/check_valid.py b/mindspore/ops/_op_impl/tbe/check_valid.py new file mode 100644 index 0000000000000000000000000000000000000000..9c489b64c55cd8ad83d6fae11acd280ef0691251 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/check_valid.py @@ -0,0 +1,37 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""CheckValid op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +check_valid_op_info = TBERegOp("CheckValid") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("check_valid.so") \ + .compute_cost(10) \ + .kernel_name("check_valid") \ + .partial_flag(True) \ + .input(0, "bbox_tensor", False, "required", "all") \ + .input(1, "img_tas", False, "required", "all") \ + .output(0, "valid_tensor", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I8_Default) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \ + .get_op_info() + + +@op_info_register(check_valid_op_info) +def _check_valid_tbe(): + """CheckValid TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/iou.py b/mindspore/ops/_op_impl/tbe/iou.py new file mode 100644 index 0000000000000000000000000000000000000000..950bd8c1474a43e91eddfe32450f5ce068041ffc --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/iou.py @@ -0,0 +1,37 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Iou op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +iou_op_info = TBERegOp("IOU") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("iou.so") \ + .compute_cost(10) \ + .kernel_name("iou") \ + .partial_flag(True) \ + .attr("mode", "required", "str", "all") \ + .input(0, "bboxes", False, "required", "all") \ + .input(1, "gtboxes", False, "required", "all") \ + .output(0, "overlap", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .get_op_info() + + +@op_info_register(iou_op_info) +def _iou_tbe(): + """Iou TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/lars_update.py b/mindspore/ops/_op_impl/tbe/lars_update.py new file mode 100644 index 0000000000000000000000000000000000000000..1f23b8d8729352520dad3cc2ffb93e28ab47a4e9 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/lars_update.py @@ -0,0 +1,50 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""LarsUpdate op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +lars_update_op_info = TBERegOp("LARSUpdate") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("lars_v2_update.so") \ + .compute_cost(10) \ + .kernel_name("lars_v2_update") \ + .partial_flag(True) \ + .attr("hyperpara", "optional", "float", "all") \ + .attr("epsilon", "optional", "float", "all") \ + .attr("use_clip", "optional", "bool", "all") \ + .input(0, "w", False, "required", "all") \ + .input(1, "g", False, "required", "all") \ + .input(2, "w_square_sum", False, "required", "all") \ + .input(3, "g_square_sum", False, "required", "all") \ + .input(4, "weight_decay", False, "required", "all") \ + .input(5, "learning_rate", False, "required", "all") \ + .output(0, "g_new", False, "required", "all") \ + .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_FracZ) \ + .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_C1HWNCoC0) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_5HD) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(lars_update_op_info) +def _lars_update_tbe(): + """LarsUpdate TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/nms_with_mask.py b/mindspore/ops/_op_impl/tbe/nms_with_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..0b9d1af3db677092f31e944a13416ca6c0d5c57c --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/nms_with_mask.py @@ -0,0 +1,39 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""NMSWithMask op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +nms_with_mask_op_info = TBERegOp("NMSWithMask") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("nms_with_mask.so") \ + .compute_cost(10) \ + .kernel_name("nms_with_mask") \ + .partial_flag(True) \ + .attr("iou_threshold", "optional", "float", "all") \ + .input(0, "box_scores", False, "required", "all") \ + .output(0, "selected_boxes", False, "required", "all") \ + .output(0, "selected_idx", False, "required", "all") \ + .output(0, "selected_mask", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I32_Default, DataType.U8_Default) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I32_Default, DataType.BOOL_Default) \ + .get_op_info() + + +@op_info_register(nms_with_mask_op_info) +def _nms_with_mask_tbe(): + """NMSWithMask TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/random_choice_with_mask.py b/mindspore/ops/_op_impl/tbe/random_choice_with_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..3cc37e7060e9f94866e8b134c07472e9a0e38651 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/random_choice_with_mask.py @@ -0,0 +1,41 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""RandomChoiceWithMask op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +random_choice_with_mask_op_info = TBERegOp("RandomChoiceWithMask") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("random_choice_with_mask.so") \ + .compute_cost(10) \ + .kernel_name("random_choice_with_mask") \ + .partial_flag(True) \ + .attr("max_shape", "optional", "listInt", "all") \ + .attr("means", "optional", "listFloat", "all") \ + .attr("stds", "optional", "listFloat", "all") \ + .attr("wh_ratio_clip", "optional", "float", "all") \ + .input(0, "rois", False, "required", "all") \ + .input(1, "deltas", False, "required", "all") \ + .output(0, "bboxes", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(random_choice_with_mask_op_info) +def _random_choice_with_mask_tbe(): + """RandomChoiceWithMask TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/sgd.py b/mindspore/ops/_op_impl/tbe/sgd.py new file mode 100644 index 0000000000000000000000000000000000000000..64ecc9272ed2287ea8c60247e801a96dffca32cc --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/sgd.py @@ -0,0 +1,54 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""SGD op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +sgd_op_info = TBERegOp("SGD") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("sgd.so") \ + .compute_cost(10) \ + .kernel_name("sgd") \ + .partial_flag(True) \ + .attr("dampening", "optional", "float", "all") \ + .attr("weight_decay", "optional", "float", "all") \ + .attr("nesterov", "optional", "bool", "all") \ + .input(0, "parameters", False, "required", "all") \ + .input(1, "gradient", False, "required", "all") \ + .input(2, "learning_rate", False, "required", "all") \ + .input(3, "accum", False, "required", "all") \ + .input(4, "momentum", False, "required", "all") \ + .input(5, "stat", False, "required", "all") \ + .output(0, "parameters", False, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_5HD, + DataType.F16_Default, DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, + DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_FracZ, + DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_5HD, + DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_FracZ, + DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ) \ + .get_op_info() + + +@op_info_register(sgd_op_info) +def _sgd_tbe(): + """SGD TBE register""" + return diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index fda6b1056ffac9d215cf5355759a306ca0926ce8..fc92bfc8fb56d9a70d643a630d839ab933dddb0b 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -16,6 +16,7 @@ """Operators for math.""" import numpy as np +from ... import context from ..._c_expression import signature_rw as sig_rw from ..._c_expression import signature_kind as sig_kind from ..._c_expression import signature_dtype as sig_dtype @@ -1950,12 +1951,16 @@ class NMSWithMask(PrimitiveWithInfer): """Init NMSWithMask""" validator.check_value_type("iou_threshold", iou_threshold, [float], self.name) self.init_prim_io_names(inputs=['bboxes'], outputs=['selected_boxes', 'selected_idx', 'selected_mask']) + self.is_ge = context.get_context("enable_ge") def infer_shape(self, bboxes_shape): cls_name = self.name validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ, cls_name) validator.check_integer("bboxes.shape()[0]", bboxes_shape[0], 0, Rel.GT, cls_name) - validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 5, Rel.EQ, cls_name) + if not self.is_ge: + validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 8, Rel.EQ, cls_name) + else: + validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 5, Rel.EQ, cls_name) num = bboxes_shape[0] return (bboxes_shape, (num,), (num,)) diff --git a/mindspore/ops/operations/other_ops.py b/mindspore/ops/operations/other_ops.py index f2c0fccca943ff6166cefefa1fa80898a90b5e5c..87162e4e6a07f2dbf87fa742755fbb4ea64853db 100644 --- a/mindspore/ops/operations/other_ops.py +++ b/mindspore/ops/operations/other_ops.py @@ -175,10 +175,10 @@ class CheckValid(PrimitiveWithInfer): self.init_prim_io_names(inputs=['bboxes', 'img_metas'], outputs=['output']) def infer_shape(self, bboxes_shape, metas_shape): - validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ, self.name) - validator.check_integer("bboxes_shape[-1]", bboxes_shape[-1], 4, Rel.EQ, self.name) - validator.check_integer("img_metas rank", len(metas_shape), 1, Rel.EQ, self.name) - validator.check_integer("img_metas shape[0]", metas_shape[0], 3, Rel.EQ, self.name) + validator.check("bboxes rank", len(bboxes_shape), "", 2, Rel.EQ, self.name) + validator.check("bboxes_shape[-1]", bboxes_shape[-1], "", 4, Rel.EQ, self.name) + validator.check("img_metas rank", len(metas_shape), "", 1, Rel.EQ, self.name) + validator.check("img_metas shape[0]", metas_shape[0], "", 3, Rel.EQ, self.name) return bboxes_shape[:-1] def infer_dtype(self, bboxes_type, metas_type): diff --git a/tests/mindspore_test_framework/utils/block_util.py b/tests/mindspore_test_framework/utils/block_util.py index 5ea7d0b8a6f4e3f0878e119b77f0c270f93a9574..28a3c62b31a81f44cdc17cbc71726262b7c41374 100644 --- a/tests/mindspore_test_framework/utils/block_util.py +++ b/tests/mindspore_test_framework/utils/block_util.py @@ -188,6 +188,10 @@ class InputOpNet(nn.Cell): x = self.op(x1, x2, x3, x4, self.c1) return x + def construct4_c2(self, x1, x2, x3, x4): + x = self.op(x1, x2, x3, x4, self.c1, self.c2) + return x + def construct4_c4(self, x1, x2, x3, x4): x = self.op(x1, x2, x3, x4, self.c1, self.c2, self.c3, self.c4) return x diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index f399d5ce30022f04ef4cde465ee7c04f2043e3f7..47396b6038a97d51e9bc0097b4652d56f8296a68 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -951,6 +951,17 @@ test_case_nn_ops = [ 'desc_inputs': [[1, 1, 2, 2], [1, 5]], 'desc_bprop': [[1, 1, 2, 2]], 'skip': ['backward']}), + ('LARSUpdate', { + 'block': P.LARSUpdate(1e-05, 0.001, False), + 'desc_const': [0.0, 0.001], + 'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3]], + 'desc_bprop': [3, 3], + 'skip': ['backward']}), + ('SGD', { + 'block': P.SGD(0.0, 0.0, False), + 'desc_inputs': [[3, 3], [3, 3], Tensor(0.001, mstype.float32), [3, 3], Tensor(0.1, mstype.float32), [3, 3]], + 'desc_bprop': [3, 3], + 'skip': ['backward']}), ] test_case_array_ops = [