diff --git a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc b/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc index c83994b5f22202f80a86e0d3b7252cc00171ae7f..f602a6acd8f651668bf9cb060fcd2a6093846c1b 100644 --- a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc +++ b/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc @@ -65,16 +65,16 @@ bool SetIOIputSize(const std::shared_ptr &anf_node, const size_t &input } else { auto type_ptr = TypeIdToType(AnfAlgo::GetInputDeviceDataType(anf_node, i)); MS_EXCEPTION_IF_NULL(type_ptr); - int size_i = 1; + int64_t size_i = 1; for (size_t j = 0; j < shape_i.size(); j++) { - IntMulWithOverflowCheck(size_i, static_cast(shape_i[j]), &size_i); + LongMulWithOverflowCheck(size_i, static_cast(shape_i[j]), &size_i); } size_t type_byte = GetTypeByte(type_ptr); if (type_byte == 0) { return false; } - IntMulWithOverflowCheck(size_i, SizeToInt(type_byte), &size_i); - input_size_list->push_back(IntToSize(size_i)); + LongMulWithOverflowCheck(size_i, SizeToInt(type_byte), &size_i); + input_size_list->push_back(LongToSize(size_i)); } } return true; @@ -97,16 +97,16 @@ bool SetIOSize(const std::shared_ptr &anf_node, const std::shared_ptr shape_i = AnfAlgo::GetOutputDeviceShape(anf_node, i); TypePtr type_ptr = TypeIdToType(AnfAlgo::GetOutputDeviceDataType(anf_node, i)); MS_EXCEPTION_IF_NULL(type_ptr); - int size_i = 1; + int64_t size_i = 1; for (size_t j = 0; j < shape_i.size(); j++) { - IntMulWithOverflowCheck(size_i, static_cast(shape_i[j]), &size_i); + LongMulWithOverflowCheck(size_i, static_cast(shape_i[j]), &size_i); } size_t type_byte = GetTypeByte(type_ptr); if (type_byte == 0) { return false; } - IntMulWithOverflowCheck(size_i, SizeToInt(type_byte), &size_i); - output_size_list.push_back(IntToSize(size_i)); + LongMulWithOverflowCheck(size_i, SizeToInt(type_byte), &size_i); + output_size_list.push_back(LongToSize(size_i)); } kernel_mod_ptr->SetOutputSizeList(output_size_list); return true; diff --git a/mindspore/ccsrc/utils/convert_utils_base.h b/mindspore/ccsrc/utils/convert_utils_base.h index 3638a43e6afa6e075e57450cba17573fabbc24f6..8960d6628b5b3c2832e0e8eec35a696af820368d 100644 --- a/mindspore/ccsrc/utils/convert_utils_base.h +++ b/mindspore/ccsrc/utils/convert_utils_base.h @@ -102,6 +102,17 @@ inline void IntMulWithOverflowCheck(int a, int b, int *c) { *c = out; } +inline void LongMulWithOverflowCheck(int64_t a, int64_t b, int64_t *c) { + int64_t out = a * b; + if (a != 0) { + bool ok = ((out / a) != b); + if (ok) { + MS_LOG(EXCEPTION) << "Mul: a(" << a << ") * b(" << b << ") result is overflow"; + } + } + *c = out; +} + inline size_t SizetMulWithOverflowCheck(size_t a, size_t b) { size_t out = a * b; if (a != 0) {