Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
e7652a37
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e7652a37
编写于
4月 13, 2023
作者:
N
niuliling123
提交者:
GitHub
4月 13, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Support print stack when place=CUDAPlace (#52841)
上级
b9ccf0e6
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
51 addition
and
19 deletion
+51
-19
paddle/fluid/framework/details/nan_inf_utils_detail.cu
paddle/fluid/framework/details/nan_inf_utils_detail.cu
+39
-3
paddle/fluid/framework/details/nan_inf_utils_detail.h
paddle/fluid/framework/details/nan_inf_utils_detail.h
+12
-16
未找到文件。
paddle/fluid/framework/details/nan_inf_utils_detail.cu
浏览文件 @
e7652a37
...
...
@@ -23,6 +23,7 @@
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"
DECLARE_int32
(
check_nan_inf_level
);
...
...
@@ -294,7 +295,8 @@ __global__ void FindGlobalMaxMinAndPrint(const int64_t* block_num_nan_ptr,
const
char
*
debug_info
,
int64_t
numel
,
int64_t
numel_max_min
,
int
check_nan_inf_level
)
{
int
check_nan_inf_level
,
int64_t
*
nan_inf_zero
)
{
if
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
)
{
int64_t
num_nan
=
0
;
int64_t
num_inf
=
0
;
...
...
@@ -325,8 +327,12 @@ __global__ void FindGlobalMaxMinAndPrint(const int64_t* block_num_nan_ptr,
min_value
=
tmp_min_value
<
min_value
?
tmp_min_value
:
min_value
;
mean_value
+=
tmp_mean_value
;
}
if
(
check_nan_inf_level
==
0
)
{
nan_inf_zero
[
0
]
=
num_nan
;
nan_inf_zero
[
1
]
=
num_inf
;
nan_inf_zero
[
2
]
=
num_zero
;
}
}
PrintForDifferentLevel
<
T
,
MT
>
(
debug_info
,
numel
,
num_nan
,
...
...
@@ -493,6 +499,10 @@ void TensorCheckerVisitor<phi::GPUContext>::apply(
tensor_block_mean_ptr
);
int
check_nan_inf_level
=
FLAGS_check_nan_inf_level
;
phi
::
DenseTensor
nan_inf_zero_tensor
;
nan_inf_zero_tensor
.
Resize
({
static_cast
<
int64_t
>
(
3
)});
int64_t
*
nan_inf_zero
=
dev_ctx
->
template
Alloc
<
int64_t
>(
&
nan_inf_zero_tensor
);
FindGlobalMaxMinAndPrint
<
T
,
MT
>
<<<
1
,
1
,
0
,
dev_ctx
->
stream
()
>>>
(
block_num_nan_ptr
,
block_num_inf_ptr
,
...
...
@@ -503,7 +513,33 @@ void TensorCheckerVisitor<phi::GPUContext>::apply(
gpu_str_ptr
,
tensor
.
numel
(),
numel_max_min
,
check_nan_inf_level
);
check_nan_inf_level
,
nan_inf_zero_tensor
.
data
<
int64_t
>
());
if
(
check_nan_inf_level
==
0
)
{
auto
nan_cpu
=
phi
::
memory_utils
::
Alloc
(
phi
::
CPUPlace
(),
sizeof
(
int64_t
)
*
3
);
int64_t
*
nan_cpu_ptr
=
reinterpret_cast
<
int64_t
*>
(
nan_cpu
->
ptr
());
phi
::
memory_utils
::
Copy
(
phi
::
CPUPlace
(),
nan_cpu_ptr
,
place
,
nan_inf_zero
,
3
*
sizeof
(
int64_t
),
dev_ctx
->
stream
());
dev_ctx
->
Wait
();
if
(
nan_cpu_ptr
[
0
]
>
0
||
nan_cpu_ptr
[
1
]
>
0
)
{
const
std
::
string
debug_info
=
GetHintString
<
T
>
(
op_type
,
var_name
,
place
,
dev_id
);
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"There are NAN or INF (num_nan=%lld, num_inf=%lld, num_zero=%lld) in "
"%s."
,
static_cast
<
long
long
>
(
nan_cpu_ptr
[
0
]),
// NOLINT
static_cast
<
long
long
>
(
nan_cpu_ptr
[
1
]),
// NOLINT
static_cast
<
long
long
>
(
nan_cpu_ptr
[
2
]),
// NOLINT
debug_info
));
}
}
#endif
}
...
...
paddle/fluid/framework/details/nan_inf_utils_detail.h
浏览文件 @
e7652a37
...
...
@@ -87,15 +87,7 @@ HOSTDEVICE void PrintForDifferentLevel(const char* debug_info,
static_cast
<
float
>
(
min_value
),
static_cast
<
float
>
(
mean_value
));
if
(
check_nan_inf_level
==
0
)
{
#if defined(__NVCC__) || defined(__HIPCC__)
PADDLE_ENFORCE
(
false
,
"There are NAN or INF (num_nan=%ld, num_inf=%lld, "
"num_zero=%lld) in %s."
,
static_cast
<
long
long
>
(
num_nan
),
// NOLINT
static_cast
<
long
long
>
(
num_inf
),
// NOLINT
static_cast
<
long
long
>
(
num_zero
),
// NOLINT
debug_info
);
#else
#if !(defined(__NVCC__) || defined(__HIPCC__))
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"There are NAN or INF (num_nan=%lld, num_inf=%lld, num_zero=%lld) in "
"%s."
,
...
...
@@ -106,12 +98,15 @@ HOSTDEVICE void PrintForDifferentLevel(const char* debug_info,
#endif
}
}
else
if
(
NeedPrint
<
T
,
MT
>
(
max_value
,
min_value
,
check_nan_inf_level
))
{
printf
(
"[PRECISION] in %s, numel=%lld, max=%e, min=%e, mean=%e
\n
"
,
debug_info
,
static_cast
<
long
long
>
(
numel
),
// NOLINT
static_cast
<
float
>
(
max_value
),
static_cast
<
float
>
(
min_value
),
static_cast
<
float
>
(
mean_value
));
printf
(
"[PRECISION] in %s, numel=%lld, num_zero=%lld, max=%e, min=%e, "
"mean=%e
\n
"
,
debug_info
,
static_cast
<
long
long
>
(
numel
),
// NOLINT
static_cast
<
long
long
>
(
num_zero
),
// NOLINT
static_cast
<
float
>
(
max_value
),
static_cast
<
float
>
(
min_value
),
static_cast
<
float
>
(
mean_value
));
}
}
...
...
@@ -152,7 +147,8 @@ void PrintForDifferentLevelFile(const char* debug_info,
<<
", mean="
<<
static_cast
<
float
>
(
mean_value
)
<<
std
::
endl
;
}
else
if
(
NeedPrint
<
T
,
MT
>
(
max_value
,
min_value
,
check_nan_inf_level
))
{
outfile
<<
"[PRECISION] in "
<<
debug_info
<<
", numel="
<<
static_cast
<
long
long
>
(
numel
)
// NOLINT
<<
", numel="
<<
static_cast
<
long
long
>
(
numel
)
// NOLINT
<<
", num_zero="
<<
static_cast
<
long
long
>
(
num_zero
)
// NOLINT
<<
", max="
<<
static_cast
<
float
>
(
max_value
)
<<
", min="
<<
static_cast
<
float
>
(
min_value
)
<<
", mean="
<<
static_cast
<
float
>
(
mean_value
)
<<
std
::
endl
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录