From 44bd5927fac7ffae794a6f56fce23c03eee9e7e3 Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Tue, 30 May 2023 20:25:43 +0800 Subject: [PATCH] [AMP] Reimplement check_nan_inf as check_numerics_kernel. (#52245) * Reimplement the check_nan_inf function as check_numerics kernel. * Remove the cpu implemention to phi. * Add ifdef for the including of omp.h. * Move the use of FLAGS_check_nan_inf_level out of header file. * Implement a common PrintAndThrowError function. * Fix the error using of __NVCC__, which should be instead with __CUDA_ARCH__. * Add dependency of phi. * Polish codes and unittest. --- paddle/fluid/framework/details/CMakeLists.txt | 10 +- .../framework/details/nan_inf_utils_detail.cc | 23 -- .../framework/details/nan_inf_utils_detail.h | 313 +---------------- paddle/phi/kernels/check_numerics_kernel.h | 29 ++ .../phi/kernels/cpu/check_numerics_kernel.cc | 54 +++ .../phi/kernels/funcs/check_numerics_utils.h | 326 ++++++++++++++++++ .../kernels/gpu/check_numerics_kernel.cu} | 255 +++++++------- python/paddle/amp/accuracy_compare.py | 70 ++-- .../tests/unittests/check_nan_inf_base.py | 4 +- .../unittests/check_nan_inf_base_dygraph.py | 169 ++++----- .../fluid/tests/unittests/test_nan_inf.py | 197 +++++++---- .../fluid/tests/unittests/test_nan_inf_dir.py | 140 ++++---- .../amp}/test_tensor_checker.py | 41 ++- test/cpp/eager/task_tests/CMakeLists.txt | 2 +- 14 files changed, 918 insertions(+), 715 deletions(-) create mode 100644 paddle/phi/kernels/check_numerics_kernel.h create mode 100644 paddle/phi/kernels/cpu/check_numerics_kernel.cc create mode 100644 paddle/phi/kernels/funcs/check_numerics_utils.h rename paddle/{fluid/framework/details/nan_inf_utils_detail.cu => phi/kernels/gpu/check_numerics_kernel.cu} (72%) rename {python/paddle/fluid/tests/unittests => test/amp}/test_tensor_checker.py (73%) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index b660cbcef2b..5f304734b24 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -70,8 +70,8 @@ endif() if(WITH_GPU) nv_library( nan_inf_utils - SRCS nan_inf_utils_detail.cc nan_inf_utils_detail.cu - DEPS framework_proto scope place) + SRCS nan_inf_utils_detail.cc + DEPS framework_proto scope place phi) nv_library( all_reduce_op_handle SRCS all_reduce_op_handle.cc @@ -144,8 +144,8 @@ if(WITH_GPU) elseif(WITH_ROCM) hip_library( nan_inf_utils - SRCS nan_inf_utils_detail.cc nan_inf_utils_detail.cu - DEPS framework_proto scope place) + SRCS nan_inf_utils_detail.cc + DEPS framework_proto scope place phi) hip_library( all_reduce_op_handle SRCS all_reduce_op_handle.cc @@ -204,7 +204,7 @@ else() cc_library( nan_inf_utils SRCS nan_inf_utils_detail.cc - DEPS framework_proto scope place) + DEPS framework_proto scope place phi) cc_library( all_reduce_op_handle SRCS all_reduce_op_handle.cc diff --git a/paddle/fluid/framework/details/nan_inf_utils_detail.cc b/paddle/fluid/framework/details/nan_inf_utils_detail.cc index 5d6975df9c4..3316e8f53b0 100644 --- a/paddle/fluid/framework/details/nan_inf_utils_detail.cc +++ b/paddle/fluid/framework/details/nan_inf_utils_detail.cc @@ -157,29 +157,6 @@ static void InitWhiteListFormEnv() { } } -template <> -template -void TensorCheckerVisitor::apply( - typename std::enable_if< - std::is_floating_point::value || - std::is_same>::value || - std::is_same>::value>::type*) - const { - std::string cpu_hint_str = - GetCpuHintString(op_type, var_name, tensor.place()); - CheckNanInfCpuImpl(tensor.data(), tensor.numel(), cpu_hint_str); -} - -template <> -void tensor_check(const std::string& op_type, - const std::string& var_name, - const phi::DenseTensor& tensor, - const platform::Place& place) { - TensorCheckerVisitor vistor( - op_type, var_name, tensor, place); - VisitDataType(framework::TransToProtoVarType(tensor.dtype()), vistor); -} - void CheckVarHasNanOrInf(const std::string& op_type, const std::string& var_name, const framework::Variable* var, diff --git a/paddle/fluid/framework/details/nan_inf_utils_detail.h b/paddle/fluid/framework/details/nan_inf_utils_detail.h index ed2fa25a5ae..e88ccb8f2c7 100644 --- a/paddle/fluid/framework/details/nan_inf_utils_detail.h +++ b/paddle/fluid/framework/details/nan_inf_utils_detail.h @@ -13,25 +13,15 @@ // limitations under the License. #pragma once -#include -#include + #include +#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/platform/complex.h" -#include "paddle/fluid/platform/place.h" -#include "paddle/phi/common/amp_type_traits.h" -#include "paddle/phi/core/flags.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/phi/kernels/check_numerics_kernel.h" #include "paddle/phi/kernels/funcs/eigen/extensions.h" -#ifdef _WIN32 -#include -#include -#define MKDIR(path) _mkdir(path) -#else -#include -#define MKDIR(path) mkdir(path, S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH) -#endif -PHI_DECLARE_int32(check_nan_inf_level); namespace paddle { namespace framework { namespace details { @@ -44,284 +34,7 @@ void SetNanInfStackLimit(const int& stack_limit); int GetNanInfStackLimit(); -template ::value, bool> = true> -HOSTDEVICE bool NeedPrint(MT max_value, MT min_value, int check_nan_inf_level) { - if (check_nan_inf_level >= 3) { - return true; - } else if (check_nan_inf_level >= 2) { - MT fp16_max = - static_cast(std::numeric_limits::max()); - return max_value > fp16_max || min_value < -fp16_max; - } - return false; -} - -template ::value, bool> = true> -HOSTDEVICE bool NeedPrint(MT max_value UNUSED, - MT min_value UNUSED, - int check_nan_inf_level) { - if (check_nan_inf_level >= 3) { - return true; - } - return false; -} - -template -HOSTDEVICE void PrintForDifferentLevel(const char* debug_info, - int64_t numel, - int64_t num_nan, - int64_t num_inf, - int64_t num_zero, - MT max_value, - MT min_value, - MT mean_value, - int check_nan_inf_level) { - if (num_nan > 0 || num_inf > 0) { - printf( - "[PRECISION] [ERROR] in %s, numel=%lld, num_nan=%lld, " - "num_inf=%lld, num_zero=%lld, max=%e, min=%e, mean=%e\n", - debug_info, - static_cast(numel), // NOLINT - static_cast(num_nan), // NOLINT - static_cast(num_inf), // NOLINT - static_cast(num_zero), // NOLINT - static_cast(max_value), - static_cast(min_value), - static_cast(mean_value)); - if (check_nan_inf_level == 0) { -#if !(defined(__NVCC__) || defined(__HIPCC__)) - PADDLE_THROW(platform::errors::PreconditionNotMet( - "There are NAN or INF (num_nan=%lld, num_inf=%lld, num_zero=%lld) in " - "%s.", - static_cast(num_nan), // NOLINT - static_cast(num_inf), // NOLINT - static_cast(num_zero), // NOLINT - debug_info)); -#endif - } - } else if (NeedPrint(max_value, min_value, check_nan_inf_level)) { - printf( - "[PRECISION] in %s, numel=%lld, num_zero=%lld, max=%e, min=%e, " - "mean=%e\n", - debug_info, - static_cast(numel), // NOLINT - static_cast(num_zero), // NOLINT - static_cast(max_value), - static_cast(min_value), - static_cast(mean_value)); - } -} - -template -void PrintForDifferentLevelFile(const char* debug_info, - int64_t numel, - int64_t num_nan, - int64_t num_inf, - int64_t num_zero, - MT max_value, - MT min_value, - MT mean_value, - int check_nan_inf_level, - const std::string& log_name) { - int dev_id = 0; -#ifdef PADDLE_WITH_HIP - hipGetDevice(&dev_id); -#elif PADDLE_WITH_CUDA - cudaGetDevice(&dev_id); -#endif - auto file_path = GetNanPath(); - MKDIR(file_path.c_str()); - std::string file_name = "worker_" + log_name + "." + std::to_string(dev_id); - std::string path = file_path + file_name; - std::ofstream outfile(path, std::ios::app); - if (!outfile.is_open()) { - return; - } - - if (num_nan > 0 || num_inf > 0) { - outfile << "[PRECISION] [ERROR] in " << debug_info - << ", numel=" << static_cast(numel) // NOLINT - << ", num_nan=" << static_cast(num_nan) // NOLINT - << ", num_inf=" << static_cast(num_inf) // NOLINT - << ", num_zero=" << static_cast(num_zero) // NOLINT - << ", max=" << static_cast(max_value) - << ", min=" << static_cast(min_value) - << ", mean=" << static_cast(mean_value) << std::endl; - } else if (NeedPrint(max_value, min_value, check_nan_inf_level)) { - outfile << "[PRECISION] in " << debug_info - << ", numel=" << static_cast(numel) // NOLINT - << ", num_zero=" << static_cast(num_zero) // NOLINT - << ", max=" << static_cast(max_value) - << ", min=" << static_cast(min_value) - << ", mean=" << static_cast(mean_value) << std::endl; - } - outfile.close(); -} - -template -inline std::string GetCpuHintString(const std::string& op_type, - const std::string& var_name, - const phi::Place& place, - int device_id = -1) { - std::string dtype_str = DataTypeToString(DataTypeTrait::DataType()); - if (dtype_str == "float") { - dtype_str = "fp32"; - } else if (dtype_str == "double") { - dtype_str = "fp64"; - } else if (dtype_str == "::paddle::platform::float16") { - dtype_str = "fp16"; - } else if (dtype_str == "::paddle::platform::bfloat16") { - dtype_str = "bf16"; - } - - std::stringstream ss; - if (platform::is_gpu_place(place)) { - ss << "[device=gpu:" << device_id << ", "; - } else { - ss << "[device=cpu, "; - } - ss << "op=" << op_type << ", tensor=" << var_name << ", dtype=" << dtype_str - << "]"; - return ss.str(); -} - -template < - typename T, - std::enable_if_t>::value && - !std::is_same>::value, - bool> = true> -static void CheckNanInfCpuImpl(const T* value_ptr, - const int64_t numel, - const std::string& cpu_hint_str, - const std::string log_name = "cpu") { - using MT = typename phi::dtype::template MPTypeTrait::Type; - -#ifdef _OPENMP - // Use maximum 4 threads to collect the nan and inf information. - int num_threads = std::max(omp_get_num_threads(), 1); - num_threads = std::min(num_threads, 4); -#else - int num_threads = 1; -#endif - - std::vector thread_num_nan(num_threads, 0); - std::vector thread_num_inf(num_threads, 0); - std::vector thread_num_zero(num_threads, 0); - std::vector thread_min_value(num_threads, static_cast(value_ptr[0])); - std::vector thread_max_value(num_threads, static_cast(value_ptr[0])); - std::vector thread_mean_value(num_threads, static_cast(0)); - -#ifdef _OPENMP -#pragma omp parallel num_threads(num_threads) -#endif - { -#ifdef _OPENMP - int64_t tid = omp_get_thread_num(); - int64_t chunk_size = (numel + num_threads - 1) / num_threads; - int64_t begin = tid * chunk_size; - int64_t end = chunk_size + begin > numel ? numel : chunk_size + begin; -#else - int64_t tid = 0; - int64_t begin = 0; - int64_t end = numel; -#endif - for (int64_t i = begin; i < end; ++i) { - MT value = static_cast(value_ptr[i]); - - thread_min_value[tid] = std::min(thread_min_value[tid], value); - thread_max_value[tid] = std::max(thread_max_value[tid], value); - thread_mean_value[tid] += value / static_cast(numel); - - if (std::isnan(value)) { - thread_num_nan[tid] += 1; - } else if (std::isinf(value)) { - thread_num_inf[tid] += 1; - } - if (value == 0) { - thread_num_zero[tid] += 1; - } - } - } - - int64_t num_nan = 0; - int64_t num_inf = 0; - int64_t num_zero = 0; - MT min_value = thread_min_value[0]; - MT max_value = thread_max_value[0]; - MT mean_value = static_cast(0); - for (int i = 0; i < num_threads; ++i) { - num_nan += thread_num_nan[i]; - num_inf += thread_num_inf[i]; - num_zero += thread_num_zero[i]; - min_value = std::min(thread_min_value[i], min_value); - max_value = std::max(thread_max_value[i], max_value); - mean_value += thread_mean_value[i]; - } - auto file_path = GetNanPath(); - // Write log to file - if (file_path.size() > 0) { - VLOG(4) << "[FLAGS_check_nan_inf_level=" << FLAGS_check_nan_inf_level - << "]. Write log to " << file_path; - PrintForDifferentLevelFile(cpu_hint_str.c_str(), - numel, - num_nan, - num_inf, - num_zero, - max_value, - min_value, - mean_value, - FLAGS_check_nan_inf_level, - log_name); - return; - } - - PrintForDifferentLevel(cpu_hint_str.c_str(), - numel, - num_nan, - num_inf, - num_zero, - max_value, - min_value, - mean_value, - FLAGS_check_nan_inf_level); -} - -template < - typename T, - std::enable_if_t>::value || - std::is_same>::value, - bool> = true> -void CheckNanInfCpuImpl(const T* value_ptr, - const int64_t numel, - const std::string& cpu_hint_str, - const std::string log_name = "cpu") { - using RealType = typename T::value_type; - - RealType real_sum = 0.0f, imag_sum = 0.0f; - -#ifdef _OPENMP -#pragma omp parallel for reduction(+ : real_sum) reduction(+ : imag_sum) -#endif - for (int64_t i = 0; i < numel; ++i) { - T value = value_ptr[i]; - real_sum += (value.real - value.real); - imag_sum += (value.imag - value.imag); - } - - if (std::isnan(real_sum) || std::isinf(real_sum) || std::isnan(imag_sum) || - std::isinf(imag_sum)) { - // hot fix for compile failed in gcc4.8 - // here also need print detail info of nan or inf later - PADDLE_THROW(platform::errors::PreconditionNotMet( - "There are NAN or INF in %s.", cpu_hint_str)); - } -} - -template +template struct TensorCheckerVisitor { TensorCheckerVisitor(const std::string& o, const std::string& v, @@ -341,7 +54,14 @@ struct TensorCheckerVisitor { std::is_floating_point::value || std::is_same>::value || std::is_same>::value>::type* = - 0) const; + 0) const { + auto* dev_ctx = reinterpret_cast( + platform::DeviceContextPool::Instance().Get(tensor.place())); + + auto file_path = GetNanPath(); + phi::CheckNumericsKernel( + *dev_ctx, tensor, op_type, var_name, GetNanInfStackLimit(), file_path); + } std::string op_type; std::string var_name; @@ -349,11 +69,14 @@ struct TensorCheckerVisitor { const platform::Place& place; }; -template +template void tensor_check(const std::string& op_type, const std::string& var_name, const phi::DenseTensor& tensor, - const platform::Place& place); + const platform::Place& place) { + TensorCheckerVisitor vistor(op_type, var_name, tensor, place); + VisitDataType(framework::TransToProtoVarType(tensor.dtype()), vistor); +} } // namespace details } // namespace framework diff --git a/paddle/phi/kernels/check_numerics_kernel.h b/paddle/phi/kernels/check_numerics_kernel.h new file mode 100644 index 00000000000..634ce10bd86 --- /dev/null +++ b/paddle/phi/kernels/check_numerics_kernel.h @@ -0,0 +1,29 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void CheckNumericsKernel(const Context& ctx, + const DenseTensor& tensor, + const std::string& op_type, + const std::string& var_name, + const int stack_height_limit, + const std::string& output_dir); + +} // namespace phi diff --git a/paddle/phi/kernels/cpu/check_numerics_kernel.cc b/paddle/phi/kernels/cpu/check_numerics_kernel.cc new file mode 100644 index 00000000000..da9dd94f28e --- /dev/null +++ b/paddle/phi/kernels/cpu/check_numerics_kernel.cc @@ -0,0 +1,54 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/check_numerics_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/flags.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/check_numerics_utils.h" + +PHI_DECLARE_int32(check_nan_inf_level); + +namespace phi { + +template +void CheckNumericsKernel(const Context& ctx, + const DenseTensor& tensor, + const std::string& op_type, + const std::string& var_name, + const int stack_height_limit, + const std::string& output_dir) { + std::string cpu_hint_str = + phi::funcs::GetCpuHintString(op_type, var_name, tensor.place()); + phi::funcs::CheckNumericsCpuImpl(tensor.data(), + tensor.numel(), + cpu_hint_str, + FLAGS_check_nan_inf_level, + "cpu", + output_dir); +} + +} // namespace phi + +PD_REGISTER_KERNEL(check_numerics, + CPU, + ALL_LAYOUT, + phi::CheckNumericsKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/funcs/check_numerics_utils.h b/paddle/phi/kernels/funcs/check_numerics_utils.h new file mode 100644 index 00000000000..9ed247c618d --- /dev/null +++ b/paddle/phi/kernels/funcs/check_numerics_utils.h @@ -0,0 +1,326 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#ifdef PADDLE_WITH_MKLML +#include +#endif +#include + +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/enforce.h" + +#ifdef _WIN32 +#include +#include +#define MKDIR(path) _mkdir(path) +#else +#include +#define MKDIR(path) mkdir(path, S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH) +#endif + +namespace phi { +namespace funcs { + +template ::value, bool> = true> +HOSTDEVICE bool NeedPrint(MT max_value, MT min_value, int check_nan_inf_level) { + if (check_nan_inf_level >= 3) { + return true; + } else if (check_nan_inf_level >= 2) { + MT fp16_max = + static_cast(std::numeric_limits::max()); + return max_value > fp16_max || min_value < -fp16_max; + } + return false; +} + +template ::value, bool> = true> +HOSTDEVICE bool NeedPrint(MT max_value UNUSED, + MT min_value UNUSED, + int check_nan_inf_level) { + if (check_nan_inf_level >= 3) { + return true; + } + return false; +} + +HOSTDEVICE static void PrintAndThrowError(const char* debug_info, + int64_t num_nan, + int64_t num_inf, + int64_t num_zero) { +#if !defined(__HIPCC__) && !defined(__CUDA_ARCH__) + PADDLE_THROW(phi::errors::PreconditionNotMet( + "There are NAN or INF (num_nan=%lld, num_inf=%lld, num_zero=%lld) in " + "%s.", + static_cast(num_nan), // NOLINT + static_cast(num_inf), // NOLINT + static_cast(num_zero), // NOLINT + debug_info)); +#endif +} + +template +HOSTDEVICE void PrintForDifferentLevel(const char* debug_info, + int64_t numel, + int64_t num_nan, + int64_t num_inf, + int64_t num_zero, + MT max_value, + MT min_value, + MT mean_value, + int check_nan_inf_level) { + if (num_nan > 0 || num_inf > 0) { + printf( + "[PRECISION] [ERROR] in %s, numel=%lld, num_nan=%lld, " + "num_inf=%lld, num_zero=%lld, max=%e, min=%e, mean=%e\n", + debug_info, + static_cast(numel), // NOLINT + static_cast(num_nan), // NOLINT + static_cast(num_inf), // NOLINT + static_cast(num_zero), // NOLINT + static_cast(max_value), + static_cast(min_value), + static_cast(mean_value)); + if (check_nan_inf_level == 0) { + PrintAndThrowError(debug_info, num_nan, num_inf, num_zero); + } + } else if (NeedPrint(max_value, min_value, check_nan_inf_level)) { + printf( + "[PRECISION] in %s, numel=%lld, num_zero=%lld, max=%e, min=%e, " + "mean=%e\n", + debug_info, + static_cast(numel), // NOLINT + static_cast(num_zero), // NOLINT + static_cast(max_value), + static_cast(min_value), + static_cast(mean_value)); + } +} + +template +void WriteToFileForDifferentLevel(const char* debug_info, + int64_t numel, + int64_t num_nan, + int64_t num_inf, + int64_t num_zero, + MT max_value, + MT min_value, + MT mean_value, + int check_nan_inf_level, + const std::string& log_name, + const std::string output_dir) { + MKDIR(output_dir.c_str()); + std::string filename = output_dir + "worker_" + log_name; + std::ofstream outfile(filename, std::ios::app); + PADDLE_ENFORCE_EQ( + outfile.is_open(), + true, + phi::errors::Unavailable("Fail to open output file %s, please check the " + "specified output_dir (%s).", + filename, + output_dir)); + + if (num_nan > 0 || num_inf > 0) { + outfile << "[PRECISION] [ERROR] in " << debug_info + << ", numel=" << static_cast(numel) // NOLINT + << ", num_nan=" << static_cast(num_nan) // NOLINT + << ", num_inf=" << static_cast(num_inf) // NOLINT + << ", num_zero=" << static_cast(num_zero) // NOLINT + << std::scientific << std::setprecision(6) + << ", max=" << static_cast(max_value) + << ", min=" << static_cast(min_value) + << ", mean=" << static_cast(mean_value) << std::endl; + } else if (phi::funcs::NeedPrint( + max_value, min_value, check_nan_inf_level)) { + outfile << "[PRECISION] in " << debug_info + << ", numel=" << static_cast(numel) // NOLINT + << ", num_zero=" << static_cast(num_zero) // NOLINT + << std::scientific << std::setprecision(6) + << ", max=" << static_cast(max_value) + << ", min=" << static_cast(min_value) + << ", mean=" << static_cast(mean_value) << std::endl; + } + outfile.close(); +} + +template +inline std::string GetCpuHintString(const std::string& op_type, + const std::string& var_name, + const phi::Place& place, + int device_id = -1) { + std::string dtype_str; + phi::DataType dtype = phi::CppTypeToDataType::Type(); + if (dtype == DataType::FLOAT32) { + dtype_str = "fp32"; + } else if (dtype == DataType::FLOAT64) { + dtype_str = "fp64"; + } else if (dtype == DataType::FLOAT16) { + dtype_str = "fp16"; + } else if (dtype == DataType::BFLOAT16) { + dtype_str = "bf16"; + } + + std::stringstream ss; + if (place.GetType() == phi::AllocationType::GPU) { + ss << "[device=gpu:" << device_id << ", "; + } else { + ss << "[device=cpu, "; + } + ss << "op=" << op_type << ", tensor=" << var_name << ", dtype=" << dtype_str + << "]"; + return ss.str(); +} + +template < + typename T, + std::enable_if_t>::value && + !std::is_same>::value, + bool> = true> +static void CheckNumericsCpuImpl(const T* value_ptr, + const int64_t numel, + const std::string& cpu_hint_str, + const int check_nan_inf_level, + const std::string log_name = "cpu", + const std::string output_dir = "") { + using MT = typename phi::dtype::template MPTypeTrait::Type; + +#ifdef _OPENMP + // Use maximum 4 threads to collect the nan and inf information. + int num_threads = std::max(omp_get_num_threads(), 1); + num_threads = std::min(num_threads, 4); +#else + int num_threads = 1; +#endif + + std::vector thread_num_nan(num_threads, 0); + std::vector thread_num_inf(num_threads, 0); + std::vector thread_num_zero(num_threads, 0); + std::vector thread_min_value(num_threads, static_cast(value_ptr[0])); + std::vector thread_max_value(num_threads, static_cast(value_ptr[0])); + std::vector thread_mean_value(num_threads, static_cast(0)); + +#ifdef _OPENMP +#pragma omp parallel num_threads(num_threads) +#endif + { +#ifdef _OPENMP + int64_t tid = omp_get_thread_num(); + int64_t chunk_size = (numel + num_threads - 1) / num_threads; + int64_t begin = tid * chunk_size; + int64_t end = chunk_size + begin > numel ? numel : chunk_size + begin; +#else + int64_t tid = 0; + int64_t begin = 0; + int64_t end = numel; +#endif + for (int64_t i = begin; i < end; ++i) { + MT value = static_cast(value_ptr[i]); + + thread_min_value[tid] = std::min(thread_min_value[tid], value); + thread_max_value[tid] = std::max(thread_max_value[tid], value); + thread_mean_value[tid] += value / static_cast(numel); + + if (std::isnan(value)) { + thread_num_nan[tid] += 1; + } else if (std::isinf(value)) { + thread_num_inf[tid] += 1; + } + if (value == 0) { + thread_num_zero[tid] += 1; + } + } + } + + int64_t num_nan = 0; + int64_t num_inf = 0; + int64_t num_zero = 0; + MT min_value = thread_min_value[0]; + MT max_value = thread_max_value[0]; + MT mean_value = static_cast(0); + for (int i = 0; i < num_threads; ++i) { + num_nan += thread_num_nan[i]; + num_inf += thread_num_inf[i]; + num_zero += thread_num_zero[i]; + min_value = std::min(thread_min_value[i], min_value); + max_value = std::max(thread_max_value[i], max_value); + mean_value += thread_mean_value[i]; + } + + // Write log to file + if (output_dir.size() > 0) { + WriteToFileForDifferentLevel(cpu_hint_str.c_str(), + numel, + num_nan, + num_inf, + num_zero, + max_value, + min_value, + mean_value, + check_nan_inf_level, + log_name, + output_dir); + } else { + PrintForDifferentLevel(cpu_hint_str.c_str(), + numel, + num_nan, + num_inf, + num_zero, + max_value, + min_value, + mean_value, + check_nan_inf_level); + } +} + +template < + typename T, + std::enable_if_t>::value || + std::is_same>::value, + bool> = true> +void CheckNumericsCpuImpl(const T* value_ptr, + const int64_t numel, + const std::string& cpu_hint_str, + const int check_nan_inf_level, + const std::string log_name = "cpu", + const std::string output_dir = "") { + using RealType = typename T::value_type; + + RealType real_sum = 0.0f, imag_sum = 0.0f; + +#ifdef _OPENMP +#pragma omp parallel for reduction(+ : real_sum) reduction(+ : imag_sum) +#endif + for (int64_t i = 0; i < numel; ++i) { + T value = value_ptr[i]; + real_sum += (value.real - value.real); + imag_sum += (value.imag - value.imag); + } + + if (std::isnan(real_sum) || std::isinf(real_sum) || std::isnan(imag_sum) || + std::isinf(imag_sum)) { + // hot fix for compile failed in gcc4.8 + // here also need print detail info of nan or inf later + PADDLE_THROW(phi::errors::PreconditionNotMet("There are NAN or INF in %s.", + cpu_hint_str)); + } +} + +} // namespace funcs +} // namespace phi diff --git a/paddle/fluid/framework/details/nan_inf_utils_detail.cu b/paddle/phi/kernels/gpu/check_numerics_kernel.cu similarity index 72% rename from paddle/fluid/framework/details/nan_inf_utils_detail.cu rename to paddle/phi/kernels/gpu/check_numerics_kernel.cu index 5569a6f29af..eb9fc6af66c 100644 --- a/paddle/fluid/framework/details/nan_inf_utils_detail.cu +++ b/paddle/phi/kernels/gpu/check_numerics_kernel.cu @@ -1,44 +1,41 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/details/nan_inf_utils_detail.h" -#include "paddle/fluid/framework/details/nan_inf_utils.h" - -#include -#include -#include -#include - -#include "paddle/fluid/framework/convert_utils.h" -#include "paddle/fluid/framework/scope.h" +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/check_numerics_kernel.h" + +#include "glog/logging.h" +#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/common/float16.h" #include "paddle/phi/common/memory_utils.h" #include "paddle/phi/core/flags.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/check_numerics_utils.h" #include "paddle/phi/kernels/funcs/math_cuda_utils.h" PHI_DECLARE_int32(check_nan_inf_level); -namespace paddle { -namespace framework { -namespace details { +namespace phi { static std::once_flag init_multi_gpu_op_var_map_flag; // lazy init -static std::vector>& +static std::vector< + std::unordered_map>& multi_op_var2gpu_str() { - static std::vector> + static std::vector< + std::unordered_map> _multi_op_var2gpu_str; return _multi_op_var2gpu_str; } @@ -49,15 +46,15 @@ static std::vector& multi_op_var2gpu_str_mutex() { } static void InitMultiGPUOpVarMap() { - int dev_count = platform::GetGPUDeviceCount(); + int dev_count = phi::backends::gpu::GetGPUDeviceCount(); PADDLE_ENFORCE_GT(dev_count, 0, - platform::errors::NotFound( + phi::errors::NotFound( "cuda device must > 0, now dev_count=%d", dev_count)); // https://stackoverflow.com/questions/16465633/how-can-i-use-something-like-stdvectorstdmutex - std::vector> tmp_multi( - dev_count); + std::vector> + tmp_multi(dev_count); std::vector tmp_multi_mutex(dev_count); multi_op_var2gpu_str().swap(tmp_multi); @@ -297,7 +294,7 @@ __global__ void FindGlobalMaxMinAndPrint(const int64_t* block_num_nan_ptr, int64_t numel, int64_t numel_max_min, int check_nan_inf_level, - int64_t* nan_inf_zero) { + int64_t* nan_inf_zero_ptr) { if (blockIdx.x == 0 && threadIdx.x == 0) { int64_t num_nan = 0; int64_t num_inf = 0; @@ -329,20 +326,21 @@ __global__ void FindGlobalMaxMinAndPrint(const int64_t* block_num_nan_ptr, mean_value += tmp_mean_value; } if (check_nan_inf_level == 0) { - nan_inf_zero[0] = num_nan; - nan_inf_zero[1] = num_inf; - nan_inf_zero[2] = num_zero; + nan_inf_zero_ptr[0] = num_nan; + nan_inf_zero_ptr[1] = num_inf; + nan_inf_zero_ptr[2] = num_zero; } } - PrintForDifferentLevel(debug_info, - numel, - num_nan, - num_inf, - num_zero, - max_value, - min_value, - mean_value, - check_nan_inf_level); + + phi::funcs::PrintForDifferentLevel(debug_info, + numel, + num_nan, + num_inf, + num_zero, + max_value, + min_value, + mean_value, + check_nan_inf_level); } } @@ -351,12 +349,13 @@ inline std::string GetHintString(const std::string& op_type, const std::string& var_name, const phi::Place& place, int dev_id = -1) { - std::string op_var = GetCpuHintString(op_type, var_name, place, dev_id); + std::string op_var = + phi::funcs::GetCpuHintString(op_type, var_name, place, dev_id); PADDLE_ENFORCE_EQ( (dev_id >= 0 && dev_id < multi_op_var2gpu_str_mutex().size()), true, - platform::errors::OutOfRange("GPU dev_id must >=0 and < dev_count=%d", - multi_op_var2gpu_str_mutex().size())); + phi::errors::OutOfRange("GPU dev_id must >=0 and < dev_count=%d", + multi_op_var2gpu_str_mutex().size())); return op_var; } @@ -375,7 +374,7 @@ static char* GetGpuHintStringPtr(const phi::GPUContext& ctx, std::lock_guard guard(op_var2gpu_str_mutex); if (op_var2gpu_str.find(op_var) == op_var2gpu_str.end()) { // insert - auto gpu_str_tensor = paddle::memory::Alloc( + auto gpu_str_tensor = phi::memory_utils::Alloc( ctx.GetPlace(), op_var.length() + 1, phi::Stream(reinterpret_cast(ctx.stream()))); @@ -386,7 +385,7 @@ static char* GetGpuHintStringPtr(const phi::GPUContext& ctx, auto iter = op_var2gpu_str.find(op_var); PADDLE_ENFORCE_EQ(iter != op_var2gpu_str.end(), true, - platform::errors::PreconditionNotMet( + phi::errors::PreconditionNotMet( "op_var=%s should successed insert into " "op_var2gpu_str, but now failed", op_var)); @@ -408,7 +407,7 @@ static char* GetGpuHintStringPtr(const phi::GPUContext& ctx, auto iter = op_var2gpu_str.find(op_var); PADDLE_ENFORCE_EQ(iter != op_var2gpu_str.end(), true, - platform::errors::PreconditionNotMet( + phi::errors::PreconditionNotMet( "op_var=%s should be in the op_var2gpu_str, but " "now can't find it", op_var)); @@ -418,37 +417,41 @@ static char* GetGpuHintStringPtr(const phi::GPUContext& ctx, return gpu_str_ptr; } -template <> -template -void TensorCheckerVisitor::apply( - typename std::enable_if< - std::is_floating_point::value || - std::is_same>::value || - std::is_same>::value>::type*) - const { - auto* dev_ctx = reinterpret_cast( - platform::DeviceContextPool::Instance().Get(tensor.place())); +template +void CheckNumericsKernel(const Context& ctx, + const DenseTensor& tensor, + const std::string& op_type, + const std::string& var_name, + const int stack_height_limit, + const std::string& output_dir) { + std::call_once(init_multi_gpu_op_var_map_flag, InitMultiGPUOpVarMap); + int dev_id = tensor.place().device; - // Write log to file - auto file_path = GetNanPath(); - if (file_path.size() > 0) { + VLOG(6) << "op_type=" << op_type << ", var_name=" << var_name + << ", dev_id=gpu:" << dev_id + << ", stack_height_limit=" << stack_height_limit + << ", output_dir=" << output_dir; + + // Write log to output_dir. + if (output_dir.size() > 0) { phi::DenseTensor cpu_tensor; - platform::CPUPlace cpu_place; cpu_tensor.Resize(tensor.dims()); - // 1. copy from gpu to cpu - paddle::framework::TensorCopySync(tensor, cpu_place, &cpu_tensor); - auto* dev_ctx = reinterpret_cast( - platform::DeviceContextPool::Instance().Get(tensor.place())); + // Copy tensor from GPU to CPU. + phi::Copy(ctx, tensor, CPUPlace(), true, &cpu_tensor); const std::string debug_info = - GetHintString(op_type, var_name, place, dev_id); - // 2. write log to file - CheckNanInfCpuImpl(cpu_tensor.data(), tensor.numel(), debug_info, "gpu"); + GetHintString(op_type, var_name, tensor.place(), dev_id); + std::string log_name = "gpu." + std::to_string(dev_id); + phi::funcs::CheckNumericsCpuImpl(cpu_tensor.data(), + tensor.numel(), + debug_info, + FLAGS_check_nan_inf_level, + log_name, + output_dir); return; } - // Write log to window - char* gpu_str_ptr = - GetGpuHintStringPtr(*dev_ctx, op_type, var_name, dev_id); + // Print to the standard output. + char* gpu_str_ptr = GetGpuHintStringPtr(ctx, op_type, var_name, dev_id); #ifdef __HIPCC__ // HIP will throw GPU memory access fault if threads > 256 @@ -466,7 +469,7 @@ void TensorCheckerVisitor::apply( dim3(blocks), dim3(threads), 0, - dev_ctx->stream(), + ctx.stream(), tensor.data(), tensor.numel(), print_num, @@ -479,83 +482,75 @@ void TensorCheckerVisitor::apply( phi::DenseTensor block_num_nan_inf_zero; block_num_nan_inf_zero.Resize({static_cast(3 * numel_max_min)}); int64_t* block_num_nan_ptr = - dev_ctx->template Alloc(&block_num_nan_inf_zero); + ctx.template Alloc(&block_num_nan_inf_zero); int64_t* block_num_inf_ptr = block_num_nan_ptr + numel_max_min; int64_t* block_num_zero_ptr = block_num_inf_ptr + numel_max_min; phi::DenseTensor tensor_block_max_min; tensor_block_max_min.Resize({static_cast(3 * numel_max_min)}); - MT* tensor_block_max_ptr = dev_ctx->template Alloc(&tensor_block_max_min); + MT* tensor_block_max_ptr = ctx.template Alloc(&tensor_block_max_min); MT* tensor_block_min_ptr = tensor_block_max_ptr + numel_max_min; MT* tensor_block_mean_ptr = tensor_block_max_ptr + 2 * numel_max_min; FindNanInfAndBlockMaxMin - <<stream()>>>(tensor.data(), - tensor.numel(), - block_num_nan_ptr, - block_num_inf_ptr, - block_num_zero_ptr, - tensor_block_max_ptr, - tensor_block_min_ptr, - tensor_block_mean_ptr); + <<>>(tensor.data(), + tensor.numel(), + block_num_nan_ptr, + block_num_inf_ptr, + block_num_zero_ptr, + tensor_block_max_ptr, + tensor_block_min_ptr, + tensor_block_mean_ptr); int check_nan_inf_level = FLAGS_check_nan_inf_level; + phi::DenseTensor nan_inf_zero_tensor; nan_inf_zero_tensor.Resize({static_cast(3)}); - int64_t* nan_inf_zero = - dev_ctx->template Alloc(&nan_inf_zero_tensor); + int64_t* nan_inf_zero_ptr = ctx.template Alloc(&nan_inf_zero_tensor); + FindGlobalMaxMinAndPrint - <<<1, 1, 0, dev_ctx->stream()>>>(block_num_nan_ptr, - block_num_inf_ptr, - block_num_zero_ptr, - tensor_block_max_ptr, - tensor_block_min_ptr, - tensor_block_mean_ptr, - gpu_str_ptr, - tensor.numel(), - numel_max_min, - check_nan_inf_level, - nan_inf_zero_tensor.data()); - - if (check_nan_inf_level == 0 && GetNanInfStackLimit() > 0) { + <<<1, 1, 0, ctx.stream()>>>(block_num_nan_ptr, + block_num_inf_ptr, + block_num_zero_ptr, + tensor_block_max_ptr, + tensor_block_min_ptr, + tensor_block_mean_ptr, + gpu_str_ptr, + tensor.numel(), + numel_max_min, + check_nan_inf_level, + nan_inf_zero_ptr); + + if (check_nan_inf_level == 0 && stack_height_limit > 0) { auto nan_cpu = phi::memory_utils::Alloc(phi::CPUPlace(), sizeof(int64_t) * 3); int64_t* nan_cpu_ptr = reinterpret_cast(nan_cpu->ptr()); phi::memory_utils::Copy(phi::CPUPlace(), nan_cpu_ptr, - place, - nan_inf_zero, + tensor.place(), + nan_inf_zero_ptr, 3 * sizeof(int64_t), - dev_ctx->stream()); - - dev_ctx->Wait(); + ctx.stream()); + ctx.Wait(); if (nan_cpu_ptr[0] > 0 || nan_cpu_ptr[1] > 0) { const std::string debug_info = - GetHintString(op_type, var_name, place, dev_id); - PADDLE_THROW(platform::errors::PreconditionNotMet( - "There are NAN or INF (num_nan=%lld, num_inf=%lld, num_zero=%lld) in " - "%s.", - static_cast(nan_cpu_ptr[0]), // NOLINT - static_cast(nan_cpu_ptr[1]), // NOLINT - static_cast(nan_cpu_ptr[2]), // NOLINT - debug_info)); + GetHintString(op_type, var_name, tensor.place(), dev_id); + phi::funcs::PrintAndThrowError( + debug_info.c_str(), nan_cpu_ptr[0], nan_cpu_ptr[1], nan_cpu_ptr[2]); } } #endif } -template <> -void tensor_check(const std::string& op_type, - const std::string& var_name, - const phi::DenseTensor& tensor, - const platform::Place& place) { - std::call_once(init_multi_gpu_op_var_map_flag, InitMultiGPUOpVarMap); - - TensorCheckerVisitor vistor( - op_type, var_name, tensor, place); - VisitDataType(framework::TransToProtoVarType(tensor.dtype()), vistor); -} - -} // namespace details -} // namespace framework -} // namespace paddle +} // namespace phi + +PD_REGISTER_KERNEL(check_numerics, + GPU, + ALL_LAYOUT, + phi::CheckNumericsKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/python/paddle/amp/accuracy_compare.py b/python/paddle/amp/accuracy_compare.py index 85f8f78ac0d..e7c77881375 100644 --- a/python/paddle/amp/accuracy_compare.py +++ b/python/paddle/amp/accuracy_compare.py @@ -33,6 +33,7 @@ def is_allclose(actual, expected, atol=1e-2, rtol=1e-2): class TensorInfo: def __init__(self): + self.device = None self.op_type = None self.tensor_name = None self.dtype = None @@ -45,7 +46,8 @@ class TensorInfo: self.num_zero = None def __str__(self): - return "[TensorInfo] op_type={}, tensor_name={}, dtype={}, numel={}, has_inf={}, has_nan={}, num_zero={}, max_value={:.6f}, min_value={:.6f}, mean_value={:.6f}".format( + return "[TensorInfo] device={}, op_type={}, tensor_name={}, dtype={}, numel={}, num_inf={}, num_nan={}, num_zero={}, max_value={:.6f}, min_value={:.6f}, mean_value={:.6f}".format( + self.device, self.op_type, self.tensor_name, self.dtype, @@ -73,6 +75,8 @@ class TensorInfo: words = word_str.split("=") if words[0] == "op": self.op_type = words[1] + elif words[0] == "device": + self.device = words[1] elif words[0] == "tensor": self.tensor_name = words[1] elif words[0] == "dtype": @@ -380,8 +384,8 @@ class ExcelWriter: "max_value": 16, "min_value": 16, "mean_value": 16, - "has_inf": 8, - "has_nan": 8, + "num_inf": 8, + "num_nan": 8, } title_names = ["op_type", "tensor_name", "numel", "infinite"] if self.log_fp16_dir is None: @@ -412,8 +416,8 @@ class ExcelWriter: "min_value", "mean_value", "num_zero", - "has_inf", - "has_nan", + "num_inf", + "num_nan", ] ) else: @@ -435,8 +439,8 @@ class ExcelWriter: "min_value", "mean_value", "num_zero", - "has_inf", - "has_nan", + "num_inf", + "num_nan", "max_value", "min_value", "mean_value", @@ -572,39 +576,45 @@ class ExcelWriter: print(f"-- OP Types produce infinite outputs: {infinite_op_types}") +def parse_lines(lines, specified_op_list=None): + tensor_info_list = [] + + for i in range(len(lines)): + if i % 10 == 0: + print( + f"-- Processing {i:-8d} / {len(lines):-8d} line", + end="\r", + ) + line = lines[i] + if "[PRECISION]" in line: + tensor_info = TensorInfo() + tensor_info.init_from_string(line) + if ( + tensor_info.tensor_name is not None + and tensor_info.tensor_name != "" + ): + has_tensor_name = True + if ( + specified_op_list is None + or tensor_info.op_type in specified_op_list + ): + tensor_info_list.append(tensor_info) + # print(tensor_info) + return tensor_info_list + + def parse_log(log_dir, filename, specified_op_list=None): if log_dir is None or filename is None: return None complete_filename = log_dir + "/" + filename - tensor_info_list = [] + tensor_info_list = None has_tensor_name = False try: with open(complete_filename, 'r') as f: lines = f.readlines() - for i in range(len(lines)): - if i % 10 == 0: - print( - f"-- Processing {i:-8d} / {len(lines):-8d} line", - end="\r", - ) - # [op=adamw] [tensor=encoder_layer_20_multi_head_att_output_fc_0.w_0], numel: 294912, max: 0.005773, min: -0.005774 - line = lines[i] - if "[PRECISION]" in line: - tensor_info = TensorInfo() - tensor_info.init_from_string(line) - if ( - tensor_info.tensor_name is not None - and tensor_info.tensor_name != "" - ): - has_tensor_name = True - if ( - specified_op_list is None - or tensor_info.op_type in specified_op_list - ): - tensor_info_list.append(tensor_info) - # print(tensor_info) + tensor_info_list = parse_lines(lines, specified_op_list) except FileNotFoundError: print("the file ", complete_filename, "is not found") return None, has_tensor_name diff --git a/python/paddle/fluid/tests/unittests/check_nan_inf_base.py b/python/paddle/fluid/tests/unittests/check_nan_inf_base.py index f48df68badd..8db773ef27c 100644 --- a/python/paddle/fluid/tests/unittests/check_nan_inf_base.py +++ b/python/paddle/fluid/tests/unittests/check_nan_inf_base.py @@ -17,7 +17,6 @@ import os import numpy as np os.environ["FLAGS_check_nan_inf"] = "1" -os.environ["GLOG_vmodule"] = "nan_inf_utils_detail=10" import paddle from paddle import fluid @@ -59,8 +58,7 @@ def net(): hidden = x - for i in range(2): - hidden = paddle.static.nn.fc(x=hidden, size=400, activation="sigmoid") + hidden = paddle.static.nn.fc(x=hidden, size=400, activation="sigmoid") hidden = paddle.static.nn.fc(x=hidden, size=3) cost, y_predict = paddle.nn.functional.softmax_with_cross_entropy( diff --git a/python/paddle/fluid/tests/unittests/check_nan_inf_base_dygraph.py b/python/paddle/fluid/tests/unittests/check_nan_inf_base_dygraph.py index 0e033de896a..2bb615c76a0 100644 --- a/python/paddle/fluid/tests/unittests/check_nan_inf_base_dygraph.py +++ b/python/paddle/fluid/tests/unittests/check_nan_inf_base_dygraph.py @@ -12,105 +12,114 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os +import argparse import numpy as np -os.environ["FLAGS_check_nan_inf"] = "1" -os.environ["GLOG_vmodule"] = "nan_inf_utils_detail=10" - import paddle from paddle import nn -np.random.seed(0) +# os.environ["GLOG_vmodule"] = "nan_inf_utils_detail=10" -def generator(): - batch_size = 5 - for i in range(5): - curr_train_x = np.random.randint( - batch_size, size=(batch_size, 3) - ).astype("float32") - if i >= 2: - curr_train_x[0, :] = np.nan - curr_train_x[-1, :] = np.inf - res = [] - for i in range(batch_size): - y = i % 3 - res.append([y]) - y_label = np.array(res).astype('int64') - yield [curr_train_x, y_label] +paddle.seed(0) +np.random.seed(0) class TestLayer(nn.Layer): def __init__(self): super().__init__() - self.linear1 = nn.Linear(3, 400) - self.linear2 = nn.Linear(400, 400) - self.linear3 = nn.Linear(400, 3) + w_1_np = np.random.random([32, 400]).astype("float32") + self.linear1 = nn.Linear( + in_features=32, + out_features=400, + weight_attr=paddle.ParamAttr( + initializer=paddle.nn.initializer.Assign(w_1_np) + ), + ) + w_2_np = np.random.random([400, 10]).astype("float32") + self.linear2 = nn.Linear( + in_features=400, + out_features=10, + weight_attr=paddle.ParamAttr( + initializer=paddle.nn.initializer.Assign(w_2_np) + ), + ) def forward(self, x): - x = self.linear1(x) - x = nn.functional.sigmoid(x) - x = self.linear2(x) - x = nn.functional.sigmoid(x) - x = self.linear3(x) - x = nn.functional.softmax(x) - - return x + out = self.linear1(x) + out = nn.functional.sigmoid(out) + out = self.linear2(out) + mask = paddle.randint(low=0, high=2, shape=out.shape).astype("float32") + out = paddle.divide(out, mask) + out = nn.functional.softmax(out) + return out -def check(use_cuda): +def check_main(use_cuda, use_amp=False): paddle.set_device('gpu' if use_cuda else 'cpu') - net = TestLayer() - sgd = paddle.optimizer.SGD(learning_rate=0.05, parameters=net.parameters()) - - for step, (x, y) in enumerate(generator()): - x = paddle.to_tensor(x) - y = paddle.to_tensor(y) - - zero = paddle.zeros(shape=[1], dtype='int64') - fp16_zero = paddle.cast(zero, dtype='float16') - - y = y + zero - - y_pred = net(x) - - cost = nn.functional.cross_entropy(y_pred, y, use_softmax=False) - avg_cost = paddle.mean(cost) - - acc_top1 = paddle.metric.accuracy(input=y_pred, label=y, k=1) - - print( - 'iter={:.0f}, cost={}, acc1={}'.format( - step, avg_cost.numpy(), acc_top1.numpy() - ) - ) - - sgd.step() - sgd.clear_grad() - - -def run_check(): - if paddle.is_compiled_with_cuda(): - try: - check(use_cuda=True) - raise AssertionError() - except Exception as e: - print(e) - print(type(e)) - # Note. Enforce in cuda kernel may not catch in paddle, and - # Exception type will be RuntimeError - assert type(e) == OSError or type(e) == RuntimeError - try: - check(use_cuda=False) - raise AssertionError() - except Exception as e: - print(e) - print(type(e)) - assert type(e) == RuntimeError + model = TestLayer() + sgd = paddle.optimizer.SGD( + learning_rate=0.05, parameters=model.parameters() + ) + + if use_cuda and use_amp: + scaler = paddle.amp.GradScaler() + + x_np = 10000 * np.random.random([128, 32]).astype("float32") + + x = paddle.to_tensor(x_np) + if use_cuda and use_amp: + with paddle.amp.auto_cast(enable=True, dtype="float16", level="O1"): + out = model(x) + loss = paddle.mean(out) + scaled = scaler.scale(loss) + scaled.backward() + scaler.minimize(sgd, scaled) + else: + out = model(x) + loss = paddle.mean(out) + loss.backward() + sgd.step() + sgd.clear_grad() + + +def run_check(args): + paddle.set_flags( + { + "FLAGS_check_nan_inf": 1, + "FLAGS_check_nan_inf_level": args.check_nan_inf_level, + } + ) + use_cuda = args.use_cuda and paddle.is_compiled_with_cuda() + if args.check_nan_inf_level == 0: + if use_cuda: + try: + check_main(use_cuda=True, use_amp=args.use_amp) + raise AssertionError() + except Exception as e: + print(e) + print(type(e)) + # Note. Enforce in cuda kernel may not catch in paddle, and + # Exception type will be RuntimeError + assert type(e) == OSError or type(e) == RuntimeError + else: + try: + check_main(use_cuda=False, use_amp=False) + raise AssertionError() + except Exception as e: + print(e) + print(type(e)) + assert type(e) == RuntimeError + else: + check_main(use_cuda=use_cuda, use_amp=args.use_amp) if __name__ == '__main__': - run_check() + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument('--use_cuda', action='store_true', default=False) + parser.add_argument('--use_amp', action='store_true', default=False) + parser.add_argument('--check_nan_inf_level', type=int, default=0) + args = parser.parse_args() + run_check(args) diff --git a/python/paddle/fluid/tests/unittests/test_nan_inf.py b/python/paddle/fluid/tests/unittests/test_nan_inf.py index 851c46c3b89..b44fc4637c0 100644 --- a/python/paddle/fluid/tests/unittests/test_nan_inf.py +++ b/python/paddle/fluid/tests/unittests/test_nan_inf.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import os import subprocess import sys @@ -22,7 +23,7 @@ import numpy as np import paddle -class TestNanInf(unittest.TestCase): +class TestNanInfBase(unittest.TestCase): def setUp(self): self._python_interp = sys.executable if os.getenv('WITH_COVERAGE', 'OFF') == 'ON': @@ -30,9 +31,8 @@ class TestNanInf(unittest.TestCase): self.env = os.environ.copy() - def check_nan_inf(self): - cmd = self._python_interp - + def run_command(self, cmd): + print(f"Run command: {cmd}") proc = subprocess.Popen( cmd.split(" "), stdout=subprocess.PIPE, @@ -42,20 +42,100 @@ class TestNanInf(unittest.TestCase): out, err = proc.communicate() returncode = proc.returncode - # in python3, type(out+err) is 'bytes', need use encode - assert (out + err).find(b'There are NAN or INF') != -1 + return returncode, out, err - def test_nan_inf_in_static_mode(self): - self._python_interp += ( - " " + os.path.dirname(__file__) + "/check_nan_inf_base.py" - ) - self.check_nan_inf() - def test_nan_inf_in_dynamic_mode(self): - self._python_interp += ( - " " + os.path.dirname(__file__) + "/check_nan_inf_base_dygraph.py" - ) - self.check_nan_inf() +class TestNanInf(TestNanInfBase): + def setUp(self): + super().setUp() + self.check_static = True + self.check_dygraph = True + self.check_nan_inf_level = 0 + self.dygraph_expected_op_count = {"divide": 1} + + def check_op_count(self, log, expected_op_count=None): + if expected_op_count is None: + return + + lines = copy.copy(log).decode().split("\n") + actual_op_count = {} + tensor_info_list = paddle.amp.accuracy_compare.parse_lines(lines) + for tensor_info in tensor_info_list: + print(tensor_info) + if actual_op_count.get(tensor_info.op_type, None) is None: + actual_op_count[tensor_info.op_type] = 1 + else: + actual_op_count[tensor_info.op_type] += 1 + print(actual_op_count) + + for op_type, expected_value in expected_op_count.items(): + actual_value = actual_op_count.get(op_type, 0) + self.assertEqual( + actual_value, + expected_value, + f"The number of operator < {op_type} > is expected to be {expected_value}, but recieved {actual_value}.", + ) + print("") + + def run_check_nan_inf(self, cmd, expected_op_count=None): + returncode, out, err = self.run_command(cmd) + self.check_op_count(out, expected_op_count) + if self.check_nan_inf_level == 0: + # in python3, type(out+err) is 'bytes', need use encode + self.assertNotEqual( + (out + err).find(b'There are NAN or INF'), + -1, + f"Cannot find NAN / INF keyword in:\n{out + err}", + ) + + def test_nan_inf_static(self): + if not self.check_static: + return + + filepath = os.path.dirname(__file__) + "/check_nan_inf_base.py" + cmd = f"{self._python_interp} {filepath}" + self.run_check_nan_inf(cmd, None) + + def test_nan_inf_dynamic(self): + if not self.check_dygraph: + return + + filepath = os.path.dirname(__file__) + "/check_nan_inf_base_dygraph.py" + + # Test on CPU. + cmd = f"{self._python_interp} {filepath} --check_nan_inf_level {self.check_nan_inf_level}" + self.run_check_nan_inf(cmd, self.dygraph_expected_op_count) + + # Test on GPU. + if paddle.fluid.core.is_compiled_with_cuda(): + cmd = f"{self._python_interp} {filepath} --use_cuda --check_nan_inf_level {self.check_nan_inf_level}" + self.run_check_nan_inf(cmd, self.dygraph_expected_op_count) + + +class TestCheckAll(TestNanInf): + def setUp(self): + super().setUp() + self.check_static = False + self.check_dygraph = True + self.check_nan_inf_level = 3 + self.dygraph_expected_op_count = { + 'assign_value_': 2, + 'full_': 3, + 'matmul': 2, + 'add': 2, + 'sigmoid': 1, + 'cast': 1, + 'divide': 1, + 'softmax': 1, + 'mean': 1, + 'mean_grad': 1, + 'softmax_grad': 1, + 'divide_grad': 1, + 'add_grad': 4, + 'matmul_grad': 3, + 'sigmoid_grad': 1, + 'sgd_': 4, + } class TestNanInfEnv(TestNanInf): @@ -67,24 +147,31 @@ class TestNanInfEnv(TestNanInf): self.env["PADDLE_INF_NAN_SKIP_ROLE"] = "loss" self.env["PADDLE_INF_NAN_SKIP_VAR"] = "elementwise_add:fc_0.tmp_1" + self.check_static = True + self.check_dygraph = False + self.check_nan_inf_level = 0 + self.dygraph_expected_op_count = None -class TestCheckSkipEnv(TestNanInf): - def setUp(self): - super().setUp() - # windows python have some bug with env, so need use str to pass ci - # otherwise, "TypeError: environment can only contain strings" - self.env["Paddle_check_nan_inf_op_list"] = "mean" - self.env["Paddle_skip_nan_inf_op_list"] = "elementwise_add" +class TestNanInfStack(TestNanInfBase): + def check_stack(self, file_name): + cmd = self._python_interp + file_name + returncode, out, err = self.run_command(cmd) -class TestNanInfCheckResult(unittest.TestCase): - def setUp(self): - self._python_interp = sys.executable - if os.getenv('WITH_COVERAGE', 'OFF') == 'ON': - self._python_interp += " -m coverage run --branch -p" + print(out) + print(err) - self.env = os.environ.copy() + # in python3, type(out+err) is 'bytes', need use encode + assert (out + err).find(b' z = paddle.pow(x, y)') != -1 + def test_check_stack(self): + self.check_stack(" check_nan_inf_backward_stack.py") + + def test_statck_check_stack(self): + self.check_stack(" check_nan_inf_backward_static_stack.py") + + +class TestNanInfCheckResult(TestNanInfBase): def generate_inputs(self, shape, dtype="float32"): data = np.random.random(size=shape).astype(dtype) # [-10, 10) @@ -148,32 +235,11 @@ class TestNanInfCheckResult(unittest.TestCase): if paddle.fluid.core.is_compiled_with_cuda(): _check_num_nan_inf(use_cuda=True) - def check_stack(self, file_name): - self._python_interp += file_name - cmd = self._python_interp - proc = subprocess.Popen( - cmd.split(" "), - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - env=self.env, + def run_check_nan_inf_level(self, use_cuda, dtype, level): + paddle.set_flags( + {"FLAGS_check_nan_inf": 1, "FLAGS_check_nan_inf_level": level} ) - out, err = proc.communicate() - returncode = proc.returncode - - print(out) - print(err) - - # in python3, type(out+err) is 'bytes', need use encode - assert (out + err).find(b' z = paddle.pow(x, y)') != -1 - - def test_check_stack(self): - self.check_stack(" check_nan_inf_backward_stack.py") - - def test_statck_check_stack(self): - self.check_stack(" check_nan_inf_backward_static_stack.py") - - def check_nan_inf_level(self, use_cuda, dtype): shape = [8, 8] x_np, y_np = self.generate_inputs(shape, dtype) @@ -186,33 +252,36 @@ class TestNanInfCheckResult(unittest.TestCase): out = paddle.log(x * 1e6) / y def test_check_nan_inf_level_float32(self): - paddle.set_flags( - {"FLAGS_check_nan_inf": 1, "FLAGS_check_nan_inf_level": 2} + level = 2 + self.run_check_nan_inf_level( + use_cuda=False, dtype="float32", level=level ) - self.check_nan_inf_level(use_cuda=False, dtype="float32") if paddle.fluid.core.is_compiled_with_cuda(): - self.check_nan_inf_level(use_cuda=True, dtype="float32") + self.run_check_nan_inf_level( + use_cuda=True, dtype="float32", level=level + ) def test_check_nan_inf_level_float16(self): - paddle.set_flags( - {"FLAGS_check_nan_inf": 1, "FLAGS_check_nan_inf_level": 3} + level = 3 + self.run_check_nan_inf_level( + use_cuda=False, dtype="float32", level=level ) if paddle.fluid.core.is_compiled_with_cuda(): - self.check_nan_inf_level(use_cuda=True, dtype="float16") + self.run_check_nan_inf_level( + use_cuda=True, dtype="float16", level=level + ) def test_check_numerics(self): paddle.set_flags( {"FLAGS_check_nan_inf": 1, "FLAGS_check_nan_inf_level": 3} ) - if paddle.fluid.core.is_compiled_with_cuda(): - self.check_nan_inf_level(use_cuda=True, dtype="float16") shape = [8, 8] x_np, y_np = self.generate_inputs(shape, "float16") x = paddle.to_tensor(x_np) y = paddle.to_tensor(y_np) - paddle.fluid.core.check_numerics("check_numerics", x) - paddle.fluid.core.check_numerics("check_numerics", y) + paddle.fluid.core.check_numerics("check_tensor", x) + paddle.fluid.core.check_numerics("check_tensor", y) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_nan_inf_dir.py b/python/paddle/fluid/tests/unittests/test_nan_inf_dir.py index 06695c56f24..122ddb74f41 100644 --- a/python/paddle/fluid/tests/unittests/test_nan_inf_dir.py +++ b/python/paddle/fluid/tests/unittests/test_nan_inf_dir.py @@ -13,7 +13,7 @@ # limitations under the License. import os -import sys +import tempfile import unittest import numpy as np @@ -22,97 +22,99 @@ import paddle class TestNanInfDirCheckResult(unittest.TestCase): - def generate_inputs(self, shape, dtype="float32"): + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + + def tearDown(self): + self.temp_dir.cleanup() + + def generate_inputs(self, shape, low=0, high=1, dtype="float32"): data = np.random.random(size=shape).astype(dtype) - # [-10, 10) - x = (data * 20 - 10) * np.random.randint( + x = (data * (high - low) + low) * np.random.randint( low=0, high=2, size=shape ).astype(dtype) - y = np.random.randint(low=0, high=2, size=shape).astype(dtype) - return x, y + return x def get_reference_num_nan_inf(self, x): out = np.log(x) num_nan = np.sum(np.isnan(out)) num_inf = np.sum(np.isinf(out)) - print(f"[reference] num_nan={num_nan}, num_inf={num_inf}") + print(f"-- [reference] num_nan={num_nan}, num_inf={num_inf}") return num_nan, num_inf - def get_num_nan_inf( - self, x_np, use_cuda=True, add_assert=False, pt="nan_inf_log_dir" - ): + def get_num_nan_inf(self, x_np, use_cuda=True, output_dir=None): + if use_cuda: + paddle.device.set_device("gpu:0") + else: + paddle.device.set_device("cpu") + x = paddle.to_tensor(x_np) + x = x * 0.5 + out = paddle.log(x) + if use_cuda: + paddle.device.cuda.synchronize() + + self.assertEqual( + os.path.exists(output_dir) and os.path.isdir(output_dir), True + ) + num_nan = 0 num_inf = 0 - if add_assert: - if use_cuda: - paddle.device.set_device("gpu:0") - else: - paddle.device.set_device("cpu") - x = paddle.to_tensor(x_np) - out = paddle.log(x) - sys.stdout.flush() - if not use_cuda: - os.path.exists(pt) - num_nan = 0 - num_inf = 0 - for root, dirs, files in os.walk(pt): - for file_name in files: - if file_name.startswith('worker_cpu'): - file_path = os.path.join(root, file_name) - with open(file_path, "rb") as fp: - for e in fp: - err_str_list = ( - str(e) - .replace("(", " ") - .replace(")", " ") - .replace(",", " ") - .split(" ") - ) - for err_str in err_str_list: - if "num_nan" in err_str: - num_nan = int(err_str.split("=")[1]) - elif "num_inf" in err_str: - num_inf = int(err_str.split("=")[1]) - print(f"[paddle] num_nan={num_nan}, num_inf={num_inf}") + prefix = "worker_gpu" if use_cuda else "worker_cpu" + for filename in os.listdir(output_dir): + if filename.startswith(prefix): + filepath = os.path.join(output_dir, filename) + print(f"-- Parse {filepath}") + with open(filepath, "rb") as fp: + for e in fp: + err_str_list = ( + str(e) + .replace("(", " ") + .replace(")", " ") + .replace(",", " ") + .split(" ") + ) + for err_str in err_str_list: + if "num_nan" in err_str: + num_nan = int(err_str.split("=")[1]) + elif "num_inf" in err_str: + num_inf = int(err_str.split("=")[1]) + print( + f"-- [paddle] use_cuda={use_cuda}, num_nan={num_nan}, num_inf={num_inf}" + ) return num_nan, num_inf - def test_num_nan_inf(self): - path = "nan_inf_log_dir" - + def check_num_nan_inf(self, x_np, use_cuda, subdir): + output_dir = self.temp_dir.name + "/" + subdir + print(f"-- output_dir: {output_dir}") checker_config = paddle.amp.debugging.TensorCheckerConfig( enable=True, debug_mode=paddle.amp.debugging.DebugMode.CHECK_ALL, - output_dir=path, + output_dir=output_dir, ) - paddle.amp.debugging.enable_tensor_checker(checker_config) - def _check_num_nan_inf(use_cuda): - shape = [32, 32] - x_np, _ = self.generate_inputs(shape) - num_nan_np, num_inf_np = self.get_reference_num_nan_inf(x_np) - add_assert = (num_nan_np + num_inf_np) > 0 - num_nan, num_inf = self.get_num_nan_inf( - x_np, - use_cuda, - add_assert, - path, - ) - if not use_cuda: - assert num_nan == num_nan_np and num_inf == num_inf_np - - if paddle.fluid.core.is_compiled_with_cuda(): - _check_num_nan_inf(use_cuda=True) - else: - _check_num_nan_inf(use_cuda=False) + num_nan_np, num_inf_np = self.get_reference_num_nan_inf(x_np) + num_nan, num_inf = self.get_num_nan_inf( + x_np, + use_cuda, + output_dir, + ) + self.assertEqual(num_nan, num_nan_np) + self.assertEqual(num_inf, num_inf_np) - x = paddle.to_tensor([2, 3, 4], 'float32') - y = paddle.to_tensor([1, 5, 2], 'float32') - z = paddle.add(x, y) - path = "" - paddle.fluid.core.set_nan_inf_debug_path(path) paddle.amp.debugging.disable_tensor_checker() + def test_num_nan_inf(self): + shape = [32, 32] + x_np = self.generate_inputs(shape, -10, 10) + self.check_num_nan_inf( + x_np, use_cuda=False, subdir="check_nan_inf_dir_cpu" + ) + if paddle.fluid.core.is_compiled_with_cuda(): + self.check_num_nan_inf( + x_np, use_cuda=True, subdir="check_nan_inf_dir_gpu" + ) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_tensor_checker.py b/test/amp/test_tensor_checker.py similarity index 73% rename from python/paddle/fluid/tests/unittests/test_tensor_checker.py rename to test/amp/test_tensor_checker.py index a5b5e82034f..d495a43c37a 100644 --- a/python/paddle/fluid/tests/unittests/test_tensor_checker.py +++ b/test/amp/test_tensor_checker.py @@ -18,7 +18,7 @@ import paddle class TestTensorChecker(unittest.TestCase): - def get_num_inf(self, e): + def _parse_num_nan_inf(self, e): num_nan = 0 num_inf = 0 # Cannot catch the log in CUDA kernel. @@ -34,14 +34,9 @@ class TestTensorChecker(unittest.TestCase): num_nan = int(err_str.split("=")[1]) elif "num_inf" in err_str: num_inf = int(err_str.split("=")[1]) - print( - "[CHECK_NAN_INF_AND_ABORT] num_nan={}, num_inf={}".format( - num_nan, num_inf - ) - ) - return num_nan + return num_nan, num_inf - def generate_num_inf(self, place): + def _generate_num_inf(self, place): num_inf = 0 num_nan = 0 paddle.set_device(place) @@ -58,8 +53,8 @@ class TestTensorChecker(unittest.TestCase): paddle.autograd.backward([res]) res = paddle.divide(y, x) except Exception as e: - num_inf = self.get_num_inf(e) - return num_inf + num_nan, num_inf = self._parse_num_nan_inf(e) + return num_nan, num_inf def test_tensor_checker(self): def _assert_flag(value): @@ -86,22 +81,38 @@ class TestTensorChecker(unittest.TestCase): for place in places: paddle.amp.debugging.TensorCheckerConfig.current_step_id = 0 - for index in range(5): + for iter_id in range(5): paddle.amp.debugging.enable_tensor_checker(checker_config) - if index <= 2: + if iter_id <= 2: _assert_flag(True) self.assertEqual( - index + 1, + iter_id + 1, paddle.amp.debugging.TensorCheckerConfig.current_step_id, ) - self.assertEqual(1, self.generate_num_inf(place)) + num_nan, num_inf = self._generate_num_inf(place) + print( + f"-- [iter_id={iter_id}, place={place}] num_nan={num_nan}, num_inf={num_inf}" + ) + self.assertEqual( + 1, + num_nan, + f"Expected num_nan to be 1, but recieved {num_nan}, place={place}.", + ) else: self.assertEqual( 3, paddle.amp.debugging.TensorCheckerConfig.current_step_id, ) _assert_flag(False) - self.assertEqual(0, self.generate_num_inf(place)) + num_nan, num_inf = self._generate_num_inf(place) + print( + f"-- [iter_id={iter_id}, place={place}] num_nan={num_nan}, num_inf={num_inf}" + ) + self.assertEqual( + 0, + num_nan, + f"Expected num_nan to be 1, but recieved {num_nan}, place={place}.", + ) paddle.amp.debugging.disable_tensor_checker() _assert_flag(False) diff --git a/test/cpp/eager/task_tests/CMakeLists.txt b/test/cpp/eager/task_tests/CMakeLists.txt index de963c86203..4df64e81d0f 100755 --- a/test/cpp/eager/task_tests/CMakeLists.txt +++ b/test/cpp/eager/task_tests/CMakeLists.txt @@ -1,7 +1,7 @@ cc_test( test_egr_task_nan_inf_utils SRCS nan_inf_utils_test.cc - DEPS eager_nan_inf_utils) + DEPS eager_nan_inf_utils phi) if(NOT ((NOT WITH_PYTHON) AND ON_INFER)) cc_test( -- GitLab