Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
54bc3b46
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
54bc3b46
编写于
11月 05, 2022
作者:
Y
Yiqun Liu
提交者:
GitHub
11月 05, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Use an unified FLAGS_check_nan_inf_level to control the result of checking infinite. (#47672)
上级
99504cbb
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
57 addition
and
44 deletion
+57
-44
paddle/fluid/framework/details/nan_inf_utils_detail.cu
paddle/fluid/framework/details/nan_inf_utils_detail.cu
+42
-22
paddle/fluid/platform/flags.cc
paddle/fluid/platform/flags.cc
+15
-22
未找到文件。
paddle/fluid/framework/details/nan_inf_utils_detail.cu
浏览文件 @
54bc3b46
...
...
@@ -25,8 +25,7 @@
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"
DECLARE_bool
(
abort_on_nan_inf
);
DECLARE_bool
(
check_tensor_max_min
);
DECLARE_int32
(
check_nan_inf_level
);
namespace
paddle
{
namespace
framework
{
...
...
@@ -233,23 +232,46 @@ __global__ void FindNanInfAndBlockMaxMin(const T* value_ptr,
tensor_block_mean_ptr
);
}
template
<
typename
T
>
template
<
typename
T
,
typename
MT
,
std
::
enable_if_t
<
std
::
is_same
<
T
,
float
>
::
value
,
bool
>
=
true
>
__device__
bool
NeedPrint
(
MT
max_value
,
MT
min_value
,
int
check_nan_inf_level
)
{
if
(
check_nan_inf_level
>=
3
)
{
return
true
;
}
else
if
(
check_nan_inf_level
>=
2
)
{
MT
fp16_max
=
static_cast
<
MT
>
(
std
::
numeric_limits
<
phi
::
dtype
::
float16
>::
max
());
return
max_value
>
fp16_max
||
min_value
<
-
fp16_max
;
}
return
false
;
}
template
<
typename
T
,
typename
MT
,
std
::
enable_if_t
<!
std
::
is_same
<
T
,
float
>
::
value
,
bool
>
=
true
>
__device__
bool
NeedPrint
(
MT
max_value
,
MT
min_value
,
int
check_nan_inf_level
)
{
if
(
check_nan_inf_level
>=
3
)
{
return
true
;
}
return
false
;
}
template
<
typename
T
,
typename
MT
>
__global__
void
FindGlobalMaxMinAndPrint
(
const
int
*
found_nan_inf_ptr
,
const
T
*
tensor_block_max_ptr
,
const
T
*
tensor_block_min_ptr
,
const
T
*
tensor_block_mean_ptr
,
const
M
T
*
tensor_block_max_ptr
,
const
M
T
*
tensor_block_min_ptr
,
const
M
T
*
tensor_block_mean_ptr
,
const
char
*
debug_info
,
int64_t
numel
,
int64_t
numel_max_min
,
bool
abort_on_nan_inf
,
bool
check_tensor_max_min
)
{
int
check_nan_inf_level
)
{
if
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
)
{
int
has_nan
=
found_nan_inf_ptr
[
0
];
int
has_inf
=
found_nan_inf_ptr
[
1
];
T
max_value
=
static_cast
<
T
>
(
0
);
T
min_value
=
static_cast
<
T
>
(
0
);
T
mean_value
=
static_cast
<
T
>
(
0
);
MT
max_value
=
static_cast
<
M
T
>
(
0
);
MT
min_value
=
static_cast
<
M
T
>
(
0
);
MT
mean_value
=
static_cast
<
M
T
>
(
0
);
if
(
tensor_block_max_ptr
&&
tensor_block_min_ptr
&&
tensor_block_mean_ptr
)
{
max_value
=
tensor_block_max_ptr
[
0
];
min_value
=
tensor_block_min_ptr
[
0
];
...
...
@@ -257,9 +279,9 @@ __global__ void FindGlobalMaxMinAndPrint(const int* found_nan_inf_ptr,
// numel_max_min <= 128
for
(
int64_t
i
=
1
;
i
<
numel_max_min
;
++
i
)
{
T
tmp_max_value
=
tensor_block_max_ptr
[
i
];
T
tmp_min_value
=
tensor_block_min_ptr
[
i
];
T
tmp_mean_value
=
tensor_block_mean_ptr
[
i
];
M
T
tmp_max_value
=
tensor_block_max_ptr
[
i
];
M
T
tmp_min_value
=
tensor_block_min_ptr
[
i
];
M
T
tmp_mean_value
=
tensor_block_mean_ptr
[
i
];
max_value
=
tmp_max_value
>
max_value
?
tmp_max_value
:
max_value
;
min_value
=
tmp_min_value
<
min_value
?
tmp_min_value
:
min_value
;
...
...
@@ -268,7 +290,7 @@ __global__ void FindGlobalMaxMinAndPrint(const int* found_nan_inf_ptr,
}
if
(
has_nan
||
has_inf
)
{
if
(
abort_on_nan_inf
)
{
if
(
check_nan_inf_level
==
0
)
{
PADDLE_ENFORCE
(
false
,
"===[PRECISION] [ERROR] in %s, numel=%ld, find_nan=%d, "
"find_inf=%d, "
...
...
@@ -280,7 +302,7 @@ __global__ void FindGlobalMaxMinAndPrint(const int* found_nan_inf_ptr,
static_cast
<
float
>
(
max_value
),
static_cast
<
float
>
(
min_value
),
static_cast
<
float
>
(
mean_value
));
}
else
{
}
else
if
(
check_nan_inf_level
>=
1
)
{
printf
(
"===[PRECISION] [ERROR] in %s, numel=%ld, find_nan=%d, "
"find_inf=%d, "
...
...
@@ -293,7 +315,7 @@ __global__ void FindGlobalMaxMinAndPrint(const int* found_nan_inf_ptr,
static_cast
<
float
>
(
min_value
),
static_cast
<
float
>
(
mean_value
));
}
}
else
if
(
check_tensor_max_min
)
{
}
else
if
(
NeedPrint
<
T
,
MT
>
(
max_value
,
min_value
,
check_nan_inf_level
)
)
{
printf
(
"[PRECISION] in %s, numel=%ld, max=%e, min=%e, mean=%e
\n
"
,
debug_info
,
numel
,
...
...
@@ -423,9 +445,8 @@ void TensorCheckerVisitor<phi::GPUContext>::apply(
tensor_block_min_ptr
,
tensor_block_mean_ptr
);
bool
abort_on_nan_inf
=
FLAGS_abort_on_nan_inf
;
bool
check_tensor_max_min
=
FLAGS_check_tensor_max_min
;
FindGlobalMaxMinAndPrint
<
MT
>
int
check_nan_inf_level
=
FLAGS_check_nan_inf_level
;
FindGlobalMaxMinAndPrint
<
T
,
MT
>
<<<
1
,
1
,
0
,
dev_ctx
->
stream
()
>>>
(
found_nan_inf_ptr
,
tensor_block_max_ptr
,
tensor_block_min_ptr
,
...
...
@@ -433,8 +454,7 @@ void TensorCheckerVisitor<phi::GPUContext>::apply(
gpu_str_ptr
,
tensor_
.
numel
(),
numel_max_min
,
abort_on_nan_inf
,
check_tensor_max_min
);
check_nan_inf_level
);
#endif
}
...
...
paddle/fluid/platform/flags.cc
浏览文件 @
54bc3b46
...
...
@@ -70,31 +70,24 @@ PADDLE_DEFINE_EXPORTED_bool(
/**
* Operator related FLAG
* Name: FLAGS_
abort_on_nan_inf
* Name: FLAGS_
check_nan_inf_level
* Since Version: 2.5.0
* Value Range: bool, default=true
* Example:
* Note: Used to debug. Whether abort the process when any operator produce
* NAN/INF. It only works when FLAGS_check_nan_inf is set.
*/
PADDLE_DEFINE_EXPORTED_bool
(
abort_on_nan_inf
,
true
,
"Whether abort the process when any operator produce NAN/INF or not."
);
/**
* Operator related FLAG
* Name: FLAGS_check_tensor_max_min
* Since Version: 2.5.0
* Value Range: bool, default=false
* Value Range: int32, default=0
* Example:
* Note: Used to debug. Enable to calculate and print the max and min value of
* each operator's output tensor. It only works when FLAGS_check_nan_inf is set.
* Note: Used to debug. Setting the check and print level when
* FLAGS_check_nan_inf is set.
* - 0, abort the process when any operator produce NAN/INF and only print the
* information of tensor which holds NAN/INF.
* - 1, continue the training or inference process and print the information of
* all tensors which holds NAN/INF.
* - 2, print the information of float tensors when the max or min value
* overflowing float16's limit.
* - 3, print the information of all tensors.
*/
PADDLE_DEFINE_EXPORTED_
bool
(
check_
tensor_max_min
,
false
,
"
Whether to check all the output tensors's min and max value
."
);
PADDLE_DEFINE_EXPORTED_
int32
(
check_
nan_inf_level
,
0
,
"
Setting the check and print level when FLAGS_check_nan_inf is set
."
);
/**
* Operator related FLAG
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录