未验证 提交 ad39043f 编写于 作者: Y Yiqun Liu 提交者: GitHub

Improve the tool for checking nan and inf, and support to compute the max, min...

Improve the tool for checking nan and inf, and support to compute the max, min and mean of output tensor. (#47095)

* Improve the tool for checking nan and inf, and support to compute the max, min and mean of output tensor.

* Add a FLAGS to control whether abort when meets inf/nan and polish codes.

* Fix unittest.

* Change the computing of mean.
上级 99f60188
......@@ -12,15 +12,21 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/nan_inf_utils_detail.h"
#include "paddle/fluid/framework/details/nan_inf_utils.h"
#include <algorithm>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/details/nan_inf_utils.h"
#include "paddle/fluid/framework/details/nan_inf_utils_detail.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"
DECLARE_bool(abort_on_nan_inf);
DECLARE_bool(check_tensor_max_min);
namespace paddle {
namespace framework {
......@@ -133,6 +139,171 @@ __global__ void CheckNanInfKernel(const T* value,
PrintNanInfKernel(value, numel, print_num, debug_info);
}
template <
typename T,
std::enable_if_t<std::is_same<T, phi::dtype::complex<float>>::value ||
std::is_same<T, phi::dtype::complex<double>>::value,
bool> = true>
__device__ void BlockReduceMaxMinAndWrite(const T max_value,
const T min_value,
const T mean_value,
int64_t offset,
T* max_ptr,
T* min_ptr,
T* mean_ptr) {
// TODO(Xreki): support complex
}
template <
typename T,
std::enable_if_t<!std::is_same<T, phi::dtype::complex<float>>::value &&
!std::is_same<T, phi::dtype::complex<double>>::value,
bool> = true>
__device__ void BlockReduceMaxMinAndWrite(const T max_value,
const T min_value,
const T mean_value,
int64_t offset,
T* max_ptr,
T* min_ptr,
T* mean_ptr) {
if (max_ptr && min_ptr && mean_ptr) {
__syncthreads();
T block_max_value = phi::funcs::blockReduceMax<T>(max_value, FINAL_MASK);
T block_min_value = phi::funcs::blockReduceMin<T>(min_value, FINAL_MASK);
T block_mean_value = phi::funcs::blockReduceSum<T>(mean_value, FINAL_MASK);
if (threadIdx.x == 0) {
max_ptr[offset] = block_max_value;
min_ptr[offset] = block_min_value;
mean_ptr[offset] = block_mean_value;
}
}
}
template <typename T, typename MT>
__global__ void FindNanInfAndBlockMaxMin(const T* value_ptr,
const int64_t numel,
int* found_nan_inf_ptr,
MT* tensor_block_max_ptr,
MT* tensor_block_min_ptr,
MT* tensor_block_mean_ptr) {
bool has_nan = false;
bool has_inf = false;
int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
MT max_value = static_cast<MT>(i < numel ? value_ptr[i] : value_ptr[0]);
MT min_value = static_cast<MT>(i < numel ? value_ptr[i] : value_ptr[0]);
MT mean_value = static_cast<MT>(0);
for (; i < numel; i += blockDim.x * gridDim.x) {
MT value = static_cast<MT>(value_ptr[i]);
max_value = value > max_value ? value : max_value;
min_value = value < min_value ? value : min_value;
mean_value += value / static_cast<MT>(numel);
if (isnan(value)) {
has_nan = true;
}
if (isinf(value)) {
has_inf = true;
}
if (has_nan || has_inf) {
if (!tensor_block_max_ptr && !tensor_block_min_ptr &&
!tensor_block_mean_ptr) {
break;
}
}
}
if (has_nan) {
found_nan_inf_ptr[0] = 1;
}
if (has_inf) {
found_nan_inf_ptr[1] = 1;
}
BlockReduceMaxMinAndWrite<MT>(max_value,
min_value,
mean_value,
blockIdx.x,
tensor_block_max_ptr,
tensor_block_min_ptr,
tensor_block_mean_ptr);
}
template <typename T>
__global__ void FindGlobalMaxMinAndPrint(const int* found_nan_inf_ptr,
const T* tensor_block_max_ptr,
const T* tensor_block_min_ptr,
const T* tensor_block_mean_ptr,
const char* debug_info,
int64_t numel,
int64_t numel_max_min,
bool abort_on_nan_inf,
bool check_tensor_max_min) {
if (blockIdx.x == 0 && threadIdx.x == 0) {
int has_nan = found_nan_inf_ptr[0];
int has_inf = found_nan_inf_ptr[1];
T max_value = static_cast<T>(0);
T min_value = static_cast<T>(0);
T mean_value = static_cast<T>(0);
if (tensor_block_max_ptr && tensor_block_min_ptr && tensor_block_mean_ptr) {
max_value = tensor_block_max_ptr[0];
min_value = tensor_block_min_ptr[0];
mean_value = tensor_block_mean_ptr[0];
// numel_max_min <= 128
for (int64_t i = 1; i < numel_max_min; ++i) {
T tmp_max_value = tensor_block_max_ptr[i];
T tmp_min_value = tensor_block_min_ptr[i];
T tmp_mean_value = tensor_block_mean_ptr[i];
max_value = tmp_max_value > max_value ? tmp_max_value : max_value;
min_value = tmp_min_value < min_value ? tmp_min_value : min_value;
mean_value += tmp_mean_value;
}
}
if (has_nan || has_inf) {
if (abort_on_nan_inf) {
PADDLE_ENFORCE(false,
"===[PRECISION] [ERROR] in %s, numel=%ld, find_nan=%d, "
"find_inf=%d, "
"max=%e, min=%e, mean=%e===\n",
debug_info,
numel,
has_nan,
has_inf,
static_cast<float>(max_value),
static_cast<float>(min_value),
static_cast<float>(mean_value));
} else {
printf(
"===[PRECISION] [ERROR] in %s, numel=%ld, find_nan=%d, "
"find_inf=%d, "
"max=%e, min=%e, mean=%e===\n",
debug_info,
numel,
has_nan,
has_inf,
static_cast<float>(max_value),
static_cast<float>(min_value),
static_cast<float>(mean_value));
}
} else if (check_tensor_max_min) {
printf("[PRECISION] in %s, numel=%ld, max=%e, min=%e, mean=%e\n",
debug_info,
numel,
static_cast<float>(max_value),
static_cast<float>(min_value),
static_cast<float>(mean_value));
}
}
}
template <>
template <typename T>
void TensorCheckerVisitor<phi::GPUContext>::apply(
......@@ -141,8 +312,6 @@ void TensorCheckerVisitor<phi::GPUContext>::apply(
std::is_same<T, ::paddle::platform::complex<float>>::value ||
std::is_same<T, ::paddle::platform::complex<double>>::value>::type*)
const {
int print_num = 3;
auto* dev_ctx = reinterpret_cast<phi::GPUContext*>(
platform::DeviceContextPool::Instance().Get(tensor_.place()));
int dev_id = tensor_.place().device;
......@@ -152,7 +321,12 @@ void TensorCheckerVisitor<phi::GPUContext>::apply(
platform::errors::OutOfRange("GPU dev_id must >=0 and < dev_count=%d",
multi_op_var2gpu_str_mutex().size()));
std::string op_var = "[op=" + op_type_ + "] [tensor=" + var_name_ + "]";
std::string dtype_str = DataTypeToString(DataTypeTrait<T>::DataType());
if (dtype_str == "::paddle::platform::float16") {
dtype_str = "float16";
}
std::string op_var = "[op=" + op_type_ + "] [tensor=" + var_name_ +
"] [dtype=" + dtype_str + "]";
char* gpu_str_ptr = NULL;
{
......@@ -212,6 +386,8 @@ void TensorCheckerVisitor<phi::GPUContext>::apply(
std::min(static_cast<size_t>(128),
static_cast<size_t>((tensor_.numel() + threads - 1) / threads));
#ifdef __HIPCC__
int print_num = 3;
hipLaunchKernelGGL(CheckNanInfKernel,
dim3(blocks),
dim3(threads),
......@@ -222,8 +398,43 @@ void TensorCheckerVisitor<phi::GPUContext>::apply(
print_num,
gpu_str_ptr);
#else
CheckNanInfKernel<<<blocks, threads, 0, dev_ctx->stream()>>>(
tensor_.data<T>(), tensor_.numel(), print_num, gpu_str_ptr);
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
phi::DenseTensor found_nan_inf;
found_nan_inf.Resize({2});
int* found_nan_inf_ptr = found_nan_inf.mutable_data<int>(tensor_.place());
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
found_nan_inf_ptr, 0, 2 * sizeof(int), dev_ctx->stream()));
int64_t numel_max_min = blocks;
phi::DenseTensor tensor_block_max_min;
tensor_block_max_min.Resize({static_cast<int64_t>(3 * numel_max_min)});
MT* tensor_block_max_ptr =
tensor_block_max_min.mutable_data<MT>(tensor_.place());
MT* tensor_block_min_ptr = tensor_block_max_ptr + numel_max_min;
MT* tensor_block_mean_ptr = tensor_block_max_ptr + 2 * numel_max_min;
FindNanInfAndBlockMaxMin<T, MT>
<<<blocks, threads, 0, dev_ctx->stream()>>>(tensor_.data<T>(),
tensor_.numel(),
found_nan_inf_ptr,
tensor_block_max_ptr,
tensor_block_min_ptr,
tensor_block_mean_ptr);
bool abort_on_nan_inf = FLAGS_abort_on_nan_inf;
bool check_tensor_max_min = FLAGS_check_tensor_max_min;
FindGlobalMaxMinAndPrint<MT>
<<<1, 1, 0, dev_ctx->stream()>>>(found_nan_inf_ptr,
tensor_block_max_ptr,
tensor_block_min_ptr,
tensor_block_mean_ptr,
gpu_str_ptr,
tensor_.numel(),
numel_max_min,
abort_on_nan_inf,
check_tensor_max_min);
#endif
}
......
......@@ -68,6 +68,34 @@ PADDLE_DEFINE_EXPORTED_bool(
"Checking whether operator produce NAN/INF or not. It will be "
"extremely slow so please use this flag wisely.");
/**
* Operator related FLAG
* Name: FLAGS_abort_on_nan_inf
* Since Version: 2.5.0
* Value Range: bool, default=true
* Example:
* Note: Used to debug. Whether abort the process when any operator produce
* NAN/INF. It only works when FLAGS_check_nan_inf is set.
*/
PADDLE_DEFINE_EXPORTED_bool(
abort_on_nan_inf,
true,
"Whether abort the process when any operator produce NAN/INF or not.");
/**
* Operator related FLAG
* Name: FLAGS_check_tensor_max_min
* Since Version: 2.5.0
* Value Range: bool, default=false
* Example:
* Note: Used to debug. Enable to calculate and print the max and min value of
* each operator's output tensor. It only works when FLAGS_check_nan_inf is set.
*/
PADDLE_DEFINE_EXPORTED_bool(
check_tensor_max_min,
false,
"Whether to check all the output tensors's min and max value.");
/**
* Operator related FLAG
* Name: FLAGS_check_nan_inf
......
......@@ -47,7 +47,7 @@ class TestNanInf(unittest.TestCase):
# in python3, type(out+err) is 'bytes', need use encode
if paddle.fluid.core.is_compiled_with_cuda():
assert (out + err).find('find nan or inf==='.encode()) != -1
assert (out + err).find('find_nan=1, find_inf=1'.encode()) != -1
else:
assert (out + err).find(
'There are `nan` or `inf` in tensor'.encode()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册