未验证 提交 69e695b7 编写于 作者: Y Yiqun Liu 提交者: GitHub

Enhance check_nan_inf implementation for CPU. (#48591)

* Enable to print device info.

* Enhance the nan and inf checking for cpu.

* Implement a common print function.

* Unify the check of complex numbers.

* Rewrite the omp method.

* Count and print the number of nan and inf.

* Change the print content.

* Add unittest.
上级 6698e8d1
......@@ -38,8 +38,7 @@ namespace egr {
} catch (paddle::platform::EnforceNotMet & error) { \
caught_exception = true; \
std::string ex_msg = error.what(); \
EXPECT_TRUE(ex_msg.find("There are `nan` or `inf` in tensor") != \
std::string::npos); \
EXPECT_TRUE(ex_msg.find("There are NAN or INF") != std::string::npos); \
} \
EXPECT_TRUE(caught_exception); \
}
......@@ -52,8 +51,7 @@ namespace egr {
} catch (paddle::platform::EnforceNotMet & error) { \
caught_exception = true; \
std::string ex_msg = error.what(); \
EXPECT_TRUE(ex_msg.find("There are `nan` or `inf` in tensor") != \
std::string::npos); \
EXPECT_TRUE(ex_msg.find("There are NAN or INF") != std::string::npos); \
} \
EXPECT_FALSE(caught_exception); \
}
......
......@@ -17,6 +17,7 @@
#include "paddle/fluid/framework/details/nan_inf_utils.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/phi/common/amp_type_traits.h"
#ifdef PADDLE_WITH_ASCEND_CL
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
......@@ -24,6 +25,8 @@
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/phi/kernels/funcs/eigen/extensions.h"
DECLARE_int32(check_nan_inf_level);
namespace paddle {
namespace framework {
namespace details {
......@@ -90,7 +93,7 @@ static void InitWhiteListFormEnv() {
const char* op_role_skip = std::getenv("PADDLE_INF_NAN_SKIP_ROLE");
const char* op_var_skip = std::getenv("PADDLE_INF_NAN_SKIP_VAR");
if (op_type_skip != NULL) {
if (op_type_skip) {
std::stringstream ss(op_type_skip);
std::string op_type;
while (std::getline(ss, op_type, ',')) {
......@@ -98,7 +101,7 @@ static void InitWhiteListFormEnv() {
}
}
if (op_role_skip != NULL) {
if (op_role_skip) {
std::stringstream ss(op_role_skip);
std::string op_role;
while (std::getline(ss, op_role, ',')) {
......@@ -113,7 +116,7 @@ static void InitWhiteListFormEnv() {
}
}
if (op_var_skip != NULL) {
if (op_var_skip) {
std::stringstream ss(op_var_skip);
std::string op_var;
while (std::getline(ss, op_var, ',')) {
......@@ -131,175 +134,101 @@ static void InitWhiteListFormEnv() {
}
}
template <typename T>
static void PrintNanInf(const T* value,
const size_t numel,
int print_num,
const std::string& op_type,
const std::string& var_name,
bool abort = true) {
T min_value = std::numeric_limits<T>::max();
T max_value = std::numeric_limits<T>::min();
size_t nan_count, inf_count, num_count;
nan_count = inf_count = num_count = 0;
// CPU print num value
for (size_t i = 0; i < numel; ++i) {
size_t count = 0;
if (std::isnan(value[i])) {
count = nan_count++;
} else if (std::isinf(value[i])) {
count = inf_count++;
} else {
count = num_count++;
min_value = std::min(min_value, value[i]);
max_value = std::max(max_value, value[i]);
}
if (count < static_cast<size_t>(print_num)) {
printf("numel:%zu index:%zu value:%f\n",
numel,
i,
static_cast<float>(value[i]));
}
}
printf(
"In cpu, there has %zu,%zu,%zu nan,inf,num. "
"And in num, min_value is %f, max_value is %f\n",
nan_count,
inf_count,
num_count,
static_cast<double>(min_value),
static_cast<double>(max_value));
if (abort) {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"There are `nan` or `inf` in tensor (%s) of operator (%s).",
var_name,
op_type));
}
}
template <
typename T,
std::enable_if_t<!std::is_same<T, phi::dtype::complex<float>>::value &&
!std::is_same<T, phi::dtype::complex<double>>::value,
bool> = true>
static void CheckNanInfCpuImpl(const T* value_ptr,
const int64_t numel,
const std::string& cpu_hint_str) {
using MT = typename phi::dtype::template MPTypeTrait<T>::Type;
#ifdef _OPENMP
// Use maximum 4 threads to collect the nan and inf information.
int num_threads = std::max(omp_get_num_threads(), 1);
num_threads = std::min(num_threads, 4);
#else
int num_threads = 1;
#endif
// openmp 4.0, reduction with fp16
#if defined _OPENMP && _OPENMP >= 201307
// more detail see: 180 page of
// https://www.openmp.org/wp-content/uploads/OpenMP4.0.0.pdf
#pragma omp declare reduction(+ : paddle::platform::float16 : omp_out += omp_in)
#pragma omp declare reduction(+ : paddle::platform::bfloat16 : omp_out += \
omp_in)
#pragma omp declare reduction(+ : paddle::platform::complex < \
float > : omp_out += omp_in)
#pragma omp declare reduction(+ : paddle::platform::complex < \
double > : omp_out += omp_in)
std::vector<int64_t> thread_num_nan(num_threads, 0);
std::vector<int64_t> thread_num_inf(num_threads, 0);
std::vector<MT> thread_min_value(num_threads, static_cast<MT>(value_ptr[0]));
std::vector<MT> thread_max_value(num_threads, static_cast<MT>(value_ptr[0]));
std::vector<MT> thread_mean_value(num_threads, static_cast<MT>(0));
#ifdef _OPENMP
#pragma omp parallel num_threads(num_threads)
#endif
template <typename T>
static void CheckNanInf(const T* value,
const size_t numel,
int print_num,
const std::string& op_type,
const std::string& var_name) {
T sum = static_cast<T>(0.0);
#if defined _OPENMP && _OPENMP >= 201307
#pragma omp parallel for simd reduction(+ : sum)
#elif defined _OPENMP
#pragma omp parallel for reduction(+ : sum)
{
#ifdef _OPENMP
int64_t tid = omp_get_thread_num();
int64_t chunk_size = (numel + num_threads - 1) / num_threads;
int64_t begin = tid * chunk_size;
int64_t end = chunk_size + begin > numel ? numel : chunk_size + begin;
#else
int64_t tid = 0;
int64_t begin = 0;
int64_t end = numel;
#endif
for (size_t i = 0; i < numel; ++i) {
sum += (value[i] - value[i]);
}
if (std::isnan(sum) || std::isinf(sum)) {
PrintNanInf(value, numel, print_num, op_type, var_name);
}
}
#if defined _OPENMP && _OPENMP >= 201307
// openmp4.0 not need to specialization fp16
#elif defined _OPENMP
template <>
void CheckNanInf<paddle::platform::float16>(
const paddle::platform::float16* value,
const size_t numel,
int print_num,
const std::string& op_type,
const std::string& var_name) {
float sum = 0.0f;
#pragma omp parallel for reduction(+ : sum)
for (size_t i = 0; i < numel; ++i) {
sum += static_cast<float>(value[i] - value[i]);
}
for (int64_t i = begin; i < end; ++i) {
MT value = static_cast<MT>(value_ptr[i]);
if (std::isnan(sum) || std::isinf(sum)) {
PrintNanInf(value, numel, print_num, op_type, var_name);
}
}
thread_min_value[tid] = std::min(thread_min_value[tid], value);
thread_max_value[tid] = std::max(thread_max_value[tid], value);
thread_mean_value[tid] += value / static_cast<MT>(numel);
template <>
void CheckNanInf<paddle::platform::bfloat16>(
const paddle::platform::bfloat16* value,
const size_t numel,
int print_num,
const std::string& op_type,
const std::string& var_name) {
float sum = 0.0f;
#pragma omp parallel for reduction(+ : sum)
for (size_t i = 0; i < numel; ++i) {
sum += static_cast<float>(value[i] - value[i]);
if (std::isnan(value)) {
thread_num_nan[tid] += 1;
} else if (std::isinf(value)) {
thread_num_inf[tid] += 1;
}
if (std::isnan(sum) || std::isinf(sum)) {
PrintNanInf(value, numel, print_num, op_type, var_name);
}
}
template <>
void CheckNanInf<paddle::platform::complex<float>>(
const paddle::platform::complex<float>* value,
const size_t numel,
int print_num,
const std::string& op_type,
const std::string& var_name) {
float real_sum = 0.0f;
#pragma omp parallel for reduction(+ : real_sum)
for (size_t i = 0; i < numel; ++i) {
real_sum += (value[i].real - value[i].real);
}
float imag_sum = 0.0f;
#pragma omp parallel for reduction(+ : imag_sum)
for (size_t i = 0; i < numel; ++i) {
imag_sum += (value[i].imag - value[i].imag);
int64_t num_nan = 0;
int64_t num_inf = 0;
MT min_value = thread_min_value[0];
MT max_value = thread_max_value[0];
MT mean_value = static_cast<MT>(0);
for (int i = 0; i < num_threads; ++i) {
num_nan += thread_num_nan[i];
num_inf += thread_num_inf[i];
min_value = std::min(thread_min_value[i], min_value);
max_value = std::max(thread_max_value[i], max_value);
mean_value += thread_mean_value[i];
}
if (std::isnan(real_sum) || std::isinf(real_sum) || std::isnan(imag_sum) ||
std::isinf(imag_sum)) {
// hot fix for compile failed in gcc4.8
// here also need print detail info of nan or inf later
PADDLE_THROW(platform::errors::PreconditionNotMet(
"There are `nan` or `inf` in tensor (%s) of operator (%s).",
var_name,
op_type));
}
PrintForDifferentLevel<T, MT>(cpu_hint_str.c_str(),
numel,
num_nan,
num_inf,
max_value,
min_value,
mean_value,
FLAGS_check_nan_inf_level);
}
template <>
void CheckNanInf < paddle::platform::complex < double >>>
(const paddle::platform::complex<double>* value,
const size_t numel,
int print_num,
const std::string& op_type,
const std::string& var_name) {
double real_sum = 0.0;
#pragma omp parallel for reduction(+ : real_sum)
for (size_t i = 0; i < numel; ++i) {
real_sum += (value[i].real - value[i].real);
}
template <
typename T,
std::enable_if_t<std::is_same<T, phi::dtype::complex<float>>::value ||
std::is_same<T, phi::dtype::complex<double>>::value,
bool> = true>
void CheckNanInfCpuImpl(const T* value_ptr,
const int64_t numel,
const std::string& cpu_hint_str) {
using RealType = typename T::value_type;
RealType real_sum = 0.0f, imag_sum = 0.0f;
double imag_sum = 0.0;
#pragma omp parallel for reduction(+ : imag_sum)
for (size_t i = 0; i < numel; ++i) {
imag_sum += (value[i].imag - value[i].imag);
#ifdef _OPENMP
#pragma omp parallel for reduction(+ : real_sum) reduction(+ : imag_sum)
#endif
for (int64_t i = 0; i < numel; ++i) {
T value = value_ptr[i];
real_sum += (value.real - value.real);
imag_sum += (value.imag - value.imag);
}
if (std::isnan(real_sum) || std::isinf(real_sum) || std::isnan(imag_sum) ||
......@@ -307,14 +236,10 @@ template <>
// hot fix for compile failed in gcc4.8
// here also need print detail info of nan or inf later
PADDLE_THROW(platform::errors::PreconditionNotMet(
"There are `nan` or `inf` in tensor (%s) of operator (%s).",
var_name,
op_type));
"There are NAN or INF in %s.", cpu_hint_str));
}
}
#endif
template <>
template <typename T>
void TensorCheckerVisitor<phi::CPUContext>::apply(
......@@ -323,10 +248,9 @@ void TensorCheckerVisitor<phi::CPUContext>::apply(
std::is_same<T, ::paddle::platform::complex<float>>::value ||
std::is_same<T, ::paddle::platform::complex<double>>::value>::type*)
const {
// use env strategy control in future, -1=print_all.
int print_num = 3;
CheckNanInf(
tensor_.data<T>(), tensor_.numel(), print_num, op_type_, var_name_);
std::string cpu_hint_str =
GetCpuHintString<T>(op_type, var_name, tensor.place());
CheckNanInfCpuImpl(tensor.data<T>(), tensor.numel(), cpu_hint_str);
}
template <>
......@@ -371,8 +295,8 @@ void CheckVarHasNanOrInf(const std::string& op_type,
tensor_check<phi::GPUContext>(op_type, var_name, *tensor, place);
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"phi::DenseTensor[%s] use gpu place. PaddlePaddle must compile with "
"GPU.",
"phi::DenseTensor[%s] use gpu place. PaddlePaddle must compile "
"with GPU.",
var_name));
#endif
return;
......@@ -406,8 +330,8 @@ void CheckVarHasNanOrInf(const std::string& op_type,
var_name));
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"phi::DenseTensor[%s] use xpu place. PaddlePaddle must compile with "
"XPU.",
"phi::DenseTensor[%s] use xpu place. PaddlePaddle must compile "
"with XPU.",
var_name));
#endif
return;
......@@ -440,8 +364,8 @@ void CheckVarHasNanOrInf(const std::string& op_type,
var_name));
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"phi::DenseTensor[%s] use npu place. PaddlePaddle must compile with "
"NPU.",
"phi::DenseTensor[%s] use npu place. PaddlePaddle must compile "
"with NPU.",
var_name));
#endif
return;
......
......@@ -138,6 +138,54 @@ __global__ void CheckNanInfKernel(const T* value,
PrintNanInfKernel(value, numel, print_num, debug_info);
}
template <typename T, int ReduceType>
__device__ T BlockReduce(T value) {
__shared__ T shared_mem[1024];
shared_mem[threadIdx.x] = value;
__syncthreads();
for (int stride = blockDim.x >> 1; stride > 0; stride = stride >> 1) {
if (threadIdx.x < stride) {
T value0 = shared_mem[threadIdx.x];
T value1 = shared_mem[threadIdx.x + stride];
T reduce_value;
if (ReduceType == 0) {
// max
reduce_value = value0 > value1 ? value0 : value1;
} else if (ReduceType == 1) {
// min
reduce_value = value0 < value1 ? value0 : value1;
} else if (ReduceType == 2) {
// sum
reduce_value = value0 + value1;
}
shared_mem[threadIdx.x] = reduce_value;
}
if (stride > 16) {
__syncthreads();
}
}
__syncthreads();
return shared_mem[0];
}
__device__ void BlockReduceNumNanInfAndWrite(const int64_t num_nan,
const int64_t num_inf,
int64_t offset,
int64_t* num_nan_ptr,
int64_t* num_inf_ptr) {
int64_t block_num_nan = BlockReduce<int64_t, 2>(num_nan);
int64_t block_num_inf = BlockReduce<int64_t, 2>(num_inf);
if (threadIdx.x == 0) {
num_nan_ptr[offset] = block_num_nan;
num_inf_ptr[offset] = block_num_inf;
}
}
template <
typename T,
std::enable_if_t<std::is_same<T, phi::dtype::complex<float>>::value ||
......@@ -183,15 +231,16 @@ __device__ void BlockReduceMaxMinAndWrite(const T max_value,
template <typename T, typename MT>
__global__ void FindNanInfAndBlockMaxMin(const T* value_ptr,
const int64_t numel,
int* found_nan_inf_ptr,
int64_t* block_num_nan_ptr,
int64_t* block_num_inf_ptr,
MT* tensor_block_max_ptr,
MT* tensor_block_min_ptr,
MT* tensor_block_mean_ptr) {
bool has_nan = false;
bool has_inf = false;
int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
int64_t num_nan = 0;
int64_t num_inf = 0;
MT max_value = static_cast<MT>(i < numel ? value_ptr[i] : value_ptr[0]);
MT min_value = static_cast<MT>(i < numel ? value_ptr[i] : value_ptr[0]);
MT mean_value = static_cast<MT>(0);
......@@ -203,25 +252,14 @@ __global__ void FindNanInfAndBlockMaxMin(const T* value_ptr,
mean_value += value / static_cast<MT>(numel);
if (isnan(value)) {
has_nan = true;
num_nan += 1;
} else if (isinf(value)) {
num_inf += 1;
}
if (isinf(value)) {
has_inf = true;
}
if (has_nan || has_inf) {
if (!tensor_block_max_ptr && !tensor_block_min_ptr &&
!tensor_block_mean_ptr) {
break;
}
}
}
if (has_nan) {
found_nan_inf_ptr[0] = 1;
}
if (has_inf) {
found_nan_inf_ptr[1] = 1;
}
BlockReduceNumNanInfAndWrite(
num_nan, num_inf, blockIdx.x, block_num_nan_ptr, block_num_inf_ptr);
BlockReduceMaxMinAndWrite<MT>(max_value,
min_value,
......@@ -232,32 +270,9 @@ __global__ void FindNanInfAndBlockMaxMin(const T* value_ptr,
tensor_block_mean_ptr);
}
template <typename T,
typename MT,
std::enable_if_t<std::is_same<T, float>::value, bool> = true>
__device__ bool NeedPrint(MT max_value, MT min_value, int check_nan_inf_level) {
if (check_nan_inf_level >= 3) {
return true;
} else if (check_nan_inf_level >= 2) {
MT fp16_max =
static_cast<MT>(std::numeric_limits<phi::dtype::float16>::max());
return max_value > fp16_max || min_value < -fp16_max;
}
return false;
}
template <typename T,
typename MT,
std::enable_if_t<!std::is_same<T, float>::value, bool> = true>
__device__ bool NeedPrint(MT max_value, MT min_value, int check_nan_inf_level) {
if (check_nan_inf_level >= 3) {
return true;
}
return false;
}
template <typename T, typename MT>
__global__ void FindGlobalMaxMinAndPrint(const int* found_nan_inf_ptr,
__global__ void FindGlobalMaxMinAndPrint(const int64_t* block_num_nan_ptr,
const int64_t* block_num_inf_ptr,
const MT* tensor_block_max_ptr,
const MT* tensor_block_min_ptr,
const MT* tensor_block_mean_ptr,
......@@ -266,8 +281,14 @@ __global__ void FindGlobalMaxMinAndPrint(const int* found_nan_inf_ptr,
int64_t numel_max_min,
int check_nan_inf_level) {
if (blockIdx.x == 0 && threadIdx.x == 0) {
int has_nan = found_nan_inf_ptr[0];
int has_inf = found_nan_inf_ptr[1];
int64_t num_nan = 0;
int64_t num_inf = 0;
// numel_max_min <= 128
for (int64_t i = 0; i < numel_max_min; ++i) {
num_nan += block_num_nan_ptr[i];
num_inf += block_num_inf_ptr[i];
}
MT max_value = static_cast<MT>(0);
MT min_value = static_cast<MT>(0);
......@@ -289,67 +310,31 @@ __global__ void FindGlobalMaxMinAndPrint(const int* found_nan_inf_ptr,
}
}
if (has_nan || has_inf) {
if (check_nan_inf_level == 0) {
PADDLE_ENFORCE(false,
"===[PRECISION] [ERROR] in %s, numel=%ld, find_nan=%d, "
"find_inf=%d, "
"max=%e, min=%e, mean=%e===\n",
debug_info,
numel,
has_nan,
has_inf,
static_cast<float>(max_value),
static_cast<float>(min_value),
static_cast<float>(mean_value));
} else if (check_nan_inf_level >= 1) {
printf(
"===[PRECISION] [ERROR] in %s, numel=%ld, find_nan=%d, "
"find_inf=%d, "
"max=%e, min=%e, mean=%e===\n",
debug_info,
numel,
has_nan,
has_inf,
static_cast<float>(max_value),
static_cast<float>(min_value),
static_cast<float>(mean_value));
}
} else if (NeedPrint<T, MT>(max_value, min_value, check_nan_inf_level)) {
printf("[PRECISION] in %s, numel=%ld, max=%e, min=%e, mean=%e\n",
debug_info,
PrintForDifferentLevel<T, MT>(debug_info,
numel,
static_cast<float>(max_value),
static_cast<float>(min_value),
static_cast<float>(mean_value));
}
num_nan,
num_inf,
max_value,
min_value,
mean_value,
check_nan_inf_level);
}
}
template <>
template <typename T>
void TensorCheckerVisitor<phi::GPUContext>::apply(
typename std::enable_if<
std::is_floating_point<T>::value ||
std::is_same<T, ::paddle::platform::complex<float>>::value ||
std::is_same<T, ::paddle::platform::complex<double>>::value>::type*)
const {
auto* dev_ctx = reinterpret_cast<phi::GPUContext*>(
platform::DeviceContextPool::Instance().Get(tensor_.place()));
int dev_id = tensor_.place().device;
static char* GetGpuHintStringPtr(const phi::GPUContext& ctx,
const std::string& op_type,
const std::string& var_name,
int dev_id) {
PADDLE_ENFORCE_EQ(
(dev_id >= 0 && dev_id < multi_op_var2gpu_str_mutex().size()),
true,
platform::errors::OutOfRange("GPU dev_id must >=0 and < dev_count=%d",
multi_op_var2gpu_str_mutex().size()));
std::string dtype_str = DataTypeToString(DataTypeTrait<T>::DataType());
if (dtype_str == "::paddle::platform::float16") {
dtype_str = "float16";
}
std::string op_var = "[op=" + op_type_ + "] [tensor=" + var_name_ +
"] [dtype=" + dtype_str + "]";
char* gpu_str_ptr = NULL;
std::string op_var =
GetCpuHintString<T>(op_type, var_name, ctx.GetPlace(), dev_id);
char* gpu_str_ptr = nullptr;
{
auto& op_var2gpu_str_mutex = multi_op_var2gpu_str_mutex().at(dev_id);
......@@ -358,9 +343,9 @@ void TensorCheckerVisitor<phi::GPUContext>::apply(
std::lock_guard<std::mutex> guard(op_var2gpu_str_mutex);
if (op_var2gpu_str.find(op_var) == op_var2gpu_str.end()) { // insert
auto gpu_str_tensor = paddle::memory::Alloc(
dev_ctx->GetPlace(),
ctx.GetPlace(),
op_var.length() + 1,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx->stream())));
phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
gpu_str_ptr = reinterpret_cast<char*>(gpu_str_tensor->ptr());
op_var2gpu_str.emplace(op_var, std::move(gpu_str_tensor));
......@@ -378,13 +363,13 @@ void TensorCheckerVisitor<phi::GPUContext>::apply(
iter->first.c_str(),
op_var.length() + 1,
hipMemcpyHostToDevice,
dev_ctx->stream()));
ctx.stream()));
#else
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(gpu_str_ptr,
iter->first.c_str(),
op_var.length() + 1,
cudaMemcpyHostToDevice,
dev_ctx->stream()));
ctx.stream()));
#endif
} else { // get
auto iter = op_var2gpu_str.find(op_var);
......@@ -397,6 +382,22 @@ void TensorCheckerVisitor<phi::GPUContext>::apply(
gpu_str_ptr = reinterpret_cast<char*>(iter->second->ptr());
}
}
return gpu_str_ptr;
}
template <>
template <typename T>
void TensorCheckerVisitor<phi::GPUContext>::apply(
typename std::enable_if<
std::is_floating_point<T>::value ||
std::is_same<T, ::paddle::platform::complex<float>>::value ||
std::is_same<T, ::paddle::platform::complex<double>>::value>::type*)
const {
auto* dev_ctx = reinterpret_cast<phi::GPUContext*>(
platform::DeviceContextPool::Instance().Get(tensor.place()));
int dev_id = tensor.place().device;
char* gpu_str_ptr =
GetGpuHintStringPtr<T>(*dev_ctx, op_type, var_name, dev_id);
#ifdef __HIPCC__
// HIP will throw GPU memory access fault if threads > 256
......@@ -406,7 +407,7 @@ void TensorCheckerVisitor<phi::GPUContext>::apply(
#endif
size_t blocks =
std::min(static_cast<size_t>(128),
static_cast<size_t>((tensor_.numel() + threads - 1) / threads));
static_cast<size_t>((tensor.numel() + threads - 1) / threads));
#ifdef __HIPCC__
int print_num = 3;
......@@ -415,44 +416,46 @@ void TensorCheckerVisitor<phi::GPUContext>::apply(
dim3(threads),
0,
dev_ctx->stream(),
tensor_.data<T>(),
tensor_.numel(),
tensor.data<T>(),
tensor.numel(),
print_num,
gpu_str_ptr);
#else
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
phi::DenseTensor found_nan_inf;
found_nan_inf.Resize({2});
int* found_nan_inf_ptr = found_nan_inf.mutable_data<int>(tensor_.place());
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
found_nan_inf_ptr, 0, 2 * sizeof(int), dev_ctx->stream()));
int64_t numel_max_min = blocks;
phi::DenseTensor block_num_nan_inf;
block_num_nan_inf.Resize({static_cast<int64_t>(2 * numel_max_min)});
int64_t* block_num_nan_ptr =
block_num_nan_inf.mutable_data<int64_t>(tensor.place());
int64_t* block_num_inf_ptr = block_num_nan_ptr + numel_max_min;
phi::DenseTensor tensor_block_max_min;
tensor_block_max_min.Resize({static_cast<int64_t>(3 * numel_max_min)});
MT* tensor_block_max_ptr =
tensor_block_max_min.mutable_data<MT>(tensor_.place());
tensor_block_max_min.mutable_data<MT>(tensor.place());
MT* tensor_block_min_ptr = tensor_block_max_ptr + numel_max_min;
MT* tensor_block_mean_ptr = tensor_block_max_ptr + 2 * numel_max_min;
FindNanInfAndBlockMaxMin<T, MT>
<<<blocks, threads, 0, dev_ctx->stream()>>>(tensor_.data<T>(),
tensor_.numel(),
found_nan_inf_ptr,
<<<blocks, threads, 0, dev_ctx->stream()>>>(tensor.data<T>(),
tensor.numel(),
block_num_nan_ptr,
block_num_inf_ptr,
tensor_block_max_ptr,
tensor_block_min_ptr,
tensor_block_mean_ptr);
int check_nan_inf_level = FLAGS_check_nan_inf_level;
FindGlobalMaxMinAndPrint<T, MT>
<<<1, 1, 0, dev_ctx->stream()>>>(found_nan_inf_ptr,
<<<1, 1, 0, dev_ctx->stream()>>>(block_num_nan_ptr,
block_num_inf_ptr,
tensor_block_max_ptr,
tensor_block_min_ptr,
tensor_block_mean_ptr,
gpu_str_ptr,
tensor_.numel(),
tensor.numel(),
numel_max_min,
check_nan_inf_level);
#endif
......
......@@ -24,21 +24,114 @@ namespace paddle {
namespace framework {
namespace details {
template <typename T,
typename MT,
std::enable_if_t<std::is_same<T, float>::value, bool> = true>
HOSTDEVICE bool NeedPrint(MT max_value, MT min_value, int check_nan_inf_level) {
if (check_nan_inf_level >= 3) {
return true;
} else if (check_nan_inf_level >= 2) {
MT fp16_max =
static_cast<MT>(std::numeric_limits<phi::dtype::float16>::max());
return max_value > fp16_max || min_value < -fp16_max;
}
return false;
}
template <typename T,
typename MT,
std::enable_if_t<!std::is_same<T, float>::value, bool> = true>
HOSTDEVICE bool NeedPrint(MT max_value, MT min_value, int check_nan_inf_level) {
if (check_nan_inf_level >= 3) {
return true;
}
return false;
}
template <typename T, typename MT>
HOSTDEVICE void PrintForDifferentLevel(const char* debug_info,
int64_t numel,
int64_t num_nan,
int64_t num_inf,
MT max_value,
MT min_value,
MT mean_value,
int check_nan_inf_level) {
if (num_nan > 0 || num_inf > 0) {
printf(
"[PRECISION] [ERROR] in %s, numel=%lld, num_nan=%lld, "
"num_inf=%lld, max=%e, min=%e, mean=%e\n",
debug_info,
static_cast<long long>(numel), // NOLINT
static_cast<long long>(num_nan), // NOLINT
static_cast<long long>(num_inf), // NOLINT
static_cast<float>(max_value),
static_cast<float>(min_value),
static_cast<float>(mean_value));
if (check_nan_inf_level == 0) {
#if defined(__NVCC__) || defined(__HIPCC__)
PADDLE_ENFORCE(false,
"There are NAN or INF (num_nan=%ld, num_inf=%lld) in %s.",
static_cast<long long>(num_nan), // NOLINT
static_cast<long long>(num_inf), // NOLINT
debug_info);
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"There are NAN or INF (num_nan=%lld, num_inf=%lld) in %s.",
static_cast<long long>(num_nan), // NOLINT
static_cast<long long>(num_inf), // NOLINT
debug_info));
#endif
}
} else if (NeedPrint<T, MT>(max_value, min_value, check_nan_inf_level)) {
printf("[PRECISION] in %s, numel=%lld, max=%e, min=%e, mean=%e\n",
debug_info,
static_cast<long long>(numel), // NOLINT
static_cast<float>(max_value),
static_cast<float>(min_value),
static_cast<float>(mean_value));
}
}
template <typename T>
inline std::string GetCpuHintString(const std::string& op_type,
const std::string& var_name,
const phi::Place& place,
int device_id = -1) {
std::string dtype_str = DataTypeToString(DataTypeTrait<T>::DataType());
if (dtype_str == "float") {
dtype_str = "fp32";
} else if (dtype_str == "double") {
dtype_str = "fp64";
} else if (dtype_str == "::paddle::platform::float16") {
dtype_str = "fp16";
} else if (dtype_str == "::paddle::platform::bfloat16") {
dtype_str = "bf16";
}
std::stringstream ss;
if (platform::is_gpu_place(place)) {
ss << "[device=gpu:" << device_id << ", ";
} else {
ss << "[device=cpu, ";
}
ss << "op=" << op_type << ", tensor=" << var_name << ", dtype=" << dtype_str
<< "]";
return ss.str();
}
template <typename DeviceContext>
struct TensorCheckerVisitor {
TensorCheckerVisitor(const std::string& op_type,
const std::string& var_name,
const phi::DenseTensor& tensor,
const platform::Place& place)
: op_type_(op_type),
var_name_(var_name),
tensor_(tensor),
place_(place) {}
TensorCheckerVisitor(const std::string& o,
const std::string& v,
const phi::DenseTensor& t,
const platform::Place& p)
: op_type(o), var_name(v), tensor(t), place(p) {}
template <typename T>
void apply(
typename std::enable_if<std::is_integral<T>::value>::type* = 0) const {
VLOG(10) << var_name_ << " need not to check, it's type is not float point";
VLOG(10) << var_name << " need not to check, it's type is not float point";
}
template <typename T>
......@@ -49,10 +142,10 @@ struct TensorCheckerVisitor {
std::is_same<T, ::paddle::platform::complex<double>>::value>::type* =
0) const;
std::string op_type_;
std::string var_name_;
const phi::DenseTensor& tensor_;
const platform::Place& place_;
std::string op_type;
std::string var_name;
const phi::DenseTensor& tensor;
const platform::Place& place;
};
template <typename DeviceContext>
......
......@@ -17,9 +17,9 @@ import subprocess
import sys
import unittest
import paddle
import numpy as np
paddle.enable_static()
import paddle
class TestNanInf(unittest.TestCase):
......@@ -47,12 +47,7 @@ class TestNanInf(unittest.TestCase):
print(err)
# in python3, type(out+err) is 'bytes', need use encode
if paddle.fluid.core.is_compiled_with_cuda():
assert (out + err).find('find_nan=1, find_inf=1'.encode()) != -1
else:
assert (out + err).find(
'There are `nan` or `inf` in tensor'.encode()
) != -1
assert (out + err).find('There are NAN or INF'.encode()) != -1
def test_nan_inf_in_static_mode(self):
self._python_interp += " check_nan_inf_base.py"
......@@ -75,5 +70,97 @@ class TestNanInfEnv(TestNanInf):
)
class TestNanInfCheckResult(unittest.TestCase):
def generate_inputs(self, shape, dtype="float32"):
data = np.random.random(size=shape).astype(dtype)
# [-10, 10)
x = (data * 20 - 10) * np.random.randint(
low=0, high=2, size=shape
).astype(dtype)
y = np.random.randint(low=0, high=2, size=shape).astype(dtype)
return x, y
def get_reference_num_nan_inf(self, x):
out = np.log(x)
num_nan = np.sum(np.isnan(out))
num_inf = np.sum(np.isinf(out))
print("[reference] num_nan={}, num_inf={}".format(num_nan, num_inf))
return num_nan, num_inf
def get_num_nan_inf(self, x_np, use_cuda=True, add_assert=False):
num_nan = 0
num_inf = 0
try:
if use_cuda:
paddle.device.set_device("gpu:0")
else:
paddle.device.set_device("cpu")
x = paddle.to_tensor(x_np)
out = paddle.log(x)
sys.stdout.flush()
if add_assert:
assert False
except Exception as e:
# Cannot catch the log in CUDA kernel.
err_str_list = (
str(e)
.replace("(", " ")
.replace(")", " ")
.replace(",", " ")
.split(" ")
)
for err_str in err_str_list:
if "num_nan" in err_str:
num_nan = int(err_str.split("=")[1])
elif "num_inf" in err_str:
num_inf = int(err_str.split("=")[1])
print("[paddle] num_nan={}, num_inf={}".format(num_nan, num_inf))
return num_nan, num_inf
def test_num_nan_inf(self):
def _check_num_nan_inf(use_cuda):
shape = [32, 32]
x_np, _ = self.generate_inputs(shape)
num_nan_np, num_inf_np = self.get_reference_num_nan_inf(x_np)
add_assert = (num_nan_np + num_inf_np) > 0
num_nan, num_inf = self.get_num_nan_inf(x_np, use_cuda, add_assert)
if not use_cuda:
assert num_nan == num_nan_np and num_inf == num_inf_np
paddle.set_flags(
{"FLAGS_check_nan_inf": 1, "FLAGS_check_nan_inf_level": 0}
)
_check_num_nan_inf(use_cuda=False)
if paddle.fluid.core.is_compiled_with_cuda():
_check_num_nan_inf(use_cuda=True)
def check_nan_inf_level(self, use_cuda, dtype):
shape = [8, 8]
x_np, y_np = self.generate_inputs(shape, dtype)
if use_cuda:
paddle.device.set_device("gpu:0")
else:
paddle.device.set_device("cpu")
x = paddle.to_tensor(x_np)
y = paddle.to_tensor(y_np)
out = paddle.log(x * 1e6) / y
def test_check_nan_inf_level_float32(self):
paddle.set_flags(
{"FLAGS_check_nan_inf": 1, "FLAGS_check_nan_inf_level": 2}
)
self.check_nan_inf_level(use_cuda=False, dtype="float32")
if paddle.fluid.core.is_compiled_with_cuda():
self.check_nan_inf_level(use_cuda=True, dtype="float32")
def test_check_nan_inf_level_float16(self):
paddle.set_flags(
{"FLAGS_check_nan_inf": 1, "FLAGS_check_nan_inf_level": 3}
)
if paddle.fluid.core.is_compiled_with_cuda():
self.check_nan_inf_level(use_cuda=True, dtype="float16")
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册