diff --git a/mindspore/ccsrc/kernel/hccl/hcom_util.cc b/mindspore/ccsrc/kernel/hccl/hcom_util.cc index 61a4d43eb5b92f817075373124a65b10a5ad8ecd..088dbe59d5b69aee64f7995c6a47167d6f5964d5 100644 --- a/mindspore/ccsrc/kernel/hccl/hcom_util.cc +++ b/mindspore/ccsrc/kernel/hccl/hcom_util.cc @@ -67,18 +67,17 @@ bool HcomUtil::GetHcomDataType(const AnfNodePtr &anf_node, vector &shape, size_t *size) { MS_EXCEPTION_IF_NULL(size); - int tmp_size = 1; + size_t tmp_size = 1; uint32_t type_size = 4; for (size_t i = 0; i < shape.size(); i++) { - IntMulWithOverflowCheck(tmp_size, SizeToInt(shape[i]), &tmp_size); + tmp_size = SizetMulWithOverflowCheck(tmp_size, shape[i]); } if (!GetHcomTypeSize(data_type, &type_size)) { return false; } - IntMulWithOverflowCheck(tmp_size, UintToInt(type_size), &tmp_size); - *size = IntToSize(tmp_size); + *size = SizetMulWithOverflowCheck(tmp_size, type_size); MS_LOG(INFO) << "size[" << *size << "]"; return true; diff --git a/mindspore/ccsrc/utils/convert_utils_base.h b/mindspore/ccsrc/utils/convert_utils_base.h index 76d89303249bf2717ee259f70d421ffca3b67204..3638a43e6afa6e075e57450cba17573fabbc24f6 100644 --- a/mindspore/ccsrc/utils/convert_utils_base.h +++ b/mindspore/ccsrc/utils/convert_utils_base.h @@ -102,6 +102,16 @@ inline void IntMulWithOverflowCheck(int a, int b, int *c) { *c = out; } +inline size_t SizetMulWithOverflowCheck(size_t a, size_t b) { + size_t out = a * b; + if (a != 0) { + if ((out / a) != b) { + MS_LOG(EXCEPTION) << "Mul: a(" << a << ") * b(" << b << ") result is overflow"; + } + } + return out; +} + inline uint8_t *AddressOffset(void *address, size_t offset) { MS_EXCEPTION_IF_NULL(address); return static_cast(address) + offset;