提交 d1e7b977 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2964 Increase the max tensor size

Merge pull request !2964 from jiangzhenguang/Increase-the-max-tensor-size
......@@ -65,16 +65,16 @@ bool SetIOIputSize(const std::shared_ptr<AnfNode> &anf_node, const size_t &input
} else {
auto type_ptr = TypeIdToType(AnfAlgo::GetInputDeviceDataType(anf_node, i));
MS_EXCEPTION_IF_NULL(type_ptr);
int size_i = 1;
int64_t size_i = 1;
for (size_t j = 0; j < shape_i.size(); j++) {
IntMulWithOverflowCheck(size_i, static_cast<int>(shape_i[j]), &size_i);
size_i = LongMulWithOverflowCheck(size_i, static_cast<int>(shape_i[j]));
}
size_t type_byte = GetTypeByte(type_ptr);
if (type_byte == 0) {
return false;
}
IntMulWithOverflowCheck(size_i, SizeToInt(type_byte), &size_i);
input_size_list->push_back(IntToSize(size_i));
size_i = LongMulWithOverflowCheck(size_i, SizeToInt(type_byte));
input_size_list->push_back(LongToSize(size_i));
}
}
return true;
......@@ -97,16 +97,16 @@ bool SetIOSize(const std::shared_ptr<AnfNode> &anf_node, const std::shared_ptr<A
std::vector<size_t> shape_i = AnfAlgo::GetOutputDeviceShape(anf_node, i);
TypePtr type_ptr = TypeIdToType(AnfAlgo::GetOutputDeviceDataType(anf_node, i));
MS_EXCEPTION_IF_NULL(type_ptr);
int size_i = 1;
int64_t size_i = 1;
for (size_t j = 0; j < shape_i.size(); j++) {
IntMulWithOverflowCheck(size_i, static_cast<int>(shape_i[j]), &size_i);
size_i = LongMulWithOverflowCheck(size_i, static_cast<int>(shape_i[j]));
}
size_t type_byte = GetTypeByte(type_ptr);
if (type_byte == 0) {
return false;
}
IntMulWithOverflowCheck(size_i, SizeToInt(type_byte), &size_i);
output_size_list.push_back(IntToSize(size_i));
size_i = LongMulWithOverflowCheck(size_i, SizeToInt(type_byte));
output_size_list.push_back(LongToSize(size_i));
}
kernel_mod_ptr->SetOutputSizeList(output_size_list);
return true;
......
......@@ -587,7 +587,7 @@ AbstractBasePtr InferImplShapeMul(const AnalysisEnginePtr &, const PrimitivePtr
int result = 1;
for (size_t i = 0; i < shpx_data.size(); i++) {
int value = GetValue<int>(shpx_data[i]);
IntMulWithOverflowCheck(result, value, &result);
result = IntMulWithOverflowCheck(result, value);
}
auto result_v = MakeValue(result);
......
......@@ -91,15 +91,26 @@ inline unsigned int UlongToUint(size_t u) {
return static_cast<unsigned int>(u);
}
inline void IntMulWithOverflowCheck(int a, int b, int *c) {
inline int IntMulWithOverflowCheck(int a, int b) {
int out = a * b;
if (a != 0) {
bool ok = ((out / a) != b);
if (ok) {
bool overflow = ((out / a) != b);
if (overflow) {
MS_LOG(EXCEPTION) << "Mul: a(" << a << ") * b(" << b << ") result is overflow";
}
}
*c = out;
return out;
}
inline int64_t LongMulWithOverflowCheck(int64_t a, int64_t b) {
int64_t out = a * b;
if (a != 0) {
bool overflow = ((out / a) != b);
if (overflow) {
MS_LOG(EXCEPTION) << "Mul: a(" << a << ") * b(" << b << ") result is overflow";
}
}
return out;
}
inline size_t SizetMulWithOverflowCheck(size_t a, size_t b) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册