提交 fb90ff16 编写于 作者: J jzg

increase the max size of tensor.

上级 94d0d45a
......@@ -67,13 +67,13 @@ bool SetIOIputSize(const std::shared_ptr<AnfNode> &anf_node, const size_t &input
MS_EXCEPTION_IF_NULL(type_ptr);
int64_t size_i = 1;
for (size_t j = 0; j < shape_i.size(); j++) {
LongMulWithOverflowCheck(size_i, static_cast<int>(shape_i[j]), &size_i);
size_i = LongMulWithOverflowCheck(size_i, static_cast<int>(shape_i[j]));
}
size_t type_byte = GetTypeByte(type_ptr);
if (type_byte == 0) {
return false;
}
LongMulWithOverflowCheck(size_i, SizeToInt(type_byte), &size_i);
size_i = LongMulWithOverflowCheck(size_i, SizeToInt(type_byte));
input_size_list->push_back(LongToSize(size_i));
}
}
......@@ -99,13 +99,13 @@ bool SetIOSize(const std::shared_ptr<AnfNode> &anf_node, const std::shared_ptr<A
MS_EXCEPTION_IF_NULL(type_ptr);
int64_t size_i = 1;
for (size_t j = 0; j < shape_i.size(); j++) {
LongMulWithOverflowCheck(size_i, static_cast<int>(shape_i[j]), &size_i);
size_i = LongMulWithOverflowCheck(size_i, static_cast<int>(shape_i[j]));
}
size_t type_byte = GetTypeByte(type_ptr);
if (type_byte == 0) {
return false;
}
LongMulWithOverflowCheck(size_i, SizeToInt(type_byte), &size_i);
size_i = LongMulWithOverflowCheck(size_i, SizeToInt(type_byte));
output_size_list.push_back(LongToSize(size_i));
}
kernel_mod_ptr->SetOutputSizeList(output_size_list);
......
......@@ -587,7 +587,7 @@ AbstractBasePtr InferImplShapeMul(const AnalysisEnginePtr &, const PrimitivePtr
int result = 1;
for (size_t i = 0; i < shpx_data.size(); i++) {
int value = GetValue<int>(shpx_data[i]);
IntMulWithOverflowCheck(result, value, &result);
result = IntMulWithOverflowCheck(result, value);
}
auto result_v = MakeValue(result);
......
......@@ -91,26 +91,26 @@ inline unsigned int UlongToUint(size_t u) {
return static_cast<unsigned int>(u);
}
inline void IntMulWithOverflowCheck(int a, int b, int *c) {
inline int IntMulWithOverflowCheck(int a, int b) {
int out = a * b;
if (a != 0) {
bool ok = ((out / a) != b);
if (ok) {
bool overflow = ((out / a) != b);
if (overflow) {
MS_LOG(EXCEPTION) << "Mul: a(" << a << ") * b(" << b << ") result is overflow";
}
}
*c = out;
return out;
}
inline void LongMulWithOverflowCheck(int64_t a, int64_t b, int64_t *c) {
inline int64_t LongMulWithOverflowCheck(int64_t a, int64_t b) {
int64_t out = a * b;
if (a != 0) {
bool ok = ((out / a) != b);
if (ok) {
bool overflow = ((out / a) != b);
if (overflow) {
MS_LOG(EXCEPTION) << "Mul: a(" << a << ") * b(" << b << ") result is overflow";
}
}
*c = out;
return out;
}
inline size_t SizetMulWithOverflowCheck(size_t a, size_t b) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册