From 40bad6483f887eae43e467623dd581dfeb4248fb Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Wed, 2 Dec 2020 16:02:28 +0800 Subject: [PATCH] Hot fix complle failed in gcc4.8 caused by complex impl (#29254) (#29274) * hot fix complle failed in gcc4.8 * fix failed unittest --- .../framework/details/nan_inf_utils_detail.cc | 51 +++++++++++++------ .../fluid/tests/unittests/test_nan_inf.py | 3 +- 2 files changed, 37 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/framework/details/nan_inf_utils_detail.cc b/paddle/fluid/framework/details/nan_inf_utils_detail.cc index ceb358b47ad..797a254c951 100644 --- a/paddle/fluid/framework/details/nan_inf_utils_detail.cc +++ b/paddle/fluid/framework/details/nan_inf_utils_detail.cc @@ -152,14 +152,12 @@ static void PrintNanInf(const T* value, const size_t numel, int print_num, static_cast(i), static_cast(value[i])); } } - bool has_nan_inf = true; printf("In cpu, there has %lu,%lu,%lu nan,inf,num\n", static_cast(nan_count), static_cast(inf_count), static_cast(num_count)); - PADDLE_ENFORCE_EQ(has_nan_inf, false, - platform::errors::PreconditionNotMet( - "===ERROR: in [op=%s] [tensor=%s] find nan or inf===", - op_type, var_name)); + PADDLE_THROW(platform::errors::PreconditionNotMet( + "There are `nan` or `inf` in tensor (%s) of operator (%s).", var_name, + op_type)); } // openmp 4.0, reduction with fp16 @@ -231,14 +229,25 @@ template <> void CheckNanInf( const paddle::platform::complex64* value, const size_t numel, int print_num, const std::string& op_type, const std::string& var_name) { - paddle::platform::complex64 sum(0.0, 0.0); -#pragma omp parallel for reduction(+ : sum) + float real_sum = 0.0f; +#pragma omp parallel for reduction(+ : real_sum) for (size_t i = 0; i < numel; ++i) { - sum += (value[i] - value[i]); + real_sum += (value[i].real - value[i].real); } - if (std::isnan(sum) || std::isinf(sum)) { - PrintNanInf(value, numel, print_num, op_type, var_name); + float imag_sum = 0.0f; +#pragma omp parallel for reduction(+ : imag_sum) + for (size_t i = 0; i < numel; ++i) { + imag_sum += (value[i].imag - value[i].imag); + } + + if (std::isnan(real_sum) || std::isinf(real_sum) || std::isnan(imag_sum) || + std::isinf(imag_sum)) { + // hot fix for compile failed in gcc4.8 + // here also need print detail info of nan or inf later + PADDLE_THROW(platform::errors::PreconditionNotMet( + "There are `nan` or `inf` in tensor (%s) of operator (%s).", var_name, + op_type)); } } @@ -246,17 +255,27 @@ template <> void CheckNanInf( const paddle::platform::complex128* value, const size_t numel, int print_num, const std::string& op_type, const std::string& var_name) { - paddle::platform::complex128 sum(0.0, 0.0); -#pragma omp parallel for reduction(+ : sum) + double real_sum = 0.0; +#pragma omp parallel for reduction(+ : real_sum) for (size_t i = 0; i < numel; ++i) { - sum += (value[i] - value[i]); + real_sum += (value[i].real - value[i].real); } - if (std::isnan(sum) || std::isinf(sum)) { - PrintNanInf(value, numel, print_num, op_type, var_name); + double imag_sum = 0.0; +#pragma omp parallel for reduction(+ : imag_sum) + for (size_t i = 0; i < numel; ++i) { + imag_sum += (value[i].imag - value[i].imag); } -} + if (std::isnan(real_sum) || std::isinf(real_sum) || std::isnan(imag_sum) || + std::isinf(imag_sum)) { + // hot fix for compile failed in gcc4.8 + // here also need print detail info of nan or inf later + PADDLE_THROW(platform::errors::PreconditionNotMet( + "There are `nan` or `inf` in tensor (%s) of operator (%s).", var_name, + op_type)); + } +} #endif template <> diff --git a/python/paddle/fluid/tests/unittests/test_nan_inf.py b/python/paddle/fluid/tests/unittests/test_nan_inf.py index dc9ea5d957a..1673002cb79 100644 --- a/python/paddle/fluid/tests/unittests/test_nan_inf.py +++ b/python/paddle/fluid/tests/unittests/test_nan_inf.py @@ -50,7 +50,8 @@ class TestNanInf(unittest.TestCase): assert returncode == 0 # in python3, type(out+err) is 'bytes', need use encode - assert (out + err).find('find nan or inf'.encode()) != -1 + assert (out + err + ).find('There are `nan` or `inf` in tensor'.encode()) != -1 class TestNanInfEnv(TestNanInf): -- GitLab