diff --git a/mindspore/ccsrc/operator/prim_nn.cc b/mindspore/ccsrc/operator/prim_nn.cc index dd00a2cad9ed711500d0d0af0c21dd401eb5cbaa..d90c09256f70cc0f5b29fe58c8bbf51d3713c4bf 100644 --- a/mindspore/ccsrc/operator/prim_nn.cc +++ b/mindspore/ccsrc/operator/prim_nn.cc @@ -407,7 +407,11 @@ AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const Primiti } // convert to bytes(8 bits) mask, using round up - int bytes_count = (count + 7) / 8; + int n128s = count / 128; + if ((count % 128) != 0) { + n128s++; + } + int bytes_count = n128s * 16; std::vector shape_y{bytes_count}; primitive->set_attr("T", kInt32); diff --git a/mindspore/ops/_op_impl/aicpu/__init__.py b/mindspore/ops/_op_impl/aicpu/__init__.py index 2c15f336d534567accc7f5328c8e1b3a1b66123f..37d008940d91a2a3c5e4943eb0753a6715978e6f 100644 --- a/mindspore/ops/_op_impl/aicpu/__init__.py +++ b/mindspore/ops/_op_impl/aicpu/__init__.py @@ -23,3 +23,4 @@ from .reshape import _reshape_aicpu from .flatten import _flatten_aicpu from .squeeze import _squeeze_aicpu from .expand_dims import _expand_dims_aicpu +from .random_choice_with_mask import _random_choice_with_mask_aicpu diff --git a/mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py b/mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..badbea3208e036e64082a0cf4aa8c10a29663316 --- /dev/null +++ b/mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py @@ -0,0 +1,33 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""RandomChoiceWithMask op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +random_choice_with_mask_op_info = AiCPURegOp("RandomChoiceWithMask") \ + .fusion_type("OPAQUE") \ + .input(0, "x", "required") \ + .output(0, "y", "required") \ + .output(1, "mask", "required") \ + .attr("count", "int") \ + .attr("seed", "int") \ + .attr("seed2", "int") \ + .dtype_format(DataType.BOOL_NCHW, DataType.I32_NCHW, DataType.BOOL_NCHW) \ + .get_op_info() + +@op_info_register(random_choice_with_mask_op_info) +def _random_choice_with_mask_aicpu(): + """RandomChoiceWithMask AiCPU register""" + return