未验证 提交 1de32f82 编写于 作者: C Chen Weihang 提交者: GitHub

Hot fix complle failed in gcc4.8 caused by complex impl (#29254)

* hot fix complle failed in gcc4.8

* fix failed unittest
上级 642abe2a
...@@ -152,14 +152,12 @@ static void PrintNanInf(const T* value, const size_t numel, int print_num, ...@@ -152,14 +152,12 @@ static void PrintNanInf(const T* value, const size_t numel, int print_num,
static_cast<uint64_t>(i), static_cast<float>(value[i])); static_cast<uint64_t>(i), static_cast<float>(value[i]));
} }
} }
bool has_nan_inf = true;
printf("In cpu, there has %lu,%lu,%lu nan,inf,num\n", printf("In cpu, there has %lu,%lu,%lu nan,inf,num\n",
static_cast<uint64_t>(nan_count), static_cast<uint64_t>(inf_count), static_cast<uint64_t>(nan_count), static_cast<uint64_t>(inf_count),
static_cast<uint64_t>(num_count)); static_cast<uint64_t>(num_count));
PADDLE_ENFORCE_EQ(has_nan_inf, false, PADDLE_THROW(platform::errors::PreconditionNotMet(
platform::errors::PreconditionNotMet( "There are `nan` or `inf` in tensor (%s) of operator (%s).", var_name,
"===ERROR: in [op=%s] [tensor=%s] find nan or inf===", op_type));
op_type, var_name));
} }
// openmp 4.0, reduction with fp16 // openmp 4.0, reduction with fp16
...@@ -231,14 +229,25 @@ template <> ...@@ -231,14 +229,25 @@ template <>
void CheckNanInf<paddle::platform::complex64>( void CheckNanInf<paddle::platform::complex64>(
const paddle::platform::complex64* value, const size_t numel, int print_num, const paddle::platform::complex64* value, const size_t numel, int print_num,
const std::string& op_type, const std::string& var_name) { const std::string& op_type, const std::string& var_name) {
paddle::platform::complex64 sum(0.0, 0.0); float real_sum = 0.0f;
#pragma omp parallel for reduction(+ : sum) #pragma omp parallel for reduction(+ : real_sum)
for (size_t i = 0; i < numel; ++i) { for (size_t i = 0; i < numel; ++i) {
sum += (value[i] - value[i]); real_sum += (value[i].real - value[i].real);
} }
if (std::isnan(sum) || std::isinf(sum)) { float imag_sum = 0.0f;
PrintNanInf(value, numel, print_num, op_type, var_name); #pragma omp parallel for reduction(+ : imag_sum)
for (size_t i = 0; i < numel; ++i) {
imag_sum += (value[i].imag - value[i].imag);
}
if (std::isnan(real_sum) || std::isinf(real_sum) || std::isnan(imag_sum) ||
std::isinf(imag_sum)) {
// hot fix for compile failed in gcc4.8
// here also need print detail info of nan or inf later
PADDLE_THROW(platform::errors::PreconditionNotMet(
"There are `nan` or `inf` in tensor (%s) of operator (%s).", var_name,
op_type));
} }
} }
...@@ -246,17 +255,27 @@ template <> ...@@ -246,17 +255,27 @@ template <>
void CheckNanInf<paddle::platform::complex128>( void CheckNanInf<paddle::platform::complex128>(
const paddle::platform::complex128* value, const size_t numel, const paddle::platform::complex128* value, const size_t numel,
int print_num, const std::string& op_type, const std::string& var_name) { int print_num, const std::string& op_type, const std::string& var_name) {
paddle::platform::complex128 sum(0.0, 0.0); double real_sum = 0.0;
#pragma omp parallel for reduction(+ : sum) #pragma omp parallel for reduction(+ : real_sum)
for (size_t i = 0; i < numel; ++i) { for (size_t i = 0; i < numel; ++i) {
sum += (value[i] - value[i]); real_sum += (value[i].real - value[i].real);
} }
if (std::isnan(sum) || std::isinf(sum)) { double imag_sum = 0.0;
PrintNanInf(value, numel, print_num, op_type, var_name); #pragma omp parallel for reduction(+ : imag_sum)
for (size_t i = 0; i < numel; ++i) {
imag_sum += (value[i].imag - value[i].imag);
} }
}
if (std::isnan(real_sum) || std::isinf(real_sum) || std::isnan(imag_sum) ||
std::isinf(imag_sum)) {
// hot fix for compile failed in gcc4.8
// here also need print detail info of nan or inf later
PADDLE_THROW(platform::errors::PreconditionNotMet(
"There are `nan` or `inf` in tensor (%s) of operator (%s).", var_name,
op_type));
}
}
#endif #endif
template <> template <>
......
...@@ -50,7 +50,8 @@ class TestNanInf(unittest.TestCase): ...@@ -50,7 +50,8 @@ class TestNanInf(unittest.TestCase):
assert returncode == 0 assert returncode == 0
# in python3, type(out+err) is 'bytes', need use encode # in python3, type(out+err) is 'bytes', need use encode
assert (out + err).find('find nan or inf'.encode()) != -1 assert (out + err
).find('There are `nan` or `inf` in tensor'.encode()) != -1
class TestNanInfEnv(TestNanInf): class TestNanInfEnv(TestNanInf):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册