diff --git a/paddle/fluid/eager/tests/task_tests/nan_inf_utils_test.cc b/paddle/fluid/eager/tests/task_tests/nan_inf_utils_test.cc index 73d213f71148f7898b2b584deda1a48e59d1f543..86f863bdffa6d55cc6dad608e22f1c2aca078ad1 100644 --- a/paddle/fluid/eager/tests/task_tests/nan_inf_utils_test.cc +++ b/paddle/fluid/eager/tests/task_tests/nan_inf_utils_test.cc @@ -30,32 +30,30 @@ PD_DECLARE_KERNEL(strings_empty, CPU, ALL_LAYOUT); namespace egr { -#define CHECK_NAN_INF(tensors) \ - { \ - bool caught_exception = false; \ - try { \ - CheckTensorHasNanOrInf("nan_inf_test", tensors); \ - } catch (paddle::platform::EnforceNotMet & error) { \ - caught_exception = true; \ - std::string ex_msg = error.what(); \ - EXPECT_TRUE(ex_msg.find("There are `nan` or `inf` in tensor") != \ - std::string::npos); \ - } \ - EXPECT_TRUE(caught_exception); \ +#define CHECK_NAN_INF(tensors) \ + { \ + bool caught_exception = false; \ + try { \ + CheckTensorHasNanOrInf("nan_inf_test", tensors); \ + } catch (paddle::platform::EnforceNotMet & error) { \ + caught_exception = true; \ + std::string ex_msg = error.what(); \ + EXPECT_TRUE(ex_msg.find("There are NAN or INF") != std::string::npos); \ + } \ + EXPECT_TRUE(caught_exception); \ } -#define CHECK_NO_NAN_INF(tensors) \ - { \ - bool caught_exception = false; \ - try { \ - CheckTensorHasNanOrInf("nan_inf_test", tensors); \ - } catch (paddle::platform::EnforceNotMet & error) { \ - caught_exception = true; \ - std::string ex_msg = error.what(); \ - EXPECT_TRUE(ex_msg.find("There are `nan` or `inf` in tensor") != \ - std::string::npos); \ - } \ - EXPECT_FALSE(caught_exception); \ +#define CHECK_NO_NAN_INF(tensors) \ + { \ + bool caught_exception = false; \ + try { \ + CheckTensorHasNanOrInf("nan_inf_test", tensors); \ + } catch (paddle::platform::EnforceNotMet & error) { \ + caught_exception = true; \ + std::string ex_msg = error.what(); \ + EXPECT_TRUE(ex_msg.find("There are NAN or INF") != std::string::npos); \ + } \ + EXPECT_FALSE(caught_exception); \ } TEST(NanInfUtils, Functions) { diff --git a/paddle/fluid/framework/details/nan_inf_utils_detail.cc b/paddle/fluid/framework/details/nan_inf_utils_detail.cc index f80bb94b30b648e9829e7135d45f2587726b08d4..30046b2d1d44e8c89f033aebc7619ed6a5980c99 100644 --- a/paddle/fluid/framework/details/nan_inf_utils_detail.cc +++ b/paddle/fluid/framework/details/nan_inf_utils_detail.cc @@ -17,6 +17,7 @@ #include "paddle/fluid/framework/details/nan_inf_utils.h" #include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/scope.h" +#include "paddle/phi/common/amp_type_traits.h" #ifdef PADDLE_WITH_ASCEND_CL #include "paddle/fluid/platform/device/npu/npu_op_runner.h" @@ -24,6 +25,8 @@ #include "paddle/fluid/framework/convert_utils.h" #include "paddle/phi/kernels/funcs/eigen/extensions.h" +DECLARE_int32(check_nan_inf_level); + namespace paddle { namespace framework { namespace details { @@ -90,7 +93,7 @@ static void InitWhiteListFormEnv() { const char* op_role_skip = std::getenv("PADDLE_INF_NAN_SKIP_ROLE"); const char* op_var_skip = std::getenv("PADDLE_INF_NAN_SKIP_VAR"); - if (op_type_skip != NULL) { + if (op_type_skip) { std::stringstream ss(op_type_skip); std::string op_type; while (std::getline(ss, op_type, ',')) { @@ -98,7 +101,7 @@ static void InitWhiteListFormEnv() { } } - if (op_role_skip != NULL) { + if (op_role_skip) { std::stringstream ss(op_role_skip); std::string op_role; while (std::getline(ss, op_role, ',')) { @@ -113,7 +116,7 @@ static void InitWhiteListFormEnv() { } } - if (op_var_skip != NULL) { + if (op_var_skip) { std::stringstream ss(op_var_skip); std::string op_var; while (std::getline(ss, op_var, ',')) { @@ -131,175 +134,101 @@ static void InitWhiteListFormEnv() { } } -template -static void PrintNanInf(const T* value, - const size_t numel, - int print_num, - const std::string& op_type, - const std::string& var_name, - bool abort = true) { - T min_value = std::numeric_limits::max(); - T max_value = std::numeric_limits::min(); - size_t nan_count, inf_count, num_count; - nan_count = inf_count = num_count = 0; - - // CPU print num value - for (size_t i = 0; i < numel; ++i) { - size_t count = 0; - if (std::isnan(value[i])) { - count = nan_count++; - } else if (std::isinf(value[i])) { - count = inf_count++; - } else { - count = num_count++; - min_value = std::min(min_value, value[i]); - max_value = std::max(max_value, value[i]); - } - - if (count < static_cast(print_num)) { - printf("numel:%zu index:%zu value:%f\n", - numel, - i, - static_cast(value[i])); - } - } - printf( - "In cpu, there has %zu,%zu,%zu nan,inf,num. " - "And in num, min_value is %f, max_value is %f\n", - nan_count, - inf_count, - num_count, - static_cast(min_value), - static_cast(max_value)); - if (abort) { - PADDLE_THROW(platform::errors::PreconditionNotMet( - "There are `nan` or `inf` in tensor (%s) of operator (%s).", - var_name, - op_type)); - } -} +template < + typename T, + std::enable_if_t>::value && + !std::is_same>::value, + bool> = true> +static void CheckNanInfCpuImpl(const T* value_ptr, + const int64_t numel, + const std::string& cpu_hint_str) { + using MT = typename phi::dtype::template MPTypeTrait::Type; + +#ifdef _OPENMP + // Use maximum 4 threads to collect the nan and inf information. + int num_threads = std::max(omp_get_num_threads(), 1); + num_threads = std::min(num_threads, 4); +#else + int num_threads = 1; +#endif -// openmp 4.0, reduction with fp16 -#if defined _OPENMP && _OPENMP >= 201307 -// more detail see: 180 page of -// https://www.openmp.org/wp-content/uploads/OpenMP4.0.0.pdf -#pragma omp declare reduction(+ : paddle::platform::float16 : omp_out += omp_in) -#pragma omp declare reduction(+ : paddle::platform::bfloat16 : omp_out += \ - omp_in) -#pragma omp declare reduction(+ : paddle::platform::complex < \ - float > : omp_out += omp_in) -#pragma omp declare reduction(+ : paddle::platform::complex < \ - double > : omp_out += omp_in) + std::vector thread_num_nan(num_threads, 0); + std::vector thread_num_inf(num_threads, 0); + std::vector thread_min_value(num_threads, static_cast(value_ptr[0])); + std::vector thread_max_value(num_threads, static_cast(value_ptr[0])); + std::vector thread_mean_value(num_threads, static_cast(0)); +#ifdef _OPENMP +#pragma omp parallel num_threads(num_threads) #endif - -template -static void CheckNanInf(const T* value, - const size_t numel, - int print_num, - const std::string& op_type, - const std::string& var_name) { - T sum = static_cast(0.0); -#if defined _OPENMP && _OPENMP >= 201307 -#pragma omp parallel for simd reduction(+ : sum) -#elif defined _OPENMP -#pragma omp parallel for reduction(+ : sum) + { +#ifdef _OPENMP + int64_t tid = omp_get_thread_num(); + int64_t chunk_size = (numel + num_threads - 1) / num_threads; + int64_t begin = tid * chunk_size; + int64_t end = chunk_size + begin > numel ? numel : chunk_size + begin; +#else + int64_t tid = 0; + int64_t begin = 0; + int64_t end = numel; #endif - for (size_t i = 0; i < numel; ++i) { - sum += (value[i] - value[i]); - } + for (int64_t i = begin; i < end; ++i) { + MT value = static_cast(value_ptr[i]); - if (std::isnan(sum) || std::isinf(sum)) { - PrintNanInf(value, numel, print_num, op_type, var_name); - } -} + thread_min_value[tid] = std::min(thread_min_value[tid], value); + thread_max_value[tid] = std::max(thread_max_value[tid], value); + thread_mean_value[tid] += value / static_cast(numel); -#if defined _OPENMP && _OPENMP >= 201307 -// openmp4.0 not need to specialization fp16 -#elif defined _OPENMP -template <> -void CheckNanInf( - const paddle::platform::float16* value, - const size_t numel, - int print_num, - const std::string& op_type, - const std::string& var_name) { - float sum = 0.0f; -#pragma omp parallel for reduction(+ : sum) - for (size_t i = 0; i < numel; ++i) { - sum += static_cast(value[i] - value[i]); - } - - if (std::isnan(sum) || std::isinf(sum)) { - PrintNanInf(value, numel, print_num, op_type, var_name); + if (std::isnan(value)) { + thread_num_nan[tid] += 1; + } else if (std::isinf(value)) { + thread_num_inf[tid] += 1; + } + } } -} -template <> -void CheckNanInf( - const paddle::platform::bfloat16* value, - const size_t numel, - int print_num, - const std::string& op_type, - const std::string& var_name) { - float sum = 0.0f; -#pragma omp parallel for reduction(+ : sum) - for (size_t i = 0; i < numel; ++i) { - sum += static_cast(value[i] - value[i]); + int64_t num_nan = 0; + int64_t num_inf = 0; + MT min_value = thread_min_value[0]; + MT max_value = thread_max_value[0]; + MT mean_value = static_cast(0); + for (int i = 0; i < num_threads; ++i) { + num_nan += thread_num_nan[i]; + num_inf += thread_num_inf[i]; + min_value = std::min(thread_min_value[i], min_value); + max_value = std::max(thread_max_value[i], max_value); + mean_value += thread_mean_value[i]; } - if (std::isnan(sum) || std::isinf(sum)) { - PrintNanInf(value, numel, print_num, op_type, var_name); - } + PrintForDifferentLevel(cpu_hint_str.c_str(), + numel, + num_nan, + num_inf, + max_value, + min_value, + mean_value, + FLAGS_check_nan_inf_level); } -template <> -void CheckNanInf>( - const paddle::platform::complex* value, - const size_t numel, - int print_num, - const std::string& op_type, - const std::string& var_name) { - float real_sum = 0.0f; -#pragma omp parallel for reduction(+ : real_sum) - for (size_t i = 0; i < numel; ++i) { - real_sum += (value[i].real - value[i].real); - } - - float imag_sum = 0.0f; -#pragma omp parallel for reduction(+ : imag_sum) - for (size_t i = 0; i < numel; ++i) { - imag_sum += (value[i].imag - value[i].imag); - } - - if (std::isnan(real_sum) || std::isinf(real_sum) || std::isnan(imag_sum) || - std::isinf(imag_sum)) { - // hot fix for compile failed in gcc4.8 - // here also need print detail info of nan or inf later - PADDLE_THROW(platform::errors::PreconditionNotMet( - "There are `nan` or `inf` in tensor (%s) of operator (%s).", - var_name, - op_type)); - } -} +template < + typename T, + std::enable_if_t>::value || + std::is_same>::value, + bool> = true> +void CheckNanInfCpuImpl(const T* value_ptr, + const int64_t numel, + const std::string& cpu_hint_str) { + using RealType = typename T::value_type; -template <> - void CheckNanInf < paddle::platform::complex < double >>> - (const paddle::platform::complex* value, - const size_t numel, - int print_num, - const std::string& op_type, - const std::string& var_name) { - double real_sum = 0.0; -#pragma omp parallel for reduction(+ : real_sum) - for (size_t i = 0; i < numel; ++i) { - real_sum += (value[i].real - value[i].real); - } + RealType real_sum = 0.0f, imag_sum = 0.0f; - double imag_sum = 0.0; -#pragma omp parallel for reduction(+ : imag_sum) - for (size_t i = 0; i < numel; ++i) { - imag_sum += (value[i].imag - value[i].imag); +#ifdef _OPENMP +#pragma omp parallel for reduction(+ : real_sum) reduction(+ : imag_sum) +#endif + for (int64_t i = 0; i < numel; ++i) { + T value = value_ptr[i]; + real_sum += (value.real - value.real); + imag_sum += (value.imag - value.imag); } if (std::isnan(real_sum) || std::isinf(real_sum) || std::isnan(imag_sum) || @@ -307,14 +236,10 @@ template <> // hot fix for compile failed in gcc4.8 // here also need print detail info of nan or inf later PADDLE_THROW(platform::errors::PreconditionNotMet( - "There are `nan` or `inf` in tensor (%s) of operator (%s).", - var_name, - op_type)); + "There are NAN or INF in %s.", cpu_hint_str)); } } -#endif - template <> template void TensorCheckerVisitor::apply( @@ -323,10 +248,9 @@ void TensorCheckerVisitor::apply( std::is_same>::value || std::is_same>::value>::type*) const { - // use env strategy control in future, -1=print_all. - int print_num = 3; - CheckNanInf( - tensor_.data(), tensor_.numel(), print_num, op_type_, var_name_); + std::string cpu_hint_str = + GetCpuHintString(op_type, var_name, tensor.place()); + CheckNanInfCpuImpl(tensor.data(), tensor.numel(), cpu_hint_str); } template <> @@ -371,8 +295,8 @@ void CheckVarHasNanOrInf(const std::string& op_type, tensor_check(op_type, var_name, *tensor, place); #else PADDLE_THROW(platform::errors::PreconditionNotMet( - "phi::DenseTensor[%s] use gpu place. PaddlePaddle must compile with " - "GPU.", + "phi::DenseTensor[%s] use gpu place. PaddlePaddle must compile " + "with GPU.", var_name)); #endif return; @@ -406,8 +330,8 @@ void CheckVarHasNanOrInf(const std::string& op_type, var_name)); #else PADDLE_THROW(platform::errors::PreconditionNotMet( - "phi::DenseTensor[%s] use xpu place. PaddlePaddle must compile with " - "XPU.", + "phi::DenseTensor[%s] use xpu place. PaddlePaddle must compile " + "with XPU.", var_name)); #endif return; @@ -440,8 +364,8 @@ void CheckVarHasNanOrInf(const std::string& op_type, var_name)); #else PADDLE_THROW(platform::errors::PreconditionNotMet( - "phi::DenseTensor[%s] use npu place. PaddlePaddle must compile with " - "NPU.", + "phi::DenseTensor[%s] use npu place. PaddlePaddle must compile " + "with NPU.", var_name)); #endif return; diff --git a/paddle/fluid/framework/details/nan_inf_utils_detail.cu b/paddle/fluid/framework/details/nan_inf_utils_detail.cu index abf575b4ca5453776f787a98ade6f4d2b1e1dde5..629ab737055a47c9493c686f982e0c05cf7441e2 100644 --- a/paddle/fluid/framework/details/nan_inf_utils_detail.cu +++ b/paddle/fluid/framework/details/nan_inf_utils_detail.cu @@ -138,6 +138,54 @@ __global__ void CheckNanInfKernel(const T* value, PrintNanInfKernel(value, numel, print_num, debug_info); } +template +__device__ T BlockReduce(T value) { + __shared__ T shared_mem[1024]; + + shared_mem[threadIdx.x] = value; + __syncthreads(); + + for (int stride = blockDim.x >> 1; stride > 0; stride = stride >> 1) { + if (threadIdx.x < stride) { + T value0 = shared_mem[threadIdx.x]; + T value1 = shared_mem[threadIdx.x + stride]; + T reduce_value; + if (ReduceType == 0) { + // max + reduce_value = value0 > value1 ? value0 : value1; + } else if (ReduceType == 1) { + // min + reduce_value = value0 < value1 ? value0 : value1; + } else if (ReduceType == 2) { + // sum + reduce_value = value0 + value1; + } + shared_mem[threadIdx.x] = reduce_value; + } + + if (stride > 16) { + __syncthreads(); + } + } + + __syncthreads(); + return shared_mem[0]; +} + +__device__ void BlockReduceNumNanInfAndWrite(const int64_t num_nan, + const int64_t num_inf, + int64_t offset, + int64_t* num_nan_ptr, + int64_t* num_inf_ptr) { + int64_t block_num_nan = BlockReduce(num_nan); + int64_t block_num_inf = BlockReduce(num_inf); + + if (threadIdx.x == 0) { + num_nan_ptr[offset] = block_num_nan; + num_inf_ptr[offset] = block_num_inf; + } +} + template < typename T, std::enable_if_t>::value || @@ -183,15 +231,16 @@ __device__ void BlockReduceMaxMinAndWrite(const T max_value, template __global__ void FindNanInfAndBlockMaxMin(const T* value_ptr, const int64_t numel, - int* found_nan_inf_ptr, + int64_t* block_num_nan_ptr, + int64_t* block_num_inf_ptr, MT* tensor_block_max_ptr, MT* tensor_block_min_ptr, MT* tensor_block_mean_ptr) { - bool has_nan = false; - bool has_inf = false; - int64_t i = threadIdx.x + blockIdx.x * blockDim.x; + int64_t num_nan = 0; + int64_t num_inf = 0; + MT max_value = static_cast(i < numel ? value_ptr[i] : value_ptr[0]); MT min_value = static_cast(i < numel ? value_ptr[i] : value_ptr[0]); MT mean_value = static_cast(0); @@ -203,25 +252,14 @@ __global__ void FindNanInfAndBlockMaxMin(const T* value_ptr, mean_value += value / static_cast(numel); if (isnan(value)) { - has_nan = true; - } - if (isinf(value)) { - has_inf = true; - } - - if (has_nan || has_inf) { - if (!tensor_block_max_ptr && !tensor_block_min_ptr && - !tensor_block_mean_ptr) { - break; - } + num_nan += 1; + } else if (isinf(value)) { + num_inf += 1; } } - if (has_nan) { - found_nan_inf_ptr[0] = 1; - } - if (has_inf) { - found_nan_inf_ptr[1] = 1; - } + + BlockReduceNumNanInfAndWrite( + num_nan, num_inf, blockIdx.x, block_num_nan_ptr, block_num_inf_ptr); BlockReduceMaxMinAndWrite(max_value, min_value, @@ -232,32 +270,9 @@ __global__ void FindNanInfAndBlockMaxMin(const T* value_ptr, tensor_block_mean_ptr); } -template ::value, bool> = true> -__device__ bool NeedPrint(MT max_value, MT min_value, int check_nan_inf_level) { - if (check_nan_inf_level >= 3) { - return true; - } else if (check_nan_inf_level >= 2) { - MT fp16_max = - static_cast(std::numeric_limits::max()); - return max_value > fp16_max || min_value < -fp16_max; - } - return false; -} - -template ::value, bool> = true> -__device__ bool NeedPrint(MT max_value, MT min_value, int check_nan_inf_level) { - if (check_nan_inf_level >= 3) { - return true; - } - return false; -} - template -__global__ void FindGlobalMaxMinAndPrint(const int* found_nan_inf_ptr, +__global__ void FindGlobalMaxMinAndPrint(const int64_t* block_num_nan_ptr, + const int64_t* block_num_inf_ptr, const MT* tensor_block_max_ptr, const MT* tensor_block_min_ptr, const MT* tensor_block_mean_ptr, @@ -266,8 +281,14 @@ __global__ void FindGlobalMaxMinAndPrint(const int* found_nan_inf_ptr, int64_t numel_max_min, int check_nan_inf_level) { if (blockIdx.x == 0 && threadIdx.x == 0) { - int has_nan = found_nan_inf_ptr[0]; - int has_inf = found_nan_inf_ptr[1]; + int64_t num_nan = 0; + int64_t num_inf = 0; + + // numel_max_min <= 128 + for (int64_t i = 0; i < numel_max_min; ++i) { + num_nan += block_num_nan_ptr[i]; + num_inf += block_num_inf_ptr[i]; + } MT max_value = static_cast(0); MT min_value = static_cast(0); @@ -289,67 +310,31 @@ __global__ void FindGlobalMaxMinAndPrint(const int* found_nan_inf_ptr, } } - if (has_nan || has_inf) { - if (check_nan_inf_level == 0) { - PADDLE_ENFORCE(false, - "===[PRECISION] [ERROR] in %s, numel=%ld, find_nan=%d, " - "find_inf=%d, " - "max=%e, min=%e, mean=%e===\n", - debug_info, - numel, - has_nan, - has_inf, - static_cast(max_value), - static_cast(min_value), - static_cast(mean_value)); - } else if (check_nan_inf_level >= 1) { - printf( - "===[PRECISION] [ERROR] in %s, numel=%ld, find_nan=%d, " - "find_inf=%d, " - "max=%e, min=%e, mean=%e===\n", - debug_info, - numel, - has_nan, - has_inf, - static_cast(max_value), - static_cast(min_value), - static_cast(mean_value)); - } - } else if (NeedPrint(max_value, min_value, check_nan_inf_level)) { - printf("[PRECISION] in %s, numel=%ld, max=%e, min=%e, mean=%e\n", - debug_info, - numel, - static_cast(max_value), - static_cast(min_value), - static_cast(mean_value)); - } + PrintForDifferentLevel(debug_info, + numel, + num_nan, + num_inf, + max_value, + min_value, + mean_value, + check_nan_inf_level); } } -template <> template -void TensorCheckerVisitor::apply( - typename std::enable_if< - std::is_floating_point::value || - std::is_same>::value || - std::is_same>::value>::type*) - const { - auto* dev_ctx = reinterpret_cast( - platform::DeviceContextPool::Instance().Get(tensor_.place())); - int dev_id = tensor_.place().device; +static char* GetGpuHintStringPtr(const phi::GPUContext& ctx, + const std::string& op_type, + const std::string& var_name, + int dev_id) { PADDLE_ENFORCE_EQ( (dev_id >= 0 && dev_id < multi_op_var2gpu_str_mutex().size()), true, platform::errors::OutOfRange("GPU dev_id must >=0 and < dev_count=%d", multi_op_var2gpu_str_mutex().size())); - std::string dtype_str = DataTypeToString(DataTypeTrait::DataType()); - if (dtype_str == "::paddle::platform::float16") { - dtype_str = "float16"; - } - std::string op_var = "[op=" + op_type_ + "] [tensor=" + var_name_ + - "] [dtype=" + dtype_str + "]"; - char* gpu_str_ptr = NULL; + std::string op_var = + GetCpuHintString(op_type, var_name, ctx.GetPlace(), dev_id); + char* gpu_str_ptr = nullptr; { auto& op_var2gpu_str_mutex = multi_op_var2gpu_str_mutex().at(dev_id); @@ -358,9 +343,9 @@ void TensorCheckerVisitor::apply( std::lock_guard guard(op_var2gpu_str_mutex); if (op_var2gpu_str.find(op_var) == op_var2gpu_str.end()) { // insert auto gpu_str_tensor = paddle::memory::Alloc( - dev_ctx->GetPlace(), + ctx.GetPlace(), op_var.length() + 1, - phi::Stream(reinterpret_cast(dev_ctx->stream()))); + phi::Stream(reinterpret_cast(ctx.stream()))); gpu_str_ptr = reinterpret_cast(gpu_str_tensor->ptr()); op_var2gpu_str.emplace(op_var, std::move(gpu_str_tensor)); @@ -378,13 +363,13 @@ void TensorCheckerVisitor::apply( iter->first.c_str(), op_var.length() + 1, hipMemcpyHostToDevice, - dev_ctx->stream())); + ctx.stream())); #else PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(gpu_str_ptr, iter->first.c_str(), op_var.length() + 1, cudaMemcpyHostToDevice, - dev_ctx->stream())); + ctx.stream())); #endif } else { // get auto iter = op_var2gpu_str.find(op_var); @@ -397,6 +382,22 @@ void TensorCheckerVisitor::apply( gpu_str_ptr = reinterpret_cast(iter->second->ptr()); } } + return gpu_str_ptr; +} + +template <> +template +void TensorCheckerVisitor::apply( + typename std::enable_if< + std::is_floating_point::value || + std::is_same>::value || + std::is_same>::value>::type*) + const { + auto* dev_ctx = reinterpret_cast( + platform::DeviceContextPool::Instance().Get(tensor.place())); + int dev_id = tensor.place().device; + char* gpu_str_ptr = + GetGpuHintStringPtr(*dev_ctx, op_type, var_name, dev_id); #ifdef __HIPCC__ // HIP will throw GPU memory access fault if threads > 256 @@ -406,7 +407,7 @@ void TensorCheckerVisitor::apply( #endif size_t blocks = std::min(static_cast(128), - static_cast((tensor_.numel() + threads - 1) / threads)); + static_cast((tensor.numel() + threads - 1) / threads)); #ifdef __HIPCC__ int print_num = 3; @@ -415,44 +416,46 @@ void TensorCheckerVisitor::apply( dim3(threads), 0, dev_ctx->stream(), - tensor_.data(), - tensor_.numel(), + tensor.data(), + tensor.numel(), print_num, gpu_str_ptr); #else using MT = typename phi::dtype::MPTypeTrait::Type; - phi::DenseTensor found_nan_inf; - found_nan_inf.Resize({2}); - int* found_nan_inf_ptr = found_nan_inf.mutable_data(tensor_.place()); - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( - found_nan_inf_ptr, 0, 2 * sizeof(int), dev_ctx->stream())); - int64_t numel_max_min = blocks; + phi::DenseTensor block_num_nan_inf; + block_num_nan_inf.Resize({static_cast(2 * numel_max_min)}); + int64_t* block_num_nan_ptr = + block_num_nan_inf.mutable_data(tensor.place()); + int64_t* block_num_inf_ptr = block_num_nan_ptr + numel_max_min; + phi::DenseTensor tensor_block_max_min; tensor_block_max_min.Resize({static_cast(3 * numel_max_min)}); MT* tensor_block_max_ptr = - tensor_block_max_min.mutable_data(tensor_.place()); + tensor_block_max_min.mutable_data(tensor.place()); MT* tensor_block_min_ptr = tensor_block_max_ptr + numel_max_min; MT* tensor_block_mean_ptr = tensor_block_max_ptr + 2 * numel_max_min; FindNanInfAndBlockMaxMin - <<stream()>>>(tensor_.data(), - tensor_.numel(), - found_nan_inf_ptr, + <<stream()>>>(tensor.data(), + tensor.numel(), + block_num_nan_ptr, + block_num_inf_ptr, tensor_block_max_ptr, tensor_block_min_ptr, tensor_block_mean_ptr); int check_nan_inf_level = FLAGS_check_nan_inf_level; FindGlobalMaxMinAndPrint - <<<1, 1, 0, dev_ctx->stream()>>>(found_nan_inf_ptr, + <<<1, 1, 0, dev_ctx->stream()>>>(block_num_nan_ptr, + block_num_inf_ptr, tensor_block_max_ptr, tensor_block_min_ptr, tensor_block_mean_ptr, gpu_str_ptr, - tensor_.numel(), + tensor.numel(), numel_max_min, check_nan_inf_level); #endif diff --git a/paddle/fluid/framework/details/nan_inf_utils_detail.h b/paddle/fluid/framework/details/nan_inf_utils_detail.h index 2a25bc7b68f366f3009909cc0a8cd8175f3e58f1..0adf23fd029218446108f55c2e8e3c98b2204fd1 100644 --- a/paddle/fluid/framework/details/nan_inf_utils_detail.h +++ b/paddle/fluid/framework/details/nan_inf_utils_detail.h @@ -24,21 +24,114 @@ namespace paddle { namespace framework { namespace details { +template ::value, bool> = true> +HOSTDEVICE bool NeedPrint(MT max_value, MT min_value, int check_nan_inf_level) { + if (check_nan_inf_level >= 3) { + return true; + } else if (check_nan_inf_level >= 2) { + MT fp16_max = + static_cast(std::numeric_limits::max()); + return max_value > fp16_max || min_value < -fp16_max; + } + return false; +} + +template ::value, bool> = true> +HOSTDEVICE bool NeedPrint(MT max_value, MT min_value, int check_nan_inf_level) { + if (check_nan_inf_level >= 3) { + return true; + } + return false; +} + +template +HOSTDEVICE void PrintForDifferentLevel(const char* debug_info, + int64_t numel, + int64_t num_nan, + int64_t num_inf, + MT max_value, + MT min_value, + MT mean_value, + int check_nan_inf_level) { + if (num_nan > 0 || num_inf > 0) { + printf( + "[PRECISION] [ERROR] in %s, numel=%lld, num_nan=%lld, " + "num_inf=%lld, max=%e, min=%e, mean=%e\n", + debug_info, + static_cast(numel), // NOLINT + static_cast(num_nan), // NOLINT + static_cast(num_inf), // NOLINT + static_cast(max_value), + static_cast(min_value), + static_cast(mean_value)); + if (check_nan_inf_level == 0) { +#if defined(__NVCC__) || defined(__HIPCC__) + PADDLE_ENFORCE(false, + "There are NAN or INF (num_nan=%ld, num_inf=%lld) in %s.", + static_cast(num_nan), // NOLINT + static_cast(num_inf), // NOLINT + debug_info); +#else + PADDLE_THROW(platform::errors::PreconditionNotMet( + "There are NAN or INF (num_nan=%lld, num_inf=%lld) in %s.", + static_cast(num_nan), // NOLINT + static_cast(num_inf), // NOLINT + debug_info)); +#endif + } + } else if (NeedPrint(max_value, min_value, check_nan_inf_level)) { + printf("[PRECISION] in %s, numel=%lld, max=%e, min=%e, mean=%e\n", + debug_info, + static_cast(numel), // NOLINT + static_cast(max_value), + static_cast(min_value), + static_cast(mean_value)); + } +} + +template +inline std::string GetCpuHintString(const std::string& op_type, + const std::string& var_name, + const phi::Place& place, + int device_id = -1) { + std::string dtype_str = DataTypeToString(DataTypeTrait::DataType()); + if (dtype_str == "float") { + dtype_str = "fp32"; + } else if (dtype_str == "double") { + dtype_str = "fp64"; + } else if (dtype_str == "::paddle::platform::float16") { + dtype_str = "fp16"; + } else if (dtype_str == "::paddle::platform::bfloat16") { + dtype_str = "bf16"; + } + + std::stringstream ss; + if (platform::is_gpu_place(place)) { + ss << "[device=gpu:" << device_id << ", "; + } else { + ss << "[device=cpu, "; + } + ss << "op=" << op_type << ", tensor=" << var_name << ", dtype=" << dtype_str + << "]"; + return ss.str(); +} + template struct TensorCheckerVisitor { - TensorCheckerVisitor(const std::string& op_type, - const std::string& var_name, - const phi::DenseTensor& tensor, - const platform::Place& place) - : op_type_(op_type), - var_name_(var_name), - tensor_(tensor), - place_(place) {} + TensorCheckerVisitor(const std::string& o, + const std::string& v, + const phi::DenseTensor& t, + const platform::Place& p) + : op_type(o), var_name(v), tensor(t), place(p) {} template void apply( typename std::enable_if::value>::type* = 0) const { - VLOG(10) << var_name_ << " need not to check, it's type is not float point"; + VLOG(10) << var_name << " need not to check, it's type is not float point"; } template @@ -49,10 +142,10 @@ struct TensorCheckerVisitor { std::is_same>::value>::type* = 0) const; - std::string op_type_; - std::string var_name_; - const phi::DenseTensor& tensor_; - const platform::Place& place_; + std::string op_type; + std::string var_name; + const phi::DenseTensor& tensor; + const platform::Place& place; }; template diff --git a/python/paddle/fluid/tests/unittests/test_nan_inf.py b/python/paddle/fluid/tests/unittests/test_nan_inf.py index 289b5ec40f9d5f79f63b12f92918d27a76d42217..4f0e02fdf613c057a521c838e9e556ac58b958f6 100644 --- a/python/paddle/fluid/tests/unittests/test_nan_inf.py +++ b/python/paddle/fluid/tests/unittests/test_nan_inf.py @@ -17,9 +17,9 @@ import subprocess import sys import unittest -import paddle +import numpy as np -paddle.enable_static() +import paddle class TestNanInf(unittest.TestCase): @@ -47,12 +47,7 @@ class TestNanInf(unittest.TestCase): print(err) # in python3, type(out+err) is 'bytes', need use encode - if paddle.fluid.core.is_compiled_with_cuda(): - assert (out + err).find('find_nan=1, find_inf=1'.encode()) != -1 - else: - assert (out + err).find( - 'There are `nan` or `inf` in tensor'.encode() - ) != -1 + assert (out + err).find('There are NAN or INF'.encode()) != -1 def test_nan_inf_in_static_mode(self): self._python_interp += " check_nan_inf_base.py" @@ -75,5 +70,97 @@ class TestNanInfEnv(TestNanInf): ) +class TestNanInfCheckResult(unittest.TestCase): + def generate_inputs(self, shape, dtype="float32"): + data = np.random.random(size=shape).astype(dtype) + # [-10, 10) + x = (data * 20 - 10) * np.random.randint( + low=0, high=2, size=shape + ).astype(dtype) + y = np.random.randint(low=0, high=2, size=shape).astype(dtype) + return x, y + + def get_reference_num_nan_inf(self, x): + out = np.log(x) + num_nan = np.sum(np.isnan(out)) + num_inf = np.sum(np.isinf(out)) + print("[reference] num_nan={}, num_inf={}".format(num_nan, num_inf)) + return num_nan, num_inf + + def get_num_nan_inf(self, x_np, use_cuda=True, add_assert=False): + num_nan = 0 + num_inf = 0 + try: + if use_cuda: + paddle.device.set_device("gpu:0") + else: + paddle.device.set_device("cpu") + x = paddle.to_tensor(x_np) + out = paddle.log(x) + sys.stdout.flush() + if add_assert: + assert False + except Exception as e: + # Cannot catch the log in CUDA kernel. + err_str_list = ( + str(e) + .replace("(", " ") + .replace(")", " ") + .replace(",", " ") + .split(" ") + ) + for err_str in err_str_list: + if "num_nan" in err_str: + num_nan = int(err_str.split("=")[1]) + elif "num_inf" in err_str: + num_inf = int(err_str.split("=")[1]) + print("[paddle] num_nan={}, num_inf={}".format(num_nan, num_inf)) + return num_nan, num_inf + + def test_num_nan_inf(self): + def _check_num_nan_inf(use_cuda): + shape = [32, 32] + x_np, _ = self.generate_inputs(shape) + num_nan_np, num_inf_np = self.get_reference_num_nan_inf(x_np) + add_assert = (num_nan_np + num_inf_np) > 0 + num_nan, num_inf = self.get_num_nan_inf(x_np, use_cuda, add_assert) + if not use_cuda: + assert num_nan == num_nan_np and num_inf == num_inf_np + + paddle.set_flags( + {"FLAGS_check_nan_inf": 1, "FLAGS_check_nan_inf_level": 0} + ) + _check_num_nan_inf(use_cuda=False) + if paddle.fluid.core.is_compiled_with_cuda(): + _check_num_nan_inf(use_cuda=True) + + def check_nan_inf_level(self, use_cuda, dtype): + shape = [8, 8] + x_np, y_np = self.generate_inputs(shape, dtype) + + if use_cuda: + paddle.device.set_device("gpu:0") + else: + paddle.device.set_device("cpu") + x = paddle.to_tensor(x_np) + y = paddle.to_tensor(y_np) + out = paddle.log(x * 1e6) / y + + def test_check_nan_inf_level_float32(self): + paddle.set_flags( + {"FLAGS_check_nan_inf": 1, "FLAGS_check_nan_inf_level": 2} + ) + self.check_nan_inf_level(use_cuda=False, dtype="float32") + if paddle.fluid.core.is_compiled_with_cuda(): + self.check_nan_inf_level(use_cuda=True, dtype="float32") + + def test_check_nan_inf_level_float16(self): + paddle.set_flags( + {"FLAGS_check_nan_inf": 1, "FLAGS_check_nan_inf_level": 3} + ) + if paddle.fluid.core.is_compiled_with_cuda(): + self.check_nan_inf_level(use_cuda=True, dtype="float16") + + if __name__ == '__main__': unittest.main()