提交 9dfa6155 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2497 [AutoParallel]Fix_Hccl_to_support_big_tensor

Merge pull request !2497 from lichen/fix_hccl_to_support_big_tensor
......@@ -67,18 +67,17 @@ bool HcomUtil::GetHcomDataType(const AnfNodePtr &anf_node, vector<hcclDataType_t
bool HcomUtil::GetHcclOpSize(const hcclDataType_t &data_type, const vector<size_t> &shape, size_t *size) {
MS_EXCEPTION_IF_NULL(size);
int tmp_size = 1;
size_t tmp_size = 1;
uint32_t type_size = 4;
for (size_t i = 0; i < shape.size(); i++) {
IntMulWithOverflowCheck(tmp_size, SizeToInt(shape[i]), &tmp_size);
tmp_size = SizetMulWithOverflowCheck(tmp_size, shape[i]);
}
if (!GetHcomTypeSize(data_type, &type_size)) {
return false;
}
IntMulWithOverflowCheck(tmp_size, UintToInt(type_size), &tmp_size);
*size = IntToSize(tmp_size);
*size = SizetMulWithOverflowCheck(tmp_size, type_size);
MS_LOG(INFO) << "size[" << *size << "]";
return true;
......
......@@ -102,6 +102,16 @@ inline void IntMulWithOverflowCheck(int a, int b, int *c) {
*c = out;
}
inline size_t SizetMulWithOverflowCheck(size_t a, size_t b) {
size_t out = a * b;
if (a != 0) {
if ((out / a) != b) {
MS_LOG(EXCEPTION) << "Mul: a(" << a << ") * b(" << b << ") result is overflow";
}
}
return out;
}
inline uint8_t *AddressOffset(void *address, size_t offset) {
MS_EXCEPTION_IF_NULL(address);
return static_cast<uint8_t *>(address) + offset;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册