未验证 提交 c18fddd3 编写于 作者: N niuliling123 提交者: GitHub

Save nan log to file when output_dir is setted (#49200)

上级 0e51f398
......@@ -30,6 +30,23 @@ DECLARE_int32(check_nan_inf_level);
namespace paddle {
namespace framework {
namespace details {
struct DebugTools {
DebugTools() {}
std::string path = "";
};
static DebugTools debug_nan_inf;
void SetNanInfDebugPath(const std::string& nan_inf_path) {
debug_nan_inf.path = nan_inf_path;
VLOG(4) << "Set the log's path of debug tools : " << nan_inf_path;
}
std::string GetNanPath() {
if (debug_nan_inf.path.empty()) {
return "";
}
return debug_nan_inf.path + "/";
}
static std::once_flag white_list_init_flag;
......@@ -134,112 +151,6 @@ static void InitWhiteListFormEnv() {
}
}
template <
typename T,
std::enable_if_t<!std::is_same<T, phi::dtype::complex<float>>::value &&
!std::is_same<T, phi::dtype::complex<double>>::value,
bool> = true>
static void CheckNanInfCpuImpl(const T* value_ptr,
const int64_t numel,
const std::string& cpu_hint_str) {
using MT = typename phi::dtype::template MPTypeTrait<T>::Type;
#ifdef _OPENMP
// Use maximum 4 threads to collect the nan and inf information.
int num_threads = std::max(omp_get_num_threads(), 1);
num_threads = std::min(num_threads, 4);
#else
int num_threads = 1;
#endif
std::vector<int64_t> thread_num_nan(num_threads, 0);
std::vector<int64_t> thread_num_inf(num_threads, 0);
std::vector<MT> thread_min_value(num_threads, static_cast<MT>(value_ptr[0]));
std::vector<MT> thread_max_value(num_threads, static_cast<MT>(value_ptr[0]));
std::vector<MT> thread_mean_value(num_threads, static_cast<MT>(0));
#ifdef _OPENMP
#pragma omp parallel num_threads(num_threads)
#endif
{
#ifdef _OPENMP
int64_t tid = omp_get_thread_num();
int64_t chunk_size = (numel + num_threads - 1) / num_threads;
int64_t begin = tid * chunk_size;
int64_t end = chunk_size + begin > numel ? numel : chunk_size + begin;
#else
int64_t tid = 0;
int64_t begin = 0;
int64_t end = numel;
#endif
for (int64_t i = begin; i < end; ++i) {
MT value = static_cast<MT>(value_ptr[i]);
thread_min_value[tid] = std::min(thread_min_value[tid], value);
thread_max_value[tid] = std::max(thread_max_value[tid], value);
thread_mean_value[tid] += value / static_cast<MT>(numel);
if (std::isnan(value)) {
thread_num_nan[tid] += 1;
} else if (std::isinf(value)) {
thread_num_inf[tid] += 1;
}
}
}
int64_t num_nan = 0;
int64_t num_inf = 0;
MT min_value = thread_min_value[0];
MT max_value = thread_max_value[0];
MT mean_value = static_cast<MT>(0);
for (int i = 0; i < num_threads; ++i) {
num_nan += thread_num_nan[i];
num_inf += thread_num_inf[i];
min_value = std::min(thread_min_value[i], min_value);
max_value = std::max(thread_max_value[i], max_value);
mean_value += thread_mean_value[i];
}
PrintForDifferentLevel<T, MT>(cpu_hint_str.c_str(),
numel,
num_nan,
num_inf,
max_value,
min_value,
mean_value,
FLAGS_check_nan_inf_level);
}
template <
typename T,
std::enable_if_t<std::is_same<T, phi::dtype::complex<float>>::value ||
std::is_same<T, phi::dtype::complex<double>>::value,
bool> = true>
void CheckNanInfCpuImpl(const T* value_ptr,
const int64_t numel,
const std::string& cpu_hint_str) {
using RealType = typename T::value_type;
RealType real_sum = 0.0f, imag_sum = 0.0f;
#ifdef _OPENMP
#pragma omp parallel for reduction(+ : real_sum) reduction(+ : imag_sum)
#endif
for (int64_t i = 0; i < numel; ++i) {
T value = value_ptr[i];
real_sum += (value.real - value.real);
imag_sum += (value.imag - value.imag);
}
if (std::isnan(real_sum) || std::isinf(real_sum) || std::isnan(imag_sum) ||
std::isinf(imag_sum)) {
// hot fix for compile failed in gcc4.8
// here also need print detail info of nan or inf later
PADDLE_THROW(platform::errors::PreconditionNotMet(
"There are NAN or INF in %s.", cpu_hint_str));
}
}
template <>
template <typename T>
void TensorCheckerVisitor<phi::CPUContext>::apply(
......
......@@ -322,18 +322,26 @@ __global__ void FindGlobalMaxMinAndPrint(const int64_t* block_num_nan_ptr,
}
template <typename T>
static char* GetGpuHintStringPtr(const phi::GPUContext& ctx,
const std::string& op_type,
inline std::string GetHintString(const std::string& op_type,
const std::string& var_name,
int dev_id) {
const phi::Place& place,
int dev_id = -1) {
std::string op_var = GetCpuHintString<T>(op_type, var_name, place, dev_id);
PADDLE_ENFORCE_EQ(
(dev_id >= 0 && dev_id < multi_op_var2gpu_str_mutex().size()),
true,
platform::errors::OutOfRange("GPU dev_id must >=0 and < dev_count=%d",
multi_op_var2gpu_str_mutex().size()));
return op_var;
}
template <typename T>
static char* GetGpuHintStringPtr(const phi::GPUContext& ctx,
const std::string& op_type,
const std::string& var_name,
int dev_id) {
std::string op_var =
GetCpuHintString<T>(op_type, var_name, ctx.GetPlace(), dev_id);
GetHintString<T>(op_type, var_name, ctx.GetPlace(), dev_id);
char* gpu_str_ptr = nullptr;
{
......@@ -396,6 +404,24 @@ void TensorCheckerVisitor<phi::GPUContext>::apply(
auto* dev_ctx = reinterpret_cast<phi::GPUContext*>(
platform::DeviceContextPool::Instance().Get(tensor.place()));
int dev_id = tensor.place().device;
// Write log to file
auto file_path = GetNanPath();
if (file_path.size() > 0) {
phi::DenseTensor cpu_tensor;
platform::CPUPlace cpu_place;
cpu_tensor.Resize(tensor.dims());
// 1. copy from gpu to cpu
paddle::framework::TensorCopySync(tensor, cpu_place, &cpu_tensor);
auto* dev_ctx = reinterpret_cast<phi::GPUContext*>(
platform::DeviceContextPool::Instance().Get(tensor.place()));
const std::string debug_info =
GetHintString<T>(op_type, var_name, place, dev_id);
// 2. write log to file
CheckNanInfCpuImpl(cpu_tensor.data<T>(), tensor.numel(), debug_info, "gpu");
return;
}
// Write log to window
char* gpu_str_ptr =
GetGpuHintStringPtr<T>(*dev_ctx, op_type, var_name, dev_id);
......
......@@ -13,17 +13,33 @@
// limitations under the License.
#pragma once
#include <fstream>
#include <iostream>
#include <string>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/funcs/eigen/extensions.h"
#ifdef _WIN32
#include <direct.h>
#include <io.h>
#define MKDIR(path) _mkdir(path)
#else
#include <sys/stat.h>
#define MKDIR(path) mkdir(path, S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH)
#endif
DECLARE_int32(check_nan_inf_level);
namespace paddle {
namespace framework {
namespace details {
void SetNanInfDebugPath(const std::string& nan_inf_path);
std::string GetNanPath();
template <typename T,
typename MT,
std::enable_if_t<std::is_same<T, float>::value, bool> = true>
......@@ -93,6 +109,49 @@ HOSTDEVICE void PrintForDifferentLevel(const char* debug_info,
}
}
template <typename T, typename MT>
void PrintForDifferentLevelFile(const char* debug_info,
int64_t numel,
int64_t num_nan,
int64_t num_inf,
MT max_value,
MT min_value,
MT mean_value,
int check_nan_inf_level,
const std::string& log_name) {
int dev_id = 0;
#ifdef PADDLE_WITH_HIP
hipGetDevice(&dev_id);
#elif PADDLE_WITH_CUDA
cudaGetDevice(&dev_id);
#endif
auto file_path = GetNanPath();
MKDIR(file_path.c_str());
std::string file_name = "worker_" + log_name + "." + std::to_string(dev_id);
std::string path = file_path + file_name;
std::ofstream outfile(path, std::ios::app);
if (!outfile.is_open()) {
return;
}
if (num_nan > 0 || num_inf > 0) {
outfile << "[PRECISION] [ERROR] in " << debug_info
<< ", numel=" << static_cast<long long>(numel) // NOLINT
<< ", num_nan=" << static_cast<long long>(num_nan) // NOLINT
<< ", num_inf=" << static_cast<long long>(num_inf) // NOLINT
<< ", max=" << static_cast<float>(max_value)
<< ", min=" << static_cast<float>(min_value)
<< ", mean=" << static_cast<float>(mean_value) << std::endl;
} else if (NeedPrint<T, MT>(max_value, min_value, check_nan_inf_level)) {
outfile << "[PRECISION] in " << debug_info
<< ", numel=" << static_cast<long long>(numel) // NOLINT
<< ", max=" << static_cast<float>(max_value)
<< ", min=" << static_cast<float>(min_value)
<< ", mean=" << static_cast<float>(mean_value) << std::endl;
}
outfile.close();
}
template <typename T>
inline std::string GetCpuHintString(const std::string& op_type,
const std::string& var_name,
......@@ -120,6 +179,130 @@ inline std::string GetCpuHintString(const std::string& op_type,
return ss.str();
}
template <
typename T,
std::enable_if_t<!std::is_same<T, phi::dtype::complex<float>>::value &&
!std::is_same<T, phi::dtype::complex<double>>::value,
bool> = true>
static void CheckNanInfCpuImpl(const T* value_ptr,
const int64_t numel,
const std::string& cpu_hint_str,
const std::string log_name = "cpu") {
using MT = typename phi::dtype::template MPTypeTrait<T>::Type;
#ifdef _OPENMP
// Use maximum 4 threads to collect the nan and inf information.
int num_threads = std::max(omp_get_num_threads(), 1);
num_threads = std::min(num_threads, 4);
#else
int num_threads = 1;
#endif
std::vector<int64_t> thread_num_nan(num_threads, 0);
std::vector<int64_t> thread_num_inf(num_threads, 0);
std::vector<MT> thread_min_value(num_threads, static_cast<MT>(value_ptr[0]));
std::vector<MT> thread_max_value(num_threads, static_cast<MT>(value_ptr[0]));
std::vector<MT> thread_mean_value(num_threads, static_cast<MT>(0));
#ifdef _OPENMP
#pragma omp parallel num_threads(num_threads)
#endif
{
#ifdef _OPENMP
int64_t tid = omp_get_thread_num();
int64_t chunk_size = (numel + num_threads - 1) / num_threads;
int64_t begin = tid * chunk_size;
int64_t end = chunk_size + begin > numel ? numel : chunk_size + begin;
#else
int64_t tid = 0;
int64_t begin = 0;
int64_t end = numel;
#endif
for (int64_t i = begin; i < end; ++i) {
MT value = static_cast<MT>(value_ptr[i]);
thread_min_value[tid] = std::min(thread_min_value[tid], value);
thread_max_value[tid] = std::max(thread_max_value[tid], value);
thread_mean_value[tid] += value / static_cast<MT>(numel);
if (std::isnan(value)) {
thread_num_nan[tid] += 1;
} else if (std::isinf(value)) {
thread_num_inf[tid] += 1;
}
}
}
int64_t num_nan = 0;
int64_t num_inf = 0;
MT min_value = thread_min_value[0];
MT max_value = thread_max_value[0];
MT mean_value = static_cast<MT>(0);
for (int i = 0; i < num_threads; ++i) {
num_nan += thread_num_nan[i];
num_inf += thread_num_inf[i];
min_value = std::min(thread_min_value[i], min_value);
max_value = std::max(thread_max_value[i], max_value);
mean_value += thread_mean_value[i];
}
auto file_path = GetNanPath();
// Write log to file
if (file_path.size() > 0) {
VLOG(4) << "[FLAGS_check_nan_inf_level=" << FLAGS_check_nan_inf_level
<< "]. Write log to " << file_path;
PrintForDifferentLevelFile<T, MT>(cpu_hint_str.c_str(),
numel,
num_nan,
num_inf,
max_value,
min_value,
mean_value,
FLAGS_check_nan_inf_level,
log_name);
return;
}
PrintForDifferentLevel<T, MT>(cpu_hint_str.c_str(),
numel,
num_nan,
num_inf,
max_value,
min_value,
mean_value,
FLAGS_check_nan_inf_level);
}
template <
typename T,
std::enable_if_t<std::is_same<T, phi::dtype::complex<float>>::value ||
std::is_same<T, phi::dtype::complex<double>>::value,
bool> = true>
void CheckNanInfCpuImpl(const T* value_ptr,
const int64_t numel,
const std::string& cpu_hint_str,
const std::string log_name = "cpu") {
using RealType = typename T::value_type;
RealType real_sum = 0.0f, imag_sum = 0.0f;
#ifdef _OPENMP
#pragma omp parallel for reduction(+ : real_sum) reduction(+ : imag_sum)
#endif
for (int64_t i = 0; i < numel; ++i) {
T value = value_ptr[i];
real_sum += (value.real - value.real);
imag_sum += (value.imag - value.imag);
}
if (std::isnan(real_sum) || std::isinf(real_sum) || std::isnan(imag_sum) ||
std::isinf(imag_sum)) {
// hot fix for compile failed in gcc4.8
// here also need print detail info of nan or inf later
PADDLE_THROW(platform::errors::PreconditionNotMet(
"There are NAN or INF in %s.", cpu_hint_str));
}
}
template <typename DeviceContext>
struct TensorCheckerVisitor {
TensorCheckerVisitor(const std::string& o,
......
......@@ -34,6 +34,7 @@ limitations under the License. */
#include "paddle/fluid/framework/custom_operator.h"
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/details/nan_inf_utils_detail.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/executor_cache.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
......@@ -2671,6 +2672,9 @@ All parameter, weight, gradient are variables in Paddle.
m.def("use_layout_autotune",
[] { return egr::Controller::Instance().UseLayoutAutoTune(); });
// Add the api for nan op debug
m.def("set_nan_inf_debug_path",
&paddle::framework::details::SetNanInfDebugPath);
BindFleetWrapper(&m);
BindIO(&m);
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import unittest
import numpy as np
import paddle
class TestNanInfDirCheckResult(unittest.TestCase):
def generate_inputs(self, shape, dtype="float32"):
data = np.random.random(size=shape).astype(dtype)
# [-10, 10)
x = (data * 20 - 10) * np.random.randint(
low=0, high=2, size=shape
).astype(dtype)
y = np.random.randint(low=0, high=2, size=shape).astype(dtype)
return x, y
def get_reference_num_nan_inf(self, x):
out = np.log(x)
num_nan = np.sum(np.isnan(out))
num_inf = np.sum(np.isinf(out))
print("[reference] num_nan={}, num_inf={}".format(num_nan, num_inf))
return num_nan, num_inf
def get_num_nan_inf(
self, x_np, use_cuda=True, add_assert=False, pt="nan_inf_log_dir"
):
num_nan = 0
num_inf = 0
if add_assert:
if use_cuda:
paddle.device.set_device("gpu:0")
else:
paddle.device.set_device("cpu")
x = paddle.to_tensor(x_np)
out = paddle.log(x)
sys.stdout.flush()
if not use_cuda:
os.path.exists(pt)
num_nan = 0
num_inf = 0
for root, dirs, files in os.walk(pt):
for file_name in files:
if file_name.startswith('worker_cpu'):
file_path = os.path.join(root, file_name)
with open(file_path, "rb") as fp:
for e in fp:
err_str_list = (
str(e)
.replace("(", " ")
.replace(")", " ")
.replace(",", " ")
.split(" ")
)
for err_str in err_str_list:
if "num_nan" in err_str:
num_nan = int(err_str.split("=")[1])
elif "num_inf" in err_str:
num_inf = int(err_str.split("=")[1])
print(
"[paddle] num_nan={}, num_inf={}".format(num_nan, num_inf)
)
return num_nan, num_inf
def test_num_nan_inf(self):
path = "nan_inf_log_dir"
paddle.fluid.core.set_nan_inf_debug_path(path)
def _check_num_nan_inf(use_cuda):
shape = [32, 32]
x_np, _ = self.generate_inputs(shape)
num_nan_np, num_inf_np = self.get_reference_num_nan_inf(x_np)
add_assert = (num_nan_np + num_inf_np) > 0
num_nan, num_inf = self.get_num_nan_inf(
x_np, use_cuda, add_assert, path
)
if not use_cuda:
assert num_nan == num_nan_np and num_inf == num_inf_np
paddle.set_flags(
{"FLAGS_check_nan_inf": 1, "FLAGS_check_nan_inf_level": 3}
)
_check_num_nan_inf(use_cuda=False)
if paddle.fluid.core.is_compiled_with_cuda():
_check_num_nan_inf(use_cuda=True)
x = paddle.to_tensor([2, 3, 4], 'float32')
y = paddle.to_tensor([1, 5, 2], 'float32')
z = paddle.add(x, y)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册