未验证 提交 a5444592 编写于 作者: Y Yiqun Liu 提交者: GitHub

[AMP] Add check_numerics API. (#54301)

* Add outputs to check_numerics_kernel.

* Add check_numerics to yaml.

* Add API and unittest.

* Add check_nan_inf_level as argument of check_numerics_kernel.

* Add more unittests.

* Fix static API implementation and unittest.

* Move the implementation of check_numerics to paddle.amp.

* Fix import error.
上级 418d2796
......@@ -23,8 +23,6 @@
#include "paddle/phi/core/flags.h"
#include "paddle/phi/kernels/funcs/eigen/extensions.h"
PHI_DECLARE_int32(check_nan_inf_level);
namespace paddle {
namespace framework {
namespace details {
......
......@@ -19,9 +19,12 @@
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/core/flags.h"
#include "paddle/phi/kernels/check_numerics_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/extensions.h"
PHI_DECLARE_int32(check_nan_inf_level);
namespace paddle {
namespace framework {
namespace details {
......@@ -58,9 +61,18 @@ struct TensorCheckerVisitor {
auto* dev_ctx = reinterpret_cast<Context*>(
platform::DeviceContextPool::Instance().Get(tensor.place()));
phi::DenseTensor stats;
phi::DenseTensor values;
auto file_path = GetNanPath();
phi::CheckNumericsKernel<T, Context>(
*dev_ctx, tensor, op_type, var_name, GetNanInfStackLimit(), file_path);
phi::CheckNumericsKernel<T, Context>(*dev_ctx,
tensor,
op_type,
var_name,
FLAGS_check_nan_inf_level,
GetNanInfStackLimit(),
file_path,
&stats,
&values);
}
std::string op_type;
......
......@@ -2699,12 +2699,6 @@ All parameter, weight, gradient are variables in Paddle.
m.def("set_skipped_op_list",
[](const std::string &op_list) { egr::SetSkipOpList(op_list); });
m.def("check_numerics",
[](const std::string &op_name, const paddle::Tensor &tensor) {
VLOG(4) << "Check tensor whether has nan or inf.";
egr::CheckTensorHasNanOrInf(op_name, tensor);
});
BindFleetWrapper(&m);
BindIO(&m);
BindParallelExecutor(m);
......
......@@ -390,6 +390,14 @@
data_type : x
inplace : (x -> out)
- op : check_numerics
args : (Tensor tensor, str op_type = "", str var_name = "", int check_nan_inf_level = 0, int stack_height_limit = -1, str output_dir = "")
output : Tensor(stats), Tensor(values)
infer_meta :
func : CheckNumericsInferMeta
kernel :
func : check_numerics
- op : cholesky
args : (Tensor x, bool upper=false)
output : Tensor
......
......@@ -4959,6 +4959,20 @@ void ChannelShuffleInferMeta(const MetaTensor& x,
out->set_dims(output_dims);
}
void CheckNumericsInferMeta(const MetaTensor& tensor,
const std::string& op_type,
const std::string& var_name,
const int check_nan_inf_level,
const int stack_height_limit,
const std::string& output_dir,
MetaTensor* stats,
MetaTensor* values) {
stats->set_dtype(DataType::INT64);
stats->set_dims(phi::make_ddim({3}));
values->set_dtype(DataType::FLOAT32);
values->set_dims(phi::make_ddim({3}));
}
} // namespace phi
PD_REGISTER_INFER_META_FN(flatten, phi::FlattenInferMeta);
......@@ -74,6 +74,15 @@ void ChannelShuffleInferMeta(const MetaTensor& x,
const std::string& data_format,
MetaTensor* out);
void CheckNumericsInferMeta(const MetaTensor& tensor,
const std::string& op_type,
const std::string& var_name,
const int check_nan_inf_level,
const int stack_height_limit,
const std::string& output_dir,
MetaTensor* stats,
MetaTensor* values);
void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out);
void ClassCenterSampleInferMeta(const MetaTensor& label,
......
......@@ -23,7 +23,10 @@ void CheckNumericsKernel(const Context& ctx,
const DenseTensor& tensor,
const std::string& op_type,
const std::string& var_name,
const int check_nan_inf_level,
const int stack_height_limit,
const std::string& output_dir);
const std::string& output_dir,
DenseTensor* stats,
DenseTensor* values);
} // namespace phi
......@@ -15,12 +15,9 @@ limitations under the License. */
#include "paddle/phi/kernels/check_numerics_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/flags.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/check_numerics_utils.h"
PHI_DECLARE_int32(check_nan_inf_level);
namespace phi {
template <typename T, typename Context>
......@@ -28,16 +25,29 @@ void CheckNumericsKernel(const Context& ctx,
const DenseTensor& tensor,
const std::string& op_type,
const std::string& var_name,
const int check_nan_inf_level,
const int stack_height_limit,
const std::string& output_dir) {
const std::string& output_dir,
DenseTensor* stats,
DenseTensor* values) {
// stats stores the checking result of num_nan, num_inf and num_zero.
stats->Resize({static_cast<int64_t>(3)});
int64_t* stats_ptr = ctx.template Alloc<int64_t>(stats);
// values stores the max_value, min_value and mean_value.
values->Resize({static_cast<int64_t>(3)});
float* values_ptr = ctx.template Alloc<float>(values);
std::string cpu_hint_str =
phi::funcs::GetCpuHintString<T>(op_type, var_name, tensor.place());
phi::funcs::CheckNumericsCpuImpl(tensor.data<T>(),
tensor.numel(),
cpu_hint_str,
FLAGS_check_nan_inf_level,
check_nan_inf_level,
"cpu",
output_dir);
output_dir,
stats_ptr,
values_ptr);
}
} // namespace phi
......
......@@ -61,6 +61,27 @@ HOSTDEVICE bool NeedPrint(MT max_value UNUSED,
return false;
}
template <typename T>
HOSTDEVICE static void SaveStatsAndValues(int64_t num_nan,
int64_t num_inf,
int64_t num_zero,
T max_value,
T min_value,
T mean_value,
int64_t* stats_ptr,
float* values_ptr) {
if (stats_ptr) {
stats_ptr[0] = num_nan;
stats_ptr[1] = num_inf;
stats_ptr[2] = num_zero;
}
if (values_ptr) {
values_ptr[0] = static_cast<float>(max_value);
values_ptr[1] = static_cast<float>(min_value);
values_ptr[2] = static_cast<float>(mean_value);
}
}
HOSTDEVICE static void PrintAndThrowError(const char* debug_info,
int64_t num_nan,
int64_t num_inf,
......@@ -197,8 +218,10 @@ static void CheckNumericsCpuImpl(const T* value_ptr,
const int64_t numel,
const std::string& cpu_hint_str,
const int check_nan_inf_level,
const std::string log_name = "cpu",
const std::string output_dir = "") {
const std::string log_name,
const std::string output_dir,
int64_t* stats_ptr,
float* values_ptr) {
using MT = typename phi::dtype::template MPTypeTrait<T>::Type;
#ifdef _OPENMP
......@@ -263,6 +286,15 @@ static void CheckNumericsCpuImpl(const T* value_ptr,
mean_value += thread_mean_value[i];
}
SaveStatsAndValues<MT>(num_nan,
num_inf,
num_zero,
max_value,
min_value,
mean_value,
stats_ptr,
values_ptr);
// Write log to file
if (output_dir.size() > 0) {
WriteToFileForDifferentLevel<T, MT>(cpu_hint_str.c_str(),
......@@ -298,8 +330,10 @@ void CheckNumericsCpuImpl(const T* value_ptr,
const int64_t numel,
const std::string& cpu_hint_str,
const int check_nan_inf_level,
const std::string log_name = "cpu",
const std::string output_dir = "") {
const std::string log_name,
const std::string output_dir,
int64_t* stats_ptr,
float* values_ptr) {
using RealType = typename T::value_type;
RealType real_sum = 0.0f, imag_sum = 0.0f;
......
......@@ -19,13 +19,10 @@ limitations under the License. */
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/flags.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/check_numerics_utils.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"
PHI_DECLARE_int32(check_nan_inf_level);
namespace phi {
static std::once_flag init_multi_gpu_op_var_map_flag;
......@@ -294,7 +291,8 @@ __global__ void FindGlobalMaxMinAndPrint(const int64_t* block_num_nan_ptr,
int64_t numel,
int64_t numel_max_min,
int check_nan_inf_level,
int64_t* nan_inf_zero_ptr) {
int64_t* stats_ptr,
float* values_ptr) {
if (blockIdx.x == 0 && threadIdx.x == 0) {
int64_t num_nan = 0;
int64_t num_inf = 0;
......@@ -325,11 +323,14 @@ __global__ void FindGlobalMaxMinAndPrint(const int64_t* block_num_nan_ptr,
min_value = tmp_min_value < min_value ? tmp_min_value : min_value;
mean_value += tmp_mean_value;
}
if (check_nan_inf_level == 0) {
nan_inf_zero_ptr[0] = num_nan;
nan_inf_zero_ptr[1] = num_inf;
nan_inf_zero_ptr[2] = num_zero;
}
phi::funcs::SaveStatsAndValues<MT>(num_nan,
num_inf,
num_zero,
max_value,
min_value,
mean_value,
stats_ptr,
values_ptr);
}
phi::funcs::PrintForDifferentLevel<T, MT>(debug_info,
......@@ -364,6 +365,8 @@ static char* GetGpuHintStringPtr(const phi::GPUContext& ctx,
const std::string& op_type,
const std::string& var_name,
int dev_id) {
std::call_once(init_multi_gpu_op_var_map_flag, InitMultiGPUOpVarMap);
std::string op_var =
GetHintString<T>(op_type, var_name, ctx.GetPlace(), dev_id);
char* gpu_str_ptr = nullptr;
......@@ -417,39 +420,86 @@ static char* GetGpuHintStringPtr(const phi::GPUContext& ctx,
return gpu_str_ptr;
}
template <typename T>
static void PrintStack(const phi::GPUContext& ctx,
const DenseTensor& stats,
const std::string& op_type,
const std::string& var_name,
int dev_id) {
auto cpu_stats =
phi::memory_utils::Alloc(phi::CPUPlace(), sizeof(int64_t) * 3);
int64_t* cpu_stats_ptr = reinterpret_cast<int64_t*>(cpu_stats->ptr());
phi::memory_utils::Copy(phi::CPUPlace(),
cpu_stats_ptr,
stats.place(),
stats.data(),
3 * sizeof(int64_t),
ctx.stream());
ctx.Wait();
if (cpu_stats_ptr[0] > 0 || cpu_stats_ptr[1] > 0) {
const std::string debug_info =
GetHintString<T>(op_type, var_name, stats.place(), dev_id);
phi::funcs::PrintAndThrowError(debug_info.c_str(),
cpu_stats_ptr[0],
cpu_stats_ptr[1],
cpu_stats_ptr[2]);
}
}
template <typename T, typename MT>
static void WriteToOutputDir(const phi::GPUContext& ctx,
const DenseTensor& tensor,
const DenseTensor& stats,
const DenseTensor& values,
const std::string& op_type,
const std::string& var_name,
const std::string& output_dir,
const int check_nan_inf_level) {
// Copy stats and values from GPU to CPU.
phi::DenseTensor cpu_stats;
cpu_stats.Resize({static_cast<int64_t>(3)});
phi::Copy(ctx, stats, phi::CPUPlace(), false, &cpu_stats);
phi::DenseTensor cpu_values;
cpu_values.Resize({static_cast<int64_t>(3)});
phi::Copy(ctx, values, phi::CPUPlace(), false, &cpu_values);
ctx.Wait();
int dev_id = tensor.place().device;
const std::string debug_info =
GetHintString<T>(op_type, var_name, tensor.place(), dev_id);
std::string log_name = "gpu." + std::to_string(dev_id);
int64_t* cpu_stats_ptr = cpu_stats.data<int64_t>();
float* cpu_values_ptr = cpu_values.data<float>();
phi::funcs::WriteToFileForDifferentLevel<T, MT>(debug_info.c_str(),
tensor.numel(),
cpu_stats_ptr[0],
cpu_stats_ptr[1],
cpu_stats_ptr[2],
cpu_values_ptr[0],
cpu_values_ptr[1],
cpu_values_ptr[2],
check_nan_inf_level,
log_name,
output_dir);
}
template <typename T, typename Context>
void CheckNumericsKernel(const Context& ctx,
const DenseTensor& tensor,
const std::string& op_type,
const std::string& var_name,
const int check_nan_inf_level,
const int stack_height_limit,
const std::string& output_dir) {
std::call_once(init_multi_gpu_op_var_map_flag, InitMultiGPUOpVarMap);
const std::string& output_dir,
DenseTensor* stats,
DenseTensor* values) {
int dev_id = tensor.place().device;
VLOG(6) << "op_type=" << op_type << ", var_name=" << var_name
<< ", dev_id=gpu:" << dev_id
<< ", stack_height_limit=" << stack_height_limit
<< ", output_dir=" << output_dir;
// Write log to output_dir.
if (output_dir.size() > 0) {
phi::DenseTensor cpu_tensor;
cpu_tensor.Resize(tensor.dims());
// Copy tensor from GPU to CPU.
phi::Copy(ctx, tensor, CPUPlace(), true, &cpu_tensor);
const std::string debug_info =
GetHintString<T>(op_type, var_name, tensor.place(), dev_id);
std::string log_name = "gpu." + std::to_string(dev_id);
phi::funcs::CheckNumericsCpuImpl(cpu_tensor.data<T>(),
tensor.numel(),
debug_info,
FLAGS_check_nan_inf_level,
log_name,
output_dir);
return;
}
// Print to the standard output.
char* gpu_str_ptr = GetGpuHintStringPtr<T>(ctx, op_type, var_name, dev_id);
......@@ -502,11 +552,13 @@ void CheckNumericsKernel(const Context& ctx,
tensor_block_min_ptr,
tensor_block_mean_ptr);
int check_nan_inf_level = FLAGS_check_nan_inf_level;
// stats stores the checking result of num_nan, num_inf and num_zero.
stats->Resize({static_cast<int64_t>(3)});
int64_t* stats_ptr = ctx.template Alloc<int64_t>(stats);
phi::DenseTensor nan_inf_zero_tensor;
nan_inf_zero_tensor.Resize({static_cast<int64_t>(3)});
int64_t* nan_inf_zero_ptr = ctx.template Alloc<int64_t>(&nan_inf_zero_tensor);
// values stores the max_value, min_value and mean_value.
values->Resize({static_cast<int64_t>(3)});
float* values_ptr = ctx.template Alloc<float>(values);
FindGlobalMaxMinAndPrint<T, MT>
<<<1, 1, 0, ctx.stream()>>>(block_num_nan_ptr,
......@@ -519,25 +571,23 @@ void CheckNumericsKernel(const Context& ctx,
tensor.numel(),
numel_max_min,
check_nan_inf_level,
nan_inf_zero_ptr);
stats_ptr,
values_ptr);
if (output_dir.size() > 0) {
// Write log to output_dir.
WriteToOutputDir<T, MT>(ctx,
tensor,
*stats,
*values,
op_type,
var_name,
output_dir,
check_nan_inf_level);
}
if (check_nan_inf_level == 0 && stack_height_limit > 0) {
auto nan_cpu =
phi::memory_utils::Alloc(phi::CPUPlace(), sizeof(int64_t) * 3);
int64_t* nan_cpu_ptr = reinterpret_cast<int64_t*>(nan_cpu->ptr());
phi::memory_utils::Copy(phi::CPUPlace(),
nan_cpu_ptr,
tensor.place(),
nan_inf_zero_ptr,
3 * sizeof(int64_t),
ctx.stream());
ctx.Wait();
if (nan_cpu_ptr[0] > 0 || nan_cpu_ptr[1] > 0) {
const std::string debug_info =
GetHintString<T>(op_type, var_name, tensor.place(), dev_id);
phi::funcs::PrintAndThrowError(
debug_info.c_str(), nan_cpu_ptr[0], nan_cpu_ptr[1], nan_cpu_ptr[2]);
}
PrintStack<T>(ctx, *stats, op_type, var_name, dev_id);
}
#endif
}
......
......@@ -19,12 +19,16 @@ from enum import Enum
import numpy as np
import paddle
from paddle import _C_ops
from paddle.fluid import core
from paddle.fluid.framework import dygraph_only
from ..framework import LayerHelper, in_dynamic_mode
__all__ = [
"DebugMode",
"TensorCheckerConfig",
"check_numerics",
"enable_operator_stats_collection",
"disable_operator_stats_collection",
"collect_operator_stats",
......@@ -259,6 +263,72 @@ class TensorCheckerConfig:
self._set_env(False)
def check_numerics(
tensor, op_type, var_name, debug_mode=DebugMode.CHECK_NAN_INF_AND_ABORT
):
"""
This function is used to debugging a tensor, finding the number of NaNs, Infs and zeros in the tensor.
Args:
tensor(Tensor): The target tensor to check.
op_type(str): The OP or API name which produce the target tensor.
var_name(str): The name of target tensor.
debug_mode(paddle.amp.debugging.DebugMode, optional): The mode of debugging to be used. Default is DebugMode.CHECK_NAN_INF_AND_ABORT.
Returns:
stats(Tensor): The output stats tensor stores the number of NaNs, Infs and zeros of input tensor. The shape is [3] and dtype is int64.
values(Tensor): The output values tensor stores the maximum, minimum and mean value of input tensor. The shape is [3] and dtype is float.
Examples:
.. code-block:: python
import paddle
checker_config = paddle.amp.debugging.TensorCheckerConfig(
enable=True, debug_mode=paddle.amp.debugging.DebugMode.CHECK_NAN_INF)
x = paddle.to_tensor([1, 0, 3], place=paddle.CPUPlace(), dtype='float32')
y = paddle.to_tensor([0.2, 0, 0.5], place=paddle.CPUPlace(), dtype='float32')
res = paddle.pow(x, y)
paddle.amp.debugging.check_numerics(res, "pow", "res")
"""
stack_height_limit = -1
output_dir = ""
if in_dynamic_mode():
return _C_ops.check_numerics(
tensor,
op_type,
var_name,
debug_mode.value,
stack_height_limit,
output_dir,
)
helper = LayerHelper("check_numerics", **locals())
stats = helper.create_variable_for_type_inference(dtype="int64")
values = helper.create_variable_for_type_inference(dtype="float")
helper.append_op(
type='check_numerics',
inputs={
'tensor': tensor,
},
attrs={
'op_type': op_type,
'var_name': var_name,
'check_nan_inf_level': debug_mode.value,
'stack_height_limit': stack_height_limit,
'output_dir': output_dir,
},
outputs={'stats': [stats], 'values': [values]},
)
return stats, values
def _get_operator_stats_flag():
flags = paddle.get_flags(["FLAGS_low_precision_op_list"])
return flags["FLAGS_low_precision_op_list"]
......
......@@ -30,6 +30,7 @@ class TestNanInfBase(unittest.TestCase):
self._python_interp += " -m coverage run --branch -p"
self.env = os.environ.copy()
paddle.disable_static()
def run_command(self, cmd):
print(f"Run command: {cmd}")
......@@ -44,6 +45,15 @@ class TestNanInfBase(unittest.TestCase):
returncode = proc.returncode
return returncode, out, err
def generate_inputs(self, shape, dtype="float32"):
data = np.random.random(size=shape).astype(dtype)
# [-10, 10)
x = (data * 20 - 10) * np.random.randint(
low=0, high=2, size=shape
).astype(dtype)
y = np.random.randint(low=0, high=2, size=shape).astype(dtype)
return x, y
class TestNanInf(TestNanInfBase):
def setUp(self):
......@@ -172,15 +182,6 @@ class TestNanInfStack(TestNanInfBase):
class TestNanInfCheckResult(TestNanInfBase):
def generate_inputs(self, shape, dtype="float32"):
data = np.random.random(size=shape).astype(dtype)
# [-10, 10)
x = (data * 20 - 10) * np.random.randint(
low=0, high=2, size=shape
).astype(dtype)
y = np.random.randint(low=0, high=2, size=shape).astype(dtype)
return x, y
def get_reference_num_nan_inf(self, x):
out = np.log(x)
num_nan = np.sum(np.isnan(out))
......@@ -271,17 +272,55 @@ class TestNanInfCheckResult(TestNanInfBase):
use_cuda=True, dtype="float16", level=level
)
def test_check_numerics(self):
paddle.set_flags(
{"FLAGS_check_nan_inf": 1, "FLAGS_check_nan_inf_level": 3}
)
class TestCheckNumericsAPI(TestNanInfBase):
def test_eager(self):
shape = [8, 8]
x_np, y_np = self.generate_inputs(shape, "float16")
x = paddle.to_tensor(x_np)
y = paddle.to_tensor(y_np)
paddle.fluid.core.check_numerics("check_tensor", x)
paddle.fluid.core.check_numerics("check_tensor", y)
x_np, y_np = self.generate_inputs(shape, "float32")
device_list = ["cpu"]
if paddle.fluid.core.is_compiled_with_cuda():
device_list.append("gpu:0")
for device in device_list:
paddle.device.set_device(device)
x = paddle.to_tensor(x_np)
y = paddle.to_tensor(y_np)
paddle.amp.debugging.check_numerics(
tensor=x,
op_type="to_tensor",
var_name="x",
debug_mode=paddle.amp.debugging.DebugMode.CHECK_ALL,
)
paddle.amp.debugging.check_numerics(
tensor=y,
op_type="to_tensor",
var_name="y",
debug_mode=paddle.amp.debugging.DebugMode.CHECK_ALL,
)
def test_static(self):
paddle.enable_static()
shape = [8, 8]
x_np, y_np = self.generate_inputs(shape, "float32")
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
x = paddle.static.data(name='x', shape=[8, 8], dtype="float32")
y = paddle.static.data(name='y', shape=[8, 8], dtype="float32")
out = paddle.add(x, y)
paddle.amp.debugging.check_numerics(
tensor=out,
op_type="elementwise_add",
var_name=out.name,
debug_mode=paddle.amp.debugging.DebugMode.CHECK_ALL,
)
exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(
main_program, feed={"x": x_np, "y": y_np}, fetch_list=[out.name]
)
paddle.disable_static()
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册