未验证 提交 44bd5927 编写于 作者: Y Yiqun Liu 提交者: GitHub

[AMP] Reimplement check_nan_inf as check_numerics_kernel. (#52245)

* Reimplement the check_nan_inf function as check_numerics kernel.

* Remove the cpu implemention to phi.

* Add ifdef for the including of omp.h.

* Move the use of FLAGS_check_nan_inf_level out of header file.

* Implement a common PrintAndThrowError function.

* Fix the error using of __NVCC__, which should be instead with __CUDA_ARCH__.

* Add dependency of phi.

* Polish codes and unittest.
上级 1ba1627d
......@@ -70,8 +70,8 @@ endif()
if(WITH_GPU)
nv_library(
nan_inf_utils
SRCS nan_inf_utils_detail.cc nan_inf_utils_detail.cu
DEPS framework_proto scope place)
SRCS nan_inf_utils_detail.cc
DEPS framework_proto scope place phi)
nv_library(
all_reduce_op_handle
SRCS all_reduce_op_handle.cc
......@@ -144,8 +144,8 @@ if(WITH_GPU)
elseif(WITH_ROCM)
hip_library(
nan_inf_utils
SRCS nan_inf_utils_detail.cc nan_inf_utils_detail.cu
DEPS framework_proto scope place)
SRCS nan_inf_utils_detail.cc
DEPS framework_proto scope place phi)
hip_library(
all_reduce_op_handle
SRCS all_reduce_op_handle.cc
......@@ -204,7 +204,7 @@ else()
cc_library(
nan_inf_utils
SRCS nan_inf_utils_detail.cc
DEPS framework_proto scope place)
DEPS framework_proto scope place phi)
cc_library(
all_reduce_op_handle
SRCS all_reduce_op_handle.cc
......
......@@ -157,29 +157,6 @@ static void InitWhiteListFormEnv() {
}
}
template <>
template <typename T>
void TensorCheckerVisitor<phi::CPUContext>::apply(
typename std::enable_if<
std::is_floating_point<T>::value ||
std::is_same<T, ::paddle::platform::complex<float>>::value ||
std::is_same<T, ::paddle::platform::complex<double>>::value>::type*)
const {
std::string cpu_hint_str =
GetCpuHintString<T>(op_type, var_name, tensor.place());
CheckNanInfCpuImpl(tensor.data<T>(), tensor.numel(), cpu_hint_str);
}
template <>
void tensor_check<phi::CPUContext>(const std::string& op_type,
const std::string& var_name,
const phi::DenseTensor& tensor,
const platform::Place& place) {
TensorCheckerVisitor<phi::CPUContext> vistor(
op_type, var_name, tensor, place);
VisitDataType(framework::TransToProtoVarType(tensor.dtype()), vistor);
}
void CheckVarHasNanOrInf(const std::string& op_type,
const std::string& var_name,
const framework::Variable* var,
......
......@@ -13,25 +13,15 @@
// limitations under the License.
#pragma once
#include <fstream>
#include <iostream>
#include <string>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/flags.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/kernels/check_numerics_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/extensions.h"
#ifdef _WIN32
#include <direct.h>
#include <io.h>
#define MKDIR(path) _mkdir(path)
#else
#include <sys/stat.h>
#define MKDIR(path) mkdir(path, S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH)
#endif
PHI_DECLARE_int32(check_nan_inf_level);
namespace paddle {
namespace framework {
namespace details {
......@@ -44,284 +34,7 @@ void SetNanInfStackLimit(const int& stack_limit);
int GetNanInfStackLimit();
template <typename T,
typename MT,
std::enable_if_t<std::is_same<T, float>::value, bool> = true>
HOSTDEVICE bool NeedPrint(MT max_value, MT min_value, int check_nan_inf_level) {
if (check_nan_inf_level >= 3) {
return true;
} else if (check_nan_inf_level >= 2) {
MT fp16_max =
static_cast<MT>(std::numeric_limits<phi::dtype::float16>::max());
return max_value > fp16_max || min_value < -fp16_max;
}
return false;
}
template <typename T,
typename MT,
std::enable_if_t<!std::is_same<T, float>::value, bool> = true>
HOSTDEVICE bool NeedPrint(MT max_value UNUSED,
MT min_value UNUSED,
int check_nan_inf_level) {
if (check_nan_inf_level >= 3) {
return true;
}
return false;
}
template <typename T, typename MT>
HOSTDEVICE void PrintForDifferentLevel(const char* debug_info,
int64_t numel,
int64_t num_nan,
int64_t num_inf,
int64_t num_zero,
MT max_value,
MT min_value,
MT mean_value,
int check_nan_inf_level) {
if (num_nan > 0 || num_inf > 0) {
printf(
"[PRECISION] [ERROR] in %s, numel=%lld, num_nan=%lld, "
"num_inf=%lld, num_zero=%lld, max=%e, min=%e, mean=%e\n",
debug_info,
static_cast<long long>(numel), // NOLINT
static_cast<long long>(num_nan), // NOLINT
static_cast<long long>(num_inf), // NOLINT
static_cast<long long>(num_zero), // NOLINT
static_cast<float>(max_value),
static_cast<float>(min_value),
static_cast<float>(mean_value));
if (check_nan_inf_level == 0) {
#if !(defined(__NVCC__) || defined(__HIPCC__))
PADDLE_THROW(platform::errors::PreconditionNotMet(
"There are NAN or INF (num_nan=%lld, num_inf=%lld, num_zero=%lld) in "
"%s.",
static_cast<long long>(num_nan), // NOLINT
static_cast<long long>(num_inf), // NOLINT
static_cast<long long>(num_zero), // NOLINT
debug_info));
#endif
}
} else if (NeedPrint<T, MT>(max_value, min_value, check_nan_inf_level)) {
printf(
"[PRECISION] in %s, numel=%lld, num_zero=%lld, max=%e, min=%e, "
"mean=%e\n",
debug_info,
static_cast<long long>(numel), // NOLINT
static_cast<long long>(num_zero), // NOLINT
static_cast<float>(max_value),
static_cast<float>(min_value),
static_cast<float>(mean_value));
}
}
template <typename T, typename MT>
void PrintForDifferentLevelFile(const char* debug_info,
int64_t numel,
int64_t num_nan,
int64_t num_inf,
int64_t num_zero,
MT max_value,
MT min_value,
MT mean_value,
int check_nan_inf_level,
const std::string& log_name) {
int dev_id = 0;
#ifdef PADDLE_WITH_HIP
hipGetDevice(&dev_id);
#elif PADDLE_WITH_CUDA
cudaGetDevice(&dev_id);
#endif
auto file_path = GetNanPath();
MKDIR(file_path.c_str());
std::string file_name = "worker_" + log_name + "." + std::to_string(dev_id);
std::string path = file_path + file_name;
std::ofstream outfile(path, std::ios::app);
if (!outfile.is_open()) {
return;
}
if (num_nan > 0 || num_inf > 0) {
outfile << "[PRECISION] [ERROR] in " << debug_info
<< ", numel=" << static_cast<long long>(numel) // NOLINT
<< ", num_nan=" << static_cast<long long>(num_nan) // NOLINT
<< ", num_inf=" << static_cast<long long>(num_inf) // NOLINT
<< ", num_zero=" << static_cast<long long>(num_zero) // NOLINT
<< ", max=" << static_cast<float>(max_value)
<< ", min=" << static_cast<float>(min_value)
<< ", mean=" << static_cast<float>(mean_value) << std::endl;
} else if (NeedPrint<T, MT>(max_value, min_value, check_nan_inf_level)) {
outfile << "[PRECISION] in " << debug_info
<< ", numel=" << static_cast<long long>(numel) // NOLINT
<< ", num_zero=" << static_cast<long long>(num_zero) // NOLINT
<< ", max=" << static_cast<float>(max_value)
<< ", min=" << static_cast<float>(min_value)
<< ", mean=" << static_cast<float>(mean_value) << std::endl;
}
outfile.close();
}
template <typename T>
inline std::string GetCpuHintString(const std::string& op_type,
const std::string& var_name,
const phi::Place& place,
int device_id = -1) {
std::string dtype_str = DataTypeToString(DataTypeTrait<T>::DataType());
if (dtype_str == "float") {
dtype_str = "fp32";
} else if (dtype_str == "double") {
dtype_str = "fp64";
} else if (dtype_str == "::paddle::platform::float16") {
dtype_str = "fp16";
} else if (dtype_str == "::paddle::platform::bfloat16") {
dtype_str = "bf16";
}
std::stringstream ss;
if (platform::is_gpu_place(place)) {
ss << "[device=gpu:" << device_id << ", ";
} else {
ss << "[device=cpu, ";
}
ss << "op=" << op_type << ", tensor=" << var_name << ", dtype=" << dtype_str
<< "]";
return ss.str();
}
template <
typename T,
std::enable_if_t<!std::is_same<T, phi::dtype::complex<float>>::value &&
!std::is_same<T, phi::dtype::complex<double>>::value,
bool> = true>
static void CheckNanInfCpuImpl(const T* value_ptr,
const int64_t numel,
const std::string& cpu_hint_str,
const std::string log_name = "cpu") {
using MT = typename phi::dtype::template MPTypeTrait<T>::Type;
#ifdef _OPENMP
// Use maximum 4 threads to collect the nan and inf information.
int num_threads = std::max(omp_get_num_threads(), 1);
num_threads = std::min(num_threads, 4);
#else
int num_threads = 1;
#endif
std::vector<int64_t> thread_num_nan(num_threads, 0);
std::vector<int64_t> thread_num_inf(num_threads, 0);
std::vector<int64_t> thread_num_zero(num_threads, 0);
std::vector<MT> thread_min_value(num_threads, static_cast<MT>(value_ptr[0]));
std::vector<MT> thread_max_value(num_threads, static_cast<MT>(value_ptr[0]));
std::vector<MT> thread_mean_value(num_threads, static_cast<MT>(0));
#ifdef _OPENMP
#pragma omp parallel num_threads(num_threads)
#endif
{
#ifdef _OPENMP
int64_t tid = omp_get_thread_num();
int64_t chunk_size = (numel + num_threads - 1) / num_threads;
int64_t begin = tid * chunk_size;
int64_t end = chunk_size + begin > numel ? numel : chunk_size + begin;
#else
int64_t tid = 0;
int64_t begin = 0;
int64_t end = numel;
#endif
for (int64_t i = begin; i < end; ++i) {
MT value = static_cast<MT>(value_ptr[i]);
thread_min_value[tid] = std::min(thread_min_value[tid], value);
thread_max_value[tid] = std::max(thread_max_value[tid], value);
thread_mean_value[tid] += value / static_cast<MT>(numel);
if (std::isnan(value)) {
thread_num_nan[tid] += 1;
} else if (std::isinf(value)) {
thread_num_inf[tid] += 1;
}
if (value == 0) {
thread_num_zero[tid] += 1;
}
}
}
int64_t num_nan = 0;
int64_t num_inf = 0;
int64_t num_zero = 0;
MT min_value = thread_min_value[0];
MT max_value = thread_max_value[0];
MT mean_value = static_cast<MT>(0);
for (int i = 0; i < num_threads; ++i) {
num_nan += thread_num_nan[i];
num_inf += thread_num_inf[i];
num_zero += thread_num_zero[i];
min_value = std::min(thread_min_value[i], min_value);
max_value = std::max(thread_max_value[i], max_value);
mean_value += thread_mean_value[i];
}
auto file_path = GetNanPath();
// Write log to file
if (file_path.size() > 0) {
VLOG(4) << "[FLAGS_check_nan_inf_level=" << FLAGS_check_nan_inf_level
<< "]. Write log to " << file_path;
PrintForDifferentLevelFile<T, MT>(cpu_hint_str.c_str(),
numel,
num_nan,
num_inf,
num_zero,
max_value,
min_value,
mean_value,
FLAGS_check_nan_inf_level,
log_name);
return;
}
PrintForDifferentLevel<T, MT>(cpu_hint_str.c_str(),
numel,
num_nan,
num_inf,
num_zero,
max_value,
min_value,
mean_value,
FLAGS_check_nan_inf_level);
}
template <
typename T,
std::enable_if_t<std::is_same<T, phi::dtype::complex<float>>::value ||
std::is_same<T, phi::dtype::complex<double>>::value,
bool> = true>
void CheckNanInfCpuImpl(const T* value_ptr,
const int64_t numel,
const std::string& cpu_hint_str,
const std::string log_name = "cpu") {
using RealType = typename T::value_type;
RealType real_sum = 0.0f, imag_sum = 0.0f;
#ifdef _OPENMP
#pragma omp parallel for reduction(+ : real_sum) reduction(+ : imag_sum)
#endif
for (int64_t i = 0; i < numel; ++i) {
T value = value_ptr[i];
real_sum += (value.real - value.real);
imag_sum += (value.imag - value.imag);
}
if (std::isnan(real_sum) || std::isinf(real_sum) || std::isnan(imag_sum) ||
std::isinf(imag_sum)) {
// hot fix for compile failed in gcc4.8
// here also need print detail info of nan or inf later
PADDLE_THROW(platform::errors::PreconditionNotMet(
"There are NAN or INF in %s.", cpu_hint_str));
}
}
template <typename DeviceContext>
template <typename Context>
struct TensorCheckerVisitor {
TensorCheckerVisitor(const std::string& o,
const std::string& v,
......@@ -341,7 +54,14 @@ struct TensorCheckerVisitor {
std::is_floating_point<T>::value ||
std::is_same<T, ::paddle::platform::complex<float>>::value ||
std::is_same<T, ::paddle::platform::complex<double>>::value>::type* =
0) const;
0) const {
auto* dev_ctx = reinterpret_cast<Context*>(
platform::DeviceContextPool::Instance().Get(tensor.place()));
auto file_path = GetNanPath();
phi::CheckNumericsKernel<T, Context>(
*dev_ctx, tensor, op_type, var_name, GetNanInfStackLimit(), file_path);
}
std::string op_type;
std::string var_name;
......@@ -349,11 +69,14 @@ struct TensorCheckerVisitor {
const platform::Place& place;
};
template <typename DeviceContext>
template <typename Context>
void tensor_check(const std::string& op_type,
const std::string& var_name,
const phi::DenseTensor& tensor,
const platform::Place& place);
const platform::Place& place) {
TensorCheckerVisitor<Context> vistor(op_type, var_name, tensor, place);
VisitDataType(framework::TransToProtoVarType(tensor.dtype()), vistor);
}
} // namespace details
} // namespace framework
......
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void CheckNumericsKernel(const Context& ctx,
const DenseTensor& tensor,
const std::string& op_type,
const std::string& var_name,
const int stack_height_limit,
const std::string& output_dir);
} // namespace phi
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/check_numerics_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/flags.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/check_numerics_utils.h"
PHI_DECLARE_int32(check_nan_inf_level);
namespace phi {
template <typename T, typename Context>
void CheckNumericsKernel(const Context& ctx,
const DenseTensor& tensor,
const std::string& op_type,
const std::string& var_name,
const int stack_height_limit,
const std::string& output_dir) {
std::string cpu_hint_str =
phi::funcs::GetCpuHintString<T>(op_type, var_name, tensor.place());
phi::funcs::CheckNumericsCpuImpl(tensor.data<T>(),
tensor.numel(),
cpu_hint_str,
FLAGS_check_nan_inf_level,
"cpu",
output_dir);
}
} // namespace phi
PD_REGISTER_KERNEL(check_numerics,
CPU,
ALL_LAYOUT,
phi::CheckNumericsKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_MKLML
#include <omp.h>
#endif
#include <fstream>
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/enforce.h"
#ifdef _WIN32
#include <direct.h>
#include <io.h>
#define MKDIR(path) _mkdir(path)
#else
#include <sys/stat.h>
#define MKDIR(path) mkdir(path, S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH)
#endif
namespace phi {
namespace funcs {
template <typename T,
typename MT,
std::enable_if_t<std::is_same<T, float>::value, bool> = true>
HOSTDEVICE bool NeedPrint(MT max_value, MT min_value, int check_nan_inf_level) {
if (check_nan_inf_level >= 3) {
return true;
} else if (check_nan_inf_level >= 2) {
MT fp16_max =
static_cast<MT>(std::numeric_limits<phi::dtype::float16>::max());
return max_value > fp16_max || min_value < -fp16_max;
}
return false;
}
template <typename T,
typename MT,
std::enable_if_t<!std::is_same<T, float>::value, bool> = true>
HOSTDEVICE bool NeedPrint(MT max_value UNUSED,
MT min_value UNUSED,
int check_nan_inf_level) {
if (check_nan_inf_level >= 3) {
return true;
}
return false;
}
HOSTDEVICE static void PrintAndThrowError(const char* debug_info,
int64_t num_nan,
int64_t num_inf,
int64_t num_zero) {
#if !defined(__HIPCC__) && !defined(__CUDA_ARCH__)
PADDLE_THROW(phi::errors::PreconditionNotMet(
"There are NAN or INF (num_nan=%lld, num_inf=%lld, num_zero=%lld) in "
"%s.",
static_cast<long long>(num_nan), // NOLINT
static_cast<long long>(num_inf), // NOLINT
static_cast<long long>(num_zero), // NOLINT
debug_info));
#endif
}
template <typename T, typename MT>
HOSTDEVICE void PrintForDifferentLevel(const char* debug_info,
int64_t numel,
int64_t num_nan,
int64_t num_inf,
int64_t num_zero,
MT max_value,
MT min_value,
MT mean_value,
int check_nan_inf_level) {
if (num_nan > 0 || num_inf > 0) {
printf(
"[PRECISION] [ERROR] in %s, numel=%lld, num_nan=%lld, "
"num_inf=%lld, num_zero=%lld, max=%e, min=%e, mean=%e\n",
debug_info,
static_cast<long long>(numel), // NOLINT
static_cast<long long>(num_nan), // NOLINT
static_cast<long long>(num_inf), // NOLINT
static_cast<long long>(num_zero), // NOLINT
static_cast<float>(max_value),
static_cast<float>(min_value),
static_cast<float>(mean_value));
if (check_nan_inf_level == 0) {
PrintAndThrowError(debug_info, num_nan, num_inf, num_zero);
}
} else if (NeedPrint<T, MT>(max_value, min_value, check_nan_inf_level)) {
printf(
"[PRECISION] in %s, numel=%lld, num_zero=%lld, max=%e, min=%e, "
"mean=%e\n",
debug_info,
static_cast<long long>(numel), // NOLINT
static_cast<long long>(num_zero), // NOLINT
static_cast<float>(max_value),
static_cast<float>(min_value),
static_cast<float>(mean_value));
}
}
template <typename T, typename MT>
void WriteToFileForDifferentLevel(const char* debug_info,
int64_t numel,
int64_t num_nan,
int64_t num_inf,
int64_t num_zero,
MT max_value,
MT min_value,
MT mean_value,
int check_nan_inf_level,
const std::string& log_name,
const std::string output_dir) {
MKDIR(output_dir.c_str());
std::string filename = output_dir + "worker_" + log_name;
std::ofstream outfile(filename, std::ios::app);
PADDLE_ENFORCE_EQ(
outfile.is_open(),
true,
phi::errors::Unavailable("Fail to open output file %s, please check the "
"specified output_dir (%s).",
filename,
output_dir));
if (num_nan > 0 || num_inf > 0) {
outfile << "[PRECISION] [ERROR] in " << debug_info
<< ", numel=" << static_cast<long long>(numel) // NOLINT
<< ", num_nan=" << static_cast<long long>(num_nan) // NOLINT
<< ", num_inf=" << static_cast<long long>(num_inf) // NOLINT
<< ", num_zero=" << static_cast<long long>(num_zero) // NOLINT
<< std::scientific << std::setprecision(6)
<< ", max=" << static_cast<float>(max_value)
<< ", min=" << static_cast<float>(min_value)
<< ", mean=" << static_cast<float>(mean_value) << std::endl;
} else if (phi::funcs::NeedPrint<T, MT>(
max_value, min_value, check_nan_inf_level)) {
outfile << "[PRECISION] in " << debug_info
<< ", numel=" << static_cast<long long>(numel) // NOLINT
<< ", num_zero=" << static_cast<long long>(num_zero) // NOLINT
<< std::scientific << std::setprecision(6)
<< ", max=" << static_cast<float>(max_value)
<< ", min=" << static_cast<float>(min_value)
<< ", mean=" << static_cast<float>(mean_value) << std::endl;
}
outfile.close();
}
template <typename T>
inline std::string GetCpuHintString(const std::string& op_type,
const std::string& var_name,
const phi::Place& place,
int device_id = -1) {
std::string dtype_str;
phi::DataType dtype = phi::CppTypeToDataType<T>::Type();
if (dtype == DataType::FLOAT32) {
dtype_str = "fp32";
} else if (dtype == DataType::FLOAT64) {
dtype_str = "fp64";
} else if (dtype == DataType::FLOAT16) {
dtype_str = "fp16";
} else if (dtype == DataType::BFLOAT16) {
dtype_str = "bf16";
}
std::stringstream ss;
if (place.GetType() == phi::AllocationType::GPU) {
ss << "[device=gpu:" << device_id << ", ";
} else {
ss << "[device=cpu, ";
}
ss << "op=" << op_type << ", tensor=" << var_name << ", dtype=" << dtype_str
<< "]";
return ss.str();
}
template <
typename T,
std::enable_if_t<!std::is_same<T, phi::dtype::complex<float>>::value &&
!std::is_same<T, phi::dtype::complex<double>>::value,
bool> = true>
static void CheckNumericsCpuImpl(const T* value_ptr,
const int64_t numel,
const std::string& cpu_hint_str,
const int check_nan_inf_level,
const std::string log_name = "cpu",
const std::string output_dir = "") {
using MT = typename phi::dtype::template MPTypeTrait<T>::Type;
#ifdef _OPENMP
// Use maximum 4 threads to collect the nan and inf information.
int num_threads = std::max(omp_get_num_threads(), 1);
num_threads = std::min(num_threads, 4);
#else
int num_threads = 1;
#endif
std::vector<int64_t> thread_num_nan(num_threads, 0);
std::vector<int64_t> thread_num_inf(num_threads, 0);
std::vector<int64_t> thread_num_zero(num_threads, 0);
std::vector<MT> thread_min_value(num_threads, static_cast<MT>(value_ptr[0]));
std::vector<MT> thread_max_value(num_threads, static_cast<MT>(value_ptr[0]));
std::vector<MT> thread_mean_value(num_threads, static_cast<MT>(0));
#ifdef _OPENMP
#pragma omp parallel num_threads(num_threads)
#endif
{
#ifdef _OPENMP
int64_t tid = omp_get_thread_num();
int64_t chunk_size = (numel + num_threads - 1) / num_threads;
int64_t begin = tid * chunk_size;
int64_t end = chunk_size + begin > numel ? numel : chunk_size + begin;
#else
int64_t tid = 0;
int64_t begin = 0;
int64_t end = numel;
#endif
for (int64_t i = begin; i < end; ++i) {
MT value = static_cast<MT>(value_ptr[i]);
thread_min_value[tid] = std::min(thread_min_value[tid], value);
thread_max_value[tid] = std::max(thread_max_value[tid], value);
thread_mean_value[tid] += value / static_cast<MT>(numel);
if (std::isnan(value)) {
thread_num_nan[tid] += 1;
} else if (std::isinf(value)) {
thread_num_inf[tid] += 1;
}
if (value == 0) {
thread_num_zero[tid] += 1;
}
}
}
int64_t num_nan = 0;
int64_t num_inf = 0;
int64_t num_zero = 0;
MT min_value = thread_min_value[0];
MT max_value = thread_max_value[0];
MT mean_value = static_cast<MT>(0);
for (int i = 0; i < num_threads; ++i) {
num_nan += thread_num_nan[i];
num_inf += thread_num_inf[i];
num_zero += thread_num_zero[i];
min_value = std::min(thread_min_value[i], min_value);
max_value = std::max(thread_max_value[i], max_value);
mean_value += thread_mean_value[i];
}
// Write log to file
if (output_dir.size() > 0) {
WriteToFileForDifferentLevel<T, MT>(cpu_hint_str.c_str(),
numel,
num_nan,
num_inf,
num_zero,
max_value,
min_value,
mean_value,
check_nan_inf_level,
log_name,
output_dir);
} else {
PrintForDifferentLevel<T, MT>(cpu_hint_str.c_str(),
numel,
num_nan,
num_inf,
num_zero,
max_value,
min_value,
mean_value,
check_nan_inf_level);
}
}
template <
typename T,
std::enable_if_t<std::is_same<T, phi::dtype::complex<float>>::value ||
std::is_same<T, phi::dtype::complex<double>>::value,
bool> = true>
void CheckNumericsCpuImpl(const T* value_ptr,
const int64_t numel,
const std::string& cpu_hint_str,
const int check_nan_inf_level,
const std::string log_name = "cpu",
const std::string output_dir = "") {
using RealType = typename T::value_type;
RealType real_sum = 0.0f, imag_sum = 0.0f;
#ifdef _OPENMP
#pragma omp parallel for reduction(+ : real_sum) reduction(+ : imag_sum)
#endif
for (int64_t i = 0; i < numel; ++i) {
T value = value_ptr[i];
real_sum += (value.real - value.real);
imag_sum += (value.imag - value.imag);
}
if (std::isnan(real_sum) || std::isinf(real_sum) || std::isnan(imag_sum) ||
std::isinf(imag_sum)) {
// hot fix for compile failed in gcc4.8
// here also need print detail info of nan or inf later
PADDLE_THROW(phi::errors::PreconditionNotMet("There are NAN or INF in %s.",
cpu_hint_str));
}
}
} // namespace funcs
} // namespace phi
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/nan_inf_utils_detail.h"
#include "paddle/fluid/framework/details/nan_inf_utils.h"
#include <algorithm>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/scope.h"
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/check_numerics_kernel.h"
#include "glog/logging.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/flags.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/check_numerics_utils.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"
PHI_DECLARE_int32(check_nan_inf_level);
namespace paddle {
namespace framework {
namespace details {
namespace phi {
static std::once_flag init_multi_gpu_op_var_map_flag;
// lazy init
static std::vector<std::unordered_map<std::string, memory::AllocationPtr>>&
static std::vector<
std::unordered_map<std::string, phi::Allocator::AllocationPtr>>&
multi_op_var2gpu_str() {
static std::vector<std::unordered_map<std::string, memory::AllocationPtr>>
static std::vector<
std::unordered_map<std::string, phi::Allocator::AllocationPtr>>
_multi_op_var2gpu_str;
return _multi_op_var2gpu_str;
}
......@@ -49,15 +46,15 @@ static std::vector<std::mutex>& multi_op_var2gpu_str_mutex() {
}
static void InitMultiGPUOpVarMap() {
int dev_count = platform::GetGPUDeviceCount();
int dev_count = phi::backends::gpu::GetGPUDeviceCount();
PADDLE_ENFORCE_GT(dev_count,
0,
platform::errors::NotFound(
phi::errors::NotFound(
"cuda device must > 0, now dev_count=%d", dev_count));
// https://stackoverflow.com/questions/16465633/how-can-i-use-something-like-stdvectorstdmutex
std::vector<std::unordered_map<std::string, memory::AllocationPtr>> tmp_multi(
dev_count);
std::vector<std::unordered_map<std::string, phi::Allocator::AllocationPtr>>
tmp_multi(dev_count);
std::vector<std::mutex> tmp_multi_mutex(dev_count);
multi_op_var2gpu_str().swap(tmp_multi);
......@@ -297,7 +294,7 @@ __global__ void FindGlobalMaxMinAndPrint(const int64_t* block_num_nan_ptr,
int64_t numel,
int64_t numel_max_min,
int check_nan_inf_level,
int64_t* nan_inf_zero) {
int64_t* nan_inf_zero_ptr) {
if (blockIdx.x == 0 && threadIdx.x == 0) {
int64_t num_nan = 0;
int64_t num_inf = 0;
......@@ -329,20 +326,21 @@ __global__ void FindGlobalMaxMinAndPrint(const int64_t* block_num_nan_ptr,
mean_value += tmp_mean_value;
}
if (check_nan_inf_level == 0) {
nan_inf_zero[0] = num_nan;
nan_inf_zero[1] = num_inf;
nan_inf_zero[2] = num_zero;
nan_inf_zero_ptr[0] = num_nan;
nan_inf_zero_ptr[1] = num_inf;
nan_inf_zero_ptr[2] = num_zero;
}
}
PrintForDifferentLevel<T, MT>(debug_info,
numel,
num_nan,
num_inf,
num_zero,
max_value,
min_value,
mean_value,
check_nan_inf_level);
phi::funcs::PrintForDifferentLevel<T, MT>(debug_info,
numel,
num_nan,
num_inf,
num_zero,
max_value,
min_value,
mean_value,
check_nan_inf_level);
}
}
......@@ -351,12 +349,13 @@ inline std::string GetHintString(const std::string& op_type,
const std::string& var_name,
const phi::Place& place,
int dev_id = -1) {
std::string op_var = GetCpuHintString<T>(op_type, var_name, place, dev_id);
std::string op_var =
phi::funcs::GetCpuHintString<T>(op_type, var_name, place, dev_id);
PADDLE_ENFORCE_EQ(
(dev_id >= 0 && dev_id < multi_op_var2gpu_str_mutex().size()),
true,
platform::errors::OutOfRange("GPU dev_id must >=0 and < dev_count=%d",
multi_op_var2gpu_str_mutex().size()));
phi::errors::OutOfRange("GPU dev_id must >=0 and < dev_count=%d",
multi_op_var2gpu_str_mutex().size()));
return op_var;
}
......@@ -375,7 +374,7 @@ static char* GetGpuHintStringPtr(const phi::GPUContext& ctx,
std::lock_guard<std::mutex> guard(op_var2gpu_str_mutex);
if (op_var2gpu_str.find(op_var) == op_var2gpu_str.end()) { // insert
auto gpu_str_tensor = paddle::memory::Alloc(
auto gpu_str_tensor = phi::memory_utils::Alloc(
ctx.GetPlace(),
op_var.length() + 1,
phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
......@@ -386,7 +385,7 @@ static char* GetGpuHintStringPtr(const phi::GPUContext& ctx,
auto iter = op_var2gpu_str.find(op_var);
PADDLE_ENFORCE_EQ(iter != op_var2gpu_str.end(),
true,
platform::errors::PreconditionNotMet(
phi::errors::PreconditionNotMet(
"op_var=%s should successed insert into "
"op_var2gpu_str, but now failed",
op_var));
......@@ -408,7 +407,7 @@ static char* GetGpuHintStringPtr(const phi::GPUContext& ctx,
auto iter = op_var2gpu_str.find(op_var);
PADDLE_ENFORCE_EQ(iter != op_var2gpu_str.end(),
true,
platform::errors::PreconditionNotMet(
phi::errors::PreconditionNotMet(
"op_var=%s should be in the op_var2gpu_str, but "
"now can't find it",
op_var));
......@@ -418,37 +417,41 @@ static char* GetGpuHintStringPtr(const phi::GPUContext& ctx,
return gpu_str_ptr;
}
template <>
template <typename T>
void TensorCheckerVisitor<phi::GPUContext>::apply(
typename std::enable_if<
std::is_floating_point<T>::value ||
std::is_same<T, ::paddle::platform::complex<float>>::value ||
std::is_same<T, ::paddle::platform::complex<double>>::value>::type*)
const {
auto* dev_ctx = reinterpret_cast<phi::GPUContext*>(
platform::DeviceContextPool::Instance().Get(tensor.place()));
template <typename T, typename Context>
void CheckNumericsKernel(const Context& ctx,
const DenseTensor& tensor,
const std::string& op_type,
const std::string& var_name,
const int stack_height_limit,
const std::string& output_dir) {
std::call_once(init_multi_gpu_op_var_map_flag, InitMultiGPUOpVarMap);
int dev_id = tensor.place().device;
// Write log to file
auto file_path = GetNanPath();
if (file_path.size() > 0) {
VLOG(6) << "op_type=" << op_type << ", var_name=" << var_name
<< ", dev_id=gpu:" << dev_id
<< ", stack_height_limit=" << stack_height_limit
<< ", output_dir=" << output_dir;
// Write log to output_dir.
if (output_dir.size() > 0) {
phi::DenseTensor cpu_tensor;
platform::CPUPlace cpu_place;
cpu_tensor.Resize(tensor.dims());
// 1. copy from gpu to cpu
paddle::framework::TensorCopySync(tensor, cpu_place, &cpu_tensor);
auto* dev_ctx = reinterpret_cast<phi::GPUContext*>(
platform::DeviceContextPool::Instance().Get(tensor.place()));
// Copy tensor from GPU to CPU.
phi::Copy(ctx, tensor, CPUPlace(), true, &cpu_tensor);
const std::string debug_info =
GetHintString<T>(op_type, var_name, place, dev_id);
// 2. write log to file
CheckNanInfCpuImpl(cpu_tensor.data<T>(), tensor.numel(), debug_info, "gpu");
GetHintString<T>(op_type, var_name, tensor.place(), dev_id);
std::string log_name = "gpu." + std::to_string(dev_id);
phi::funcs::CheckNumericsCpuImpl(cpu_tensor.data<T>(),
tensor.numel(),
debug_info,
FLAGS_check_nan_inf_level,
log_name,
output_dir);
return;
}
// Write log to window
char* gpu_str_ptr =
GetGpuHintStringPtr<T>(*dev_ctx, op_type, var_name, dev_id);
// Print to the standard output.
char* gpu_str_ptr = GetGpuHintStringPtr<T>(ctx, op_type, var_name, dev_id);
#ifdef __HIPCC__
// HIP will throw GPU memory access fault if threads > 256
......@@ -466,7 +469,7 @@ void TensorCheckerVisitor<phi::GPUContext>::apply(
dim3(blocks),
dim3(threads),
0,
dev_ctx->stream(),
ctx.stream(),
tensor.data<T>(),
tensor.numel(),
print_num,
......@@ -479,83 +482,75 @@ void TensorCheckerVisitor<phi::GPUContext>::apply(
phi::DenseTensor block_num_nan_inf_zero;
block_num_nan_inf_zero.Resize({static_cast<int64_t>(3 * numel_max_min)});
int64_t* block_num_nan_ptr =
dev_ctx->template Alloc<int64_t>(&block_num_nan_inf_zero);
ctx.template Alloc<int64_t>(&block_num_nan_inf_zero);
int64_t* block_num_inf_ptr = block_num_nan_ptr + numel_max_min;
int64_t* block_num_zero_ptr = block_num_inf_ptr + numel_max_min;
phi::DenseTensor tensor_block_max_min;
tensor_block_max_min.Resize({static_cast<int64_t>(3 * numel_max_min)});
MT* tensor_block_max_ptr = dev_ctx->template Alloc<MT>(&tensor_block_max_min);
MT* tensor_block_max_ptr = ctx.template Alloc<MT>(&tensor_block_max_min);
MT* tensor_block_min_ptr = tensor_block_max_ptr + numel_max_min;
MT* tensor_block_mean_ptr = tensor_block_max_ptr + 2 * numel_max_min;
FindNanInfAndBlockMaxMin<T, MT>
<<<blocks, threads, 0, dev_ctx->stream()>>>(tensor.data<T>(),
tensor.numel(),
block_num_nan_ptr,
block_num_inf_ptr,
block_num_zero_ptr,
tensor_block_max_ptr,
tensor_block_min_ptr,
tensor_block_mean_ptr);
<<<blocks, threads, 0, ctx.stream()>>>(tensor.data<T>(),
tensor.numel(),
block_num_nan_ptr,
block_num_inf_ptr,
block_num_zero_ptr,
tensor_block_max_ptr,
tensor_block_min_ptr,
tensor_block_mean_ptr);
int check_nan_inf_level = FLAGS_check_nan_inf_level;
phi::DenseTensor nan_inf_zero_tensor;
nan_inf_zero_tensor.Resize({static_cast<int64_t>(3)});
int64_t* nan_inf_zero =
dev_ctx->template Alloc<int64_t>(&nan_inf_zero_tensor);
int64_t* nan_inf_zero_ptr = ctx.template Alloc<int64_t>(&nan_inf_zero_tensor);
FindGlobalMaxMinAndPrint<T, MT>
<<<1, 1, 0, dev_ctx->stream()>>>(block_num_nan_ptr,
block_num_inf_ptr,
block_num_zero_ptr,
tensor_block_max_ptr,
tensor_block_min_ptr,
tensor_block_mean_ptr,
gpu_str_ptr,
tensor.numel(),
numel_max_min,
check_nan_inf_level,
nan_inf_zero_tensor.data<int64_t>());
if (check_nan_inf_level == 0 && GetNanInfStackLimit() > 0) {
<<<1, 1, 0, ctx.stream()>>>(block_num_nan_ptr,
block_num_inf_ptr,
block_num_zero_ptr,
tensor_block_max_ptr,
tensor_block_min_ptr,
tensor_block_mean_ptr,
gpu_str_ptr,
tensor.numel(),
numel_max_min,
check_nan_inf_level,
nan_inf_zero_ptr);
if (check_nan_inf_level == 0 && stack_height_limit > 0) {
auto nan_cpu =
phi::memory_utils::Alloc(phi::CPUPlace(), sizeof(int64_t) * 3);
int64_t* nan_cpu_ptr = reinterpret_cast<int64_t*>(nan_cpu->ptr());
phi::memory_utils::Copy(phi::CPUPlace(),
nan_cpu_ptr,
place,
nan_inf_zero,
tensor.place(),
nan_inf_zero_ptr,
3 * sizeof(int64_t),
dev_ctx->stream());
dev_ctx->Wait();
ctx.stream());
ctx.Wait();
if (nan_cpu_ptr[0] > 0 || nan_cpu_ptr[1] > 0) {
const std::string debug_info =
GetHintString<T>(op_type, var_name, place, dev_id);
PADDLE_THROW(platform::errors::PreconditionNotMet(
"There are NAN or INF (num_nan=%lld, num_inf=%lld, num_zero=%lld) in "
"%s.",
static_cast<long long>(nan_cpu_ptr[0]), // NOLINT
static_cast<long long>(nan_cpu_ptr[1]), // NOLINT
static_cast<long long>(nan_cpu_ptr[2]), // NOLINT
debug_info));
GetHintString<T>(op_type, var_name, tensor.place(), dev_id);
phi::funcs::PrintAndThrowError(
debug_info.c_str(), nan_cpu_ptr[0], nan_cpu_ptr[1], nan_cpu_ptr[2]);
}
}
#endif
}
template <>
void tensor_check<phi::GPUContext>(const std::string& op_type,
const std::string& var_name,
const phi::DenseTensor& tensor,
const platform::Place& place) {
std::call_once(init_multi_gpu_op_var_map_flag, InitMultiGPUOpVarMap);
TensorCheckerVisitor<phi::GPUContext> vistor(
op_type, var_name, tensor, place);
VisitDataType(framework::TransToProtoVarType(tensor.dtype()), vistor);
}
} // namespace details
} // namespace framework
} // namespace paddle
} // namespace phi
PD_REGISTER_KERNEL(check_numerics,
GPU,
ALL_LAYOUT,
phi::CheckNumericsKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -33,6 +33,7 @@ def is_allclose(actual, expected, atol=1e-2, rtol=1e-2):
class TensorInfo:
def __init__(self):
self.device = None
self.op_type = None
self.tensor_name = None
self.dtype = None
......@@ -45,7 +46,8 @@ class TensorInfo:
self.num_zero = None
def __str__(self):
return "[TensorInfo] op_type={}, tensor_name={}, dtype={}, numel={}, has_inf={}, has_nan={}, num_zero={}, max_value={:.6f}, min_value={:.6f}, mean_value={:.6f}".format(
return "[TensorInfo] device={}, op_type={}, tensor_name={}, dtype={}, numel={}, num_inf={}, num_nan={}, num_zero={}, max_value={:.6f}, min_value={:.6f}, mean_value={:.6f}".format(
self.device,
self.op_type,
self.tensor_name,
self.dtype,
......@@ -73,6 +75,8 @@ class TensorInfo:
words = word_str.split("=")
if words[0] == "op":
self.op_type = words[1]
elif words[0] == "device":
self.device = words[1]
elif words[0] == "tensor":
self.tensor_name = words[1]
elif words[0] == "dtype":
......@@ -380,8 +384,8 @@ class ExcelWriter:
"max_value": 16,
"min_value": 16,
"mean_value": 16,
"has_inf": 8,
"has_nan": 8,
"num_inf": 8,
"num_nan": 8,
}
title_names = ["op_type", "tensor_name", "numel", "infinite"]
if self.log_fp16_dir is None:
......@@ -412,8 +416,8 @@ class ExcelWriter:
"min_value",
"mean_value",
"num_zero",
"has_inf",
"has_nan",
"num_inf",
"num_nan",
]
)
else:
......@@ -435,8 +439,8 @@ class ExcelWriter:
"min_value",
"mean_value",
"num_zero",
"has_inf",
"has_nan",
"num_inf",
"num_nan",
"max_value",
"min_value",
"mean_value",
......@@ -572,39 +576,45 @@ class ExcelWriter:
print(f"-- OP Types produce infinite outputs: {infinite_op_types}")
def parse_lines(lines, specified_op_list=None):
tensor_info_list = []
for i in range(len(lines)):
if i % 10 == 0:
print(
f"-- Processing {i:-8d} / {len(lines):-8d} line",
end="\r",
)
line = lines[i]
if "[PRECISION]" in line:
tensor_info = TensorInfo()
tensor_info.init_from_string(line)
if (
tensor_info.tensor_name is not None
and tensor_info.tensor_name != ""
):
has_tensor_name = True
if (
specified_op_list is None
or tensor_info.op_type in specified_op_list
):
tensor_info_list.append(tensor_info)
# print(tensor_info)
return tensor_info_list
def parse_log(log_dir, filename, specified_op_list=None):
if log_dir is None or filename is None:
return None
complete_filename = log_dir + "/" + filename
tensor_info_list = []
tensor_info_list = None
has_tensor_name = False
try:
with open(complete_filename, 'r') as f:
lines = f.readlines()
for i in range(len(lines)):
if i % 10 == 0:
print(
f"-- Processing {i:-8d} / {len(lines):-8d} line",
end="\r",
)
# [op=adamw] [tensor=encoder_layer_20_multi_head_att_output_fc_0.w_0], numel: 294912, max: 0.005773, min: -0.005774
line = lines[i]
if "[PRECISION]" in line:
tensor_info = TensorInfo()
tensor_info.init_from_string(line)
if (
tensor_info.tensor_name is not None
and tensor_info.tensor_name != ""
):
has_tensor_name = True
if (
specified_op_list is None
or tensor_info.op_type in specified_op_list
):
tensor_info_list.append(tensor_info)
# print(tensor_info)
tensor_info_list = parse_lines(lines, specified_op_list)
except FileNotFoundError:
print("the file ", complete_filename, "is not found")
return None, has_tensor_name
......
......@@ -17,7 +17,6 @@ import os
import numpy as np
os.environ["FLAGS_check_nan_inf"] = "1"
os.environ["GLOG_vmodule"] = "nan_inf_utils_detail=10"
import paddle
from paddle import fluid
......@@ -59,8 +58,7 @@ def net():
hidden = x
for i in range(2):
hidden = paddle.static.nn.fc(x=hidden, size=400, activation="sigmoid")
hidden = paddle.static.nn.fc(x=hidden, size=400, activation="sigmoid")
hidden = paddle.static.nn.fc(x=hidden, size=3)
cost, y_predict = paddle.nn.functional.softmax_with_cross_entropy(
......
......@@ -12,105 +12,114 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import argparse
import numpy as np
os.environ["FLAGS_check_nan_inf"] = "1"
os.environ["GLOG_vmodule"] = "nan_inf_utils_detail=10"
import paddle
from paddle import nn
np.random.seed(0)
# os.environ["GLOG_vmodule"] = "nan_inf_utils_detail=10"
def generator():
batch_size = 5
for i in range(5):
curr_train_x = np.random.randint(
batch_size, size=(batch_size, 3)
).astype("float32")
if i >= 2:
curr_train_x[0, :] = np.nan
curr_train_x[-1, :] = np.inf
res = []
for i in range(batch_size):
y = i % 3
res.append([y])
y_label = np.array(res).astype('int64')
yield [curr_train_x, y_label]
paddle.seed(0)
np.random.seed(0)
class TestLayer(nn.Layer):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(3, 400)
self.linear2 = nn.Linear(400, 400)
self.linear3 = nn.Linear(400, 3)
w_1_np = np.random.random([32, 400]).astype("float32")
self.linear1 = nn.Linear(
in_features=32,
out_features=400,
weight_attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Assign(w_1_np)
),
)
w_2_np = np.random.random([400, 10]).astype("float32")
self.linear2 = nn.Linear(
in_features=400,
out_features=10,
weight_attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Assign(w_2_np)
),
)
def forward(self, x):
x = self.linear1(x)
x = nn.functional.sigmoid(x)
x = self.linear2(x)
x = nn.functional.sigmoid(x)
x = self.linear3(x)
x = nn.functional.softmax(x)
return x
out = self.linear1(x)
out = nn.functional.sigmoid(out)
out = self.linear2(out)
mask = paddle.randint(low=0, high=2, shape=out.shape).astype("float32")
out = paddle.divide(out, mask)
out = nn.functional.softmax(out)
return out
def check(use_cuda):
def check_main(use_cuda, use_amp=False):
paddle.set_device('gpu' if use_cuda else 'cpu')
net = TestLayer()
sgd = paddle.optimizer.SGD(learning_rate=0.05, parameters=net.parameters())
for step, (x, y) in enumerate(generator()):
x = paddle.to_tensor(x)
y = paddle.to_tensor(y)
zero = paddle.zeros(shape=[1], dtype='int64')
fp16_zero = paddle.cast(zero, dtype='float16')
y = y + zero
y_pred = net(x)
cost = nn.functional.cross_entropy(y_pred, y, use_softmax=False)
avg_cost = paddle.mean(cost)
acc_top1 = paddle.metric.accuracy(input=y_pred, label=y, k=1)
print(
'iter={:.0f}, cost={}, acc1={}'.format(
step, avg_cost.numpy(), acc_top1.numpy()
)
)
sgd.step()
sgd.clear_grad()
def run_check():
if paddle.is_compiled_with_cuda():
try:
check(use_cuda=True)
raise AssertionError()
except Exception as e:
print(e)
print(type(e))
# Note. Enforce in cuda kernel may not catch in paddle, and
# Exception type will be RuntimeError
assert type(e) == OSError or type(e) == RuntimeError
try:
check(use_cuda=False)
raise AssertionError()
except Exception as e:
print(e)
print(type(e))
assert type(e) == RuntimeError
model = TestLayer()
sgd = paddle.optimizer.SGD(
learning_rate=0.05, parameters=model.parameters()
)
if use_cuda and use_amp:
scaler = paddle.amp.GradScaler()
x_np = 10000 * np.random.random([128, 32]).astype("float32")
x = paddle.to_tensor(x_np)
if use_cuda and use_amp:
with paddle.amp.auto_cast(enable=True, dtype="float16", level="O1"):
out = model(x)
loss = paddle.mean(out)
scaled = scaler.scale(loss)
scaled.backward()
scaler.minimize(sgd, scaled)
else:
out = model(x)
loss = paddle.mean(out)
loss.backward()
sgd.step()
sgd.clear_grad()
def run_check(args):
paddle.set_flags(
{
"FLAGS_check_nan_inf": 1,
"FLAGS_check_nan_inf_level": args.check_nan_inf_level,
}
)
use_cuda = args.use_cuda and paddle.is_compiled_with_cuda()
if args.check_nan_inf_level == 0:
if use_cuda:
try:
check_main(use_cuda=True, use_amp=args.use_amp)
raise AssertionError()
except Exception as e:
print(e)
print(type(e))
# Note. Enforce in cuda kernel may not catch in paddle, and
# Exception type will be RuntimeError
assert type(e) == OSError or type(e) == RuntimeError
else:
try:
check_main(use_cuda=False, use_amp=False)
raise AssertionError()
except Exception as e:
print(e)
print(type(e))
assert type(e) == RuntimeError
else:
check_main(use_cuda=use_cuda, use_amp=args.use_amp)
if __name__ == '__main__':
run_check()
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--use_cuda', action='store_true', default=False)
parser.add_argument('--use_amp', action='store_true', default=False)
parser.add_argument('--check_nan_inf_level', type=int, default=0)
args = parser.parse_args()
run_check(args)
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import os
import subprocess
import sys
......@@ -22,7 +23,7 @@ import numpy as np
import paddle
class TestNanInf(unittest.TestCase):
class TestNanInfBase(unittest.TestCase):
def setUp(self):
self._python_interp = sys.executable
if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
......@@ -30,9 +31,8 @@ class TestNanInf(unittest.TestCase):
self.env = os.environ.copy()
def check_nan_inf(self):
cmd = self._python_interp
def run_command(self, cmd):
print(f"Run command: {cmd}")
proc = subprocess.Popen(
cmd.split(" "),
stdout=subprocess.PIPE,
......@@ -42,20 +42,100 @@ class TestNanInf(unittest.TestCase):
out, err = proc.communicate()
returncode = proc.returncode
# in python3, type(out+err) is 'bytes', need use encode
assert (out + err).find(b'There are NAN or INF') != -1
return returncode, out, err
def test_nan_inf_in_static_mode(self):
self._python_interp += (
" " + os.path.dirname(__file__) + "/check_nan_inf_base.py"
)
self.check_nan_inf()
def test_nan_inf_in_dynamic_mode(self):
self._python_interp += (
" " + os.path.dirname(__file__) + "/check_nan_inf_base_dygraph.py"
)
self.check_nan_inf()
class TestNanInf(TestNanInfBase):
def setUp(self):
super().setUp()
self.check_static = True
self.check_dygraph = True
self.check_nan_inf_level = 0
self.dygraph_expected_op_count = {"divide": 1}
def check_op_count(self, log, expected_op_count=None):
if expected_op_count is None:
return
lines = copy.copy(log).decode().split("\n")
actual_op_count = {}
tensor_info_list = paddle.amp.accuracy_compare.parse_lines(lines)
for tensor_info in tensor_info_list:
print(tensor_info)
if actual_op_count.get(tensor_info.op_type, None) is None:
actual_op_count[tensor_info.op_type] = 1
else:
actual_op_count[tensor_info.op_type] += 1
print(actual_op_count)
for op_type, expected_value in expected_op_count.items():
actual_value = actual_op_count.get(op_type, 0)
self.assertEqual(
actual_value,
expected_value,
f"The number of operator < {op_type} > is expected to be {expected_value}, but recieved {actual_value}.",
)
print("")
def run_check_nan_inf(self, cmd, expected_op_count=None):
returncode, out, err = self.run_command(cmd)
self.check_op_count(out, expected_op_count)
if self.check_nan_inf_level == 0:
# in python3, type(out+err) is 'bytes', need use encode
self.assertNotEqual(
(out + err).find(b'There are NAN or INF'),
-1,
f"Cannot find NAN / INF keyword in:\n{out + err}",
)
def test_nan_inf_static(self):
if not self.check_static:
return
filepath = os.path.dirname(__file__) + "/check_nan_inf_base.py"
cmd = f"{self._python_interp} {filepath}"
self.run_check_nan_inf(cmd, None)
def test_nan_inf_dynamic(self):
if not self.check_dygraph:
return
filepath = os.path.dirname(__file__) + "/check_nan_inf_base_dygraph.py"
# Test on CPU.
cmd = f"{self._python_interp} {filepath} --check_nan_inf_level {self.check_nan_inf_level}"
self.run_check_nan_inf(cmd, self.dygraph_expected_op_count)
# Test on GPU.
if paddle.fluid.core.is_compiled_with_cuda():
cmd = f"{self._python_interp} {filepath} --use_cuda --check_nan_inf_level {self.check_nan_inf_level}"
self.run_check_nan_inf(cmd, self.dygraph_expected_op_count)
class TestCheckAll(TestNanInf):
def setUp(self):
super().setUp()
self.check_static = False
self.check_dygraph = True
self.check_nan_inf_level = 3
self.dygraph_expected_op_count = {
'assign_value_': 2,
'full_': 3,
'matmul': 2,
'add': 2,
'sigmoid': 1,
'cast': 1,
'divide': 1,
'softmax': 1,
'mean': 1,
'mean_grad': 1,
'softmax_grad': 1,
'divide_grad': 1,
'add_grad': 4,
'matmul_grad': 3,
'sigmoid_grad': 1,
'sgd_': 4,
}
class TestNanInfEnv(TestNanInf):
......@@ -67,24 +147,31 @@ class TestNanInfEnv(TestNanInf):
self.env["PADDLE_INF_NAN_SKIP_ROLE"] = "loss"
self.env["PADDLE_INF_NAN_SKIP_VAR"] = "elementwise_add:fc_0.tmp_1"
self.check_static = True
self.check_dygraph = False
self.check_nan_inf_level = 0
self.dygraph_expected_op_count = None
class TestCheckSkipEnv(TestNanInf):
def setUp(self):
super().setUp()
# windows python have some bug with env, so need use str to pass ci
# otherwise, "TypeError: environment can only contain strings"
self.env["Paddle_check_nan_inf_op_list"] = "mean"
self.env["Paddle_skip_nan_inf_op_list"] = "elementwise_add"
class TestNanInfStack(TestNanInfBase):
def check_stack(self, file_name):
cmd = self._python_interp + file_name
returncode, out, err = self.run_command(cmd)
class TestNanInfCheckResult(unittest.TestCase):
def setUp(self):
self._python_interp = sys.executable
if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
self._python_interp += " -m coverage run --branch -p"
print(out)
print(err)
self.env = os.environ.copy()
# in python3, type(out+err) is 'bytes', need use encode
assert (out + err).find(b' z = paddle.pow(x, y)') != -1
def test_check_stack(self):
self.check_stack(" check_nan_inf_backward_stack.py")
def test_statck_check_stack(self):
self.check_stack(" check_nan_inf_backward_static_stack.py")
class TestNanInfCheckResult(TestNanInfBase):
def generate_inputs(self, shape, dtype="float32"):
data = np.random.random(size=shape).astype(dtype)
# [-10, 10)
......@@ -148,32 +235,11 @@ class TestNanInfCheckResult(unittest.TestCase):
if paddle.fluid.core.is_compiled_with_cuda():
_check_num_nan_inf(use_cuda=True)
def check_stack(self, file_name):
self._python_interp += file_name
cmd = self._python_interp
proc = subprocess.Popen(
cmd.split(" "),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=self.env,
def run_check_nan_inf_level(self, use_cuda, dtype, level):
paddle.set_flags(
{"FLAGS_check_nan_inf": 1, "FLAGS_check_nan_inf_level": level}
)
out, err = proc.communicate()
returncode = proc.returncode
print(out)
print(err)
# in python3, type(out+err) is 'bytes', need use encode
assert (out + err).find(b' z = paddle.pow(x, y)') != -1
def test_check_stack(self):
self.check_stack(" check_nan_inf_backward_stack.py")
def test_statck_check_stack(self):
self.check_stack(" check_nan_inf_backward_static_stack.py")
def check_nan_inf_level(self, use_cuda, dtype):
shape = [8, 8]
x_np, y_np = self.generate_inputs(shape, dtype)
......@@ -186,33 +252,36 @@ class TestNanInfCheckResult(unittest.TestCase):
out = paddle.log(x * 1e6) / y
def test_check_nan_inf_level_float32(self):
paddle.set_flags(
{"FLAGS_check_nan_inf": 1, "FLAGS_check_nan_inf_level": 2}
level = 2
self.run_check_nan_inf_level(
use_cuda=False, dtype="float32", level=level
)
self.check_nan_inf_level(use_cuda=False, dtype="float32")
if paddle.fluid.core.is_compiled_with_cuda():
self.check_nan_inf_level(use_cuda=True, dtype="float32")
self.run_check_nan_inf_level(
use_cuda=True, dtype="float32", level=level
)
def test_check_nan_inf_level_float16(self):
paddle.set_flags(
{"FLAGS_check_nan_inf": 1, "FLAGS_check_nan_inf_level": 3}
level = 3
self.run_check_nan_inf_level(
use_cuda=False, dtype="float32", level=level
)
if paddle.fluid.core.is_compiled_with_cuda():
self.check_nan_inf_level(use_cuda=True, dtype="float16")
self.run_check_nan_inf_level(
use_cuda=True, dtype="float16", level=level
)
def test_check_numerics(self):
paddle.set_flags(
{"FLAGS_check_nan_inf": 1, "FLAGS_check_nan_inf_level": 3}
)
if paddle.fluid.core.is_compiled_with_cuda():
self.check_nan_inf_level(use_cuda=True, dtype="float16")
shape = [8, 8]
x_np, y_np = self.generate_inputs(shape, "float16")
x = paddle.to_tensor(x_np)
y = paddle.to_tensor(y_np)
paddle.fluid.core.check_numerics("check_numerics", x)
paddle.fluid.core.check_numerics("check_numerics", y)
paddle.fluid.core.check_numerics("check_tensor", x)
paddle.fluid.core.check_numerics("check_tensor", y)
if __name__ == '__main__':
......
......@@ -13,7 +13,7 @@
# limitations under the License.
import os
import sys
import tempfile
import unittest
import numpy as np
......@@ -22,97 +22,99 @@ import paddle
class TestNanInfDirCheckResult(unittest.TestCase):
def generate_inputs(self, shape, dtype="float32"):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
def tearDown(self):
self.temp_dir.cleanup()
def generate_inputs(self, shape, low=0, high=1, dtype="float32"):
data = np.random.random(size=shape).astype(dtype)
# [-10, 10)
x = (data * 20 - 10) * np.random.randint(
x = (data * (high - low) + low) * np.random.randint(
low=0, high=2, size=shape
).astype(dtype)
y = np.random.randint(low=0, high=2, size=shape).astype(dtype)
return x, y
return x
def get_reference_num_nan_inf(self, x):
out = np.log(x)
num_nan = np.sum(np.isnan(out))
num_inf = np.sum(np.isinf(out))
print(f"[reference] num_nan={num_nan}, num_inf={num_inf}")
print(f"-- [reference] num_nan={num_nan}, num_inf={num_inf}")
return num_nan, num_inf
def get_num_nan_inf(
self, x_np, use_cuda=True, add_assert=False, pt="nan_inf_log_dir"
):
def get_num_nan_inf(self, x_np, use_cuda=True, output_dir=None):
if use_cuda:
paddle.device.set_device("gpu:0")
else:
paddle.device.set_device("cpu")
x = paddle.to_tensor(x_np)
x = x * 0.5
out = paddle.log(x)
if use_cuda:
paddle.device.cuda.synchronize()
self.assertEqual(
os.path.exists(output_dir) and os.path.isdir(output_dir), True
)
num_nan = 0
num_inf = 0
if add_assert:
if use_cuda:
paddle.device.set_device("gpu:0")
else:
paddle.device.set_device("cpu")
x = paddle.to_tensor(x_np)
out = paddle.log(x)
sys.stdout.flush()
if not use_cuda:
os.path.exists(pt)
num_nan = 0
num_inf = 0
for root, dirs, files in os.walk(pt):
for file_name in files:
if file_name.startswith('worker_cpu'):
file_path = os.path.join(root, file_name)
with open(file_path, "rb") as fp:
for e in fp:
err_str_list = (
str(e)
.replace("(", " ")
.replace(")", " ")
.replace(",", " ")
.split(" ")
)
for err_str in err_str_list:
if "num_nan" in err_str:
num_nan = int(err_str.split("=")[1])
elif "num_inf" in err_str:
num_inf = int(err_str.split("=")[1])
print(f"[paddle] num_nan={num_nan}, num_inf={num_inf}")
prefix = "worker_gpu" if use_cuda else "worker_cpu"
for filename in os.listdir(output_dir):
if filename.startswith(prefix):
filepath = os.path.join(output_dir, filename)
print(f"-- Parse {filepath}")
with open(filepath, "rb") as fp:
for e in fp:
err_str_list = (
str(e)
.replace("(", " ")
.replace(")", " ")
.replace(",", " ")
.split(" ")
)
for err_str in err_str_list:
if "num_nan" in err_str:
num_nan = int(err_str.split("=")[1])
elif "num_inf" in err_str:
num_inf = int(err_str.split("=")[1])
print(
f"-- [paddle] use_cuda={use_cuda}, num_nan={num_nan}, num_inf={num_inf}"
)
return num_nan, num_inf
def test_num_nan_inf(self):
path = "nan_inf_log_dir"
def check_num_nan_inf(self, x_np, use_cuda, subdir):
output_dir = self.temp_dir.name + "/" + subdir
print(f"-- output_dir: {output_dir}")
checker_config = paddle.amp.debugging.TensorCheckerConfig(
enable=True,
debug_mode=paddle.amp.debugging.DebugMode.CHECK_ALL,
output_dir=path,
output_dir=output_dir,
)
paddle.amp.debugging.enable_tensor_checker(checker_config)
def _check_num_nan_inf(use_cuda):
shape = [32, 32]
x_np, _ = self.generate_inputs(shape)
num_nan_np, num_inf_np = self.get_reference_num_nan_inf(x_np)
add_assert = (num_nan_np + num_inf_np) > 0
num_nan, num_inf = self.get_num_nan_inf(
x_np,
use_cuda,
add_assert,
path,
)
if not use_cuda:
assert num_nan == num_nan_np and num_inf == num_inf_np
if paddle.fluid.core.is_compiled_with_cuda():
_check_num_nan_inf(use_cuda=True)
else:
_check_num_nan_inf(use_cuda=False)
num_nan_np, num_inf_np = self.get_reference_num_nan_inf(x_np)
num_nan, num_inf = self.get_num_nan_inf(
x_np,
use_cuda,
output_dir,
)
self.assertEqual(num_nan, num_nan_np)
self.assertEqual(num_inf, num_inf_np)
x = paddle.to_tensor([2, 3, 4], 'float32')
y = paddle.to_tensor([1, 5, 2], 'float32')
z = paddle.add(x, y)
path = ""
paddle.fluid.core.set_nan_inf_debug_path(path)
paddle.amp.debugging.disable_tensor_checker()
def test_num_nan_inf(self):
shape = [32, 32]
x_np = self.generate_inputs(shape, -10, 10)
self.check_num_nan_inf(
x_np, use_cuda=False, subdir="check_nan_inf_dir_cpu"
)
if paddle.fluid.core.is_compiled_with_cuda():
self.check_num_nan_inf(
x_np, use_cuda=True, subdir="check_nan_inf_dir_gpu"
)
if __name__ == '__main__':
unittest.main()
......@@ -18,7 +18,7 @@ import paddle
class TestTensorChecker(unittest.TestCase):
def get_num_inf(self, e):
def _parse_num_nan_inf(self, e):
num_nan = 0
num_inf = 0
# Cannot catch the log in CUDA kernel.
......@@ -34,14 +34,9 @@ class TestTensorChecker(unittest.TestCase):
num_nan = int(err_str.split("=")[1])
elif "num_inf" in err_str:
num_inf = int(err_str.split("=")[1])
print(
"[CHECK_NAN_INF_AND_ABORT] num_nan={}, num_inf={}".format(
num_nan, num_inf
)
)
return num_nan
return num_nan, num_inf
def generate_num_inf(self, place):
def _generate_num_inf(self, place):
num_inf = 0
num_nan = 0
paddle.set_device(place)
......@@ -58,8 +53,8 @@ class TestTensorChecker(unittest.TestCase):
paddle.autograd.backward([res])
res = paddle.divide(y, x)
except Exception as e:
num_inf = self.get_num_inf(e)
return num_inf
num_nan, num_inf = self._parse_num_nan_inf(e)
return num_nan, num_inf
def test_tensor_checker(self):
def _assert_flag(value):
......@@ -86,22 +81,38 @@ class TestTensorChecker(unittest.TestCase):
for place in places:
paddle.amp.debugging.TensorCheckerConfig.current_step_id = 0
for index in range(5):
for iter_id in range(5):
paddle.amp.debugging.enable_tensor_checker(checker_config)
if index <= 2:
if iter_id <= 2:
_assert_flag(True)
self.assertEqual(
index + 1,
iter_id + 1,
paddle.amp.debugging.TensorCheckerConfig.current_step_id,
)
self.assertEqual(1, self.generate_num_inf(place))
num_nan, num_inf = self._generate_num_inf(place)
print(
f"-- [iter_id={iter_id}, place={place}] num_nan={num_nan}, num_inf={num_inf}"
)
self.assertEqual(
1,
num_nan,
f"Expected num_nan to be 1, but recieved {num_nan}, place={place}.",
)
else:
self.assertEqual(
3,
paddle.amp.debugging.TensorCheckerConfig.current_step_id,
)
_assert_flag(False)
self.assertEqual(0, self.generate_num_inf(place))
num_nan, num_inf = self._generate_num_inf(place)
print(
f"-- [iter_id={iter_id}, place={place}] num_nan={num_nan}, num_inf={num_inf}"
)
self.assertEqual(
0,
num_nan,
f"Expected num_nan to be 1, but recieved {num_nan}, place={place}.",
)
paddle.amp.debugging.disable_tensor_checker()
_assert_flag(False)
......
cc_test(
test_egr_task_nan_inf_utils
SRCS nan_inf_utils_test.cc
DEPS eager_nan_inf_utils)
DEPS eager_nan_inf_utils phi)
if(NOT ((NOT WITH_PYTHON) AND ON_INFER))
cc_test(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册