Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
wmsofts
Paddle
提交
6c471ed0
P
Paddle
项目概览
wmsofts
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
6c471ed0
编写于
2月 28, 2023
作者:
N
niuliling123
提交者:
GitHub
2月 28, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Count the number of 0 in the output Tensor (#50981)
上级
49752074
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
56 addition
and
19 deletion
+56
-19
paddle/fluid/framework/details/nan_inf_utils_detail.cu
paddle/fluid/framework/details/nan_inf_utils_detail.cu
+27
-6
paddle/fluid/framework/details/nan_inf_utils_detail.h
paddle/fluid/framework/details/nan_inf_utils_detail.h
+29
-13
未找到文件。
paddle/fluid/framework/details/nan_inf_utils_detail.cu
浏览文件 @
6c471ed0
...
...
@@ -174,15 +174,19 @@ __device__ T BlockReduce(T value) {
__device__
void
BlockReduceNumNanInfAndWrite
(
const
int64_t
num_nan
,
const
int64_t
num_inf
,
const
int64_t
num_zero
,
int64_t
offset
,
int64_t
*
num_nan_ptr
,
int64_t
*
num_inf_ptr
)
{
int64_t
*
num_inf_ptr
,
int64_t
*
num_zero_ptr
)
{
int64_t
block_num_nan
=
BlockReduce
<
int64_t
,
2
>
(
num_nan
);
int64_t
block_num_inf
=
BlockReduce
<
int64_t
,
2
>
(
num_inf
);
int64_t
block_num_zero
=
BlockReduce
<
int64_t
,
2
>
(
num_zero
);
if
(
threadIdx
.
x
==
0
)
{
num_nan_ptr
[
offset
]
=
block_num_nan
;
num_inf_ptr
[
offset
]
=
block_num_inf
;
num_zero_ptr
[
offset
]
=
block_num_zero
;
}
}
...
...
@@ -233,6 +237,7 @@ __global__ void FindNanInfAndBlockMaxMin(const T* value_ptr,
const
int64_t
numel
,
int64_t
*
block_num_nan_ptr
,
int64_t
*
block_num_inf_ptr
,
int64_t
*
block_num_zero_ptr
,
MT
*
tensor_block_max_ptr
,
MT
*
tensor_block_min_ptr
,
MT
*
tensor_block_mean_ptr
)
{
...
...
@@ -240,6 +245,7 @@ __global__ void FindNanInfAndBlockMaxMin(const T* value_ptr,
int64_t
num_nan
=
0
;
int64_t
num_inf
=
0
;
int64_t
num_zero
=
0
;
MT
max_value
=
static_cast
<
MT
>
(
i
<
numel
?
value_ptr
[
i
]
:
value_ptr
[
0
]);
MT
min_value
=
static_cast
<
MT
>
(
i
<
numel
?
value_ptr
[
i
]
:
value_ptr
[
0
]);
...
...
@@ -256,10 +262,18 @@ __global__ void FindNanInfAndBlockMaxMin(const T* value_ptr,
}
else
if
(
isinf
(
value
))
{
num_inf
+=
1
;
}
if
(
value
==
static_cast
<
MT
>
(
0
))
{
num_zero
+=
1
;
}
}
BlockReduceNumNanInfAndWrite
(
num_nan
,
num_inf
,
blockIdx
.
x
,
block_num_nan_ptr
,
block_num_inf_ptr
);
BlockReduceNumNanInfAndWrite
(
num_nan
,
num_inf
,
num_zero
,
blockIdx
.
x
,
block_num_nan_ptr
,
block_num_inf_ptr
,
block_num_zero_ptr
);
BlockReduceMaxMinAndWrite
<
MT
>
(
max_value
,
min_value
,
...
...
@@ -273,6 +287,7 @@ __global__ void FindNanInfAndBlockMaxMin(const T* value_ptr,
template
<
typename
T
,
typename
MT
>
__global__
void
FindGlobalMaxMinAndPrint
(
const
int64_t
*
block_num_nan_ptr
,
const
int64_t
*
block_num_inf_ptr
,
const
int64_t
*
block_num_zero_ptr
,
const
MT
*
tensor_block_max_ptr
,
const
MT
*
tensor_block_min_ptr
,
const
MT
*
tensor_block_mean_ptr
,
...
...
@@ -283,11 +298,13 @@ __global__ void FindGlobalMaxMinAndPrint(const int64_t* block_num_nan_ptr,
if
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
)
{
int64_t
num_nan
=
0
;
int64_t
num_inf
=
0
;
int64_t
num_zero
=
0
;
// numel_max_min <= 128
for
(
int64_t
i
=
0
;
i
<
numel_max_min
;
++
i
)
{
num_nan
+=
block_num_nan_ptr
[
i
];
num_inf
+=
block_num_inf_ptr
[
i
];
num_zero
+=
block_num_zero_ptr
[
i
];
}
MT
max_value
=
static_cast
<
MT
>
(
0
);
...
...
@@ -314,6 +331,7 @@ __global__ void FindGlobalMaxMinAndPrint(const int64_t* block_num_nan_ptr,
numel
,
num_nan
,
num_inf
,
num_zero
,
max_value
,
min_value
,
mean_value
,
...
...
@@ -451,11 +469,12 @@ void TensorCheckerVisitor<phi::GPUContext>::apply(
int64_t
numel_max_min
=
blocks
;
phi
::
DenseTensor
block_num_nan_inf
;
block_num_nan_inf
.
Resize
({
static_cast
<
int64_t
>
(
2
*
numel_max_min
)});
phi
::
DenseTensor
block_num_nan_inf
_zero
;
block_num_nan_inf
_zero
.
Resize
({
static_cast
<
int64_t
>
(
3
*
numel_max_min
)});
int64_t
*
block_num_nan_ptr
=
dev_ctx
->
template
Alloc
<
int64_t
>(
&
block_num_nan_inf
);
dev_ctx
->
template
Alloc
<
int64_t
>(
&
block_num_nan_inf
_zero
);
int64_t
*
block_num_inf_ptr
=
block_num_nan_ptr
+
numel_max_min
;
int64_t
*
block_num_zero_ptr
=
block_num_inf_ptr
+
numel_max_min
;
phi
::
DenseTensor
tensor_block_max_min
;
tensor_block_max_min
.
Resize
({
static_cast
<
int64_t
>
(
3
*
numel_max_min
)});
...
...
@@ -468,6 +487,7 @@ void TensorCheckerVisitor<phi::GPUContext>::apply(
tensor
.
numel
(),
block_num_nan_ptr
,
block_num_inf_ptr
,
block_num_zero_ptr
,
tensor_block_max_ptr
,
tensor_block_min_ptr
,
tensor_block_mean_ptr
);
...
...
@@ -476,6 +496,7 @@ void TensorCheckerVisitor<phi::GPUContext>::apply(
FindGlobalMaxMinAndPrint
<
T
,
MT
>
<<<
1
,
1
,
0
,
dev_ctx
->
stream
()
>>>
(
block_num_nan_ptr
,
block_num_inf_ptr
,
block_num_zero_ptr
,
tensor_block_max_ptr
,
tensor_block_min_ptr
,
tensor_block_mean_ptr
,
...
...
paddle/fluid/framework/details/nan_inf_utils_detail.h
浏览文件 @
6c471ed0
...
...
@@ -69,6 +69,7 @@ HOSTDEVICE void PrintForDifferentLevel(const char* debug_info,
int64_t
numel
,
int64_t
num_nan
,
int64_t
num_inf
,
int64_t
num_zero
,
MT
max_value
,
MT
min_value
,
MT
mean_value
,
...
...
@@ -76,26 +77,31 @@ HOSTDEVICE void PrintForDifferentLevel(const char* debug_info,
if
(
num_nan
>
0
||
num_inf
>
0
)
{
printf
(
"[PRECISION] [ERROR] in %s, numel=%lld, num_nan=%lld, "
"num_inf=%lld, max=%e, min=%e, mean=%e
\n
"
,
"num_inf=%lld,
num_zero=%lld,
max=%e, min=%e, mean=%e
\n
"
,
debug_info
,
static_cast
<
long
long
>
(
numel
),
// NOLINT
static_cast
<
long
long
>
(
num_nan
),
// NOLINT
static_cast
<
long
long
>
(
num_inf
),
// NOLINT
static_cast
<
long
long
>
(
numel
),
// NOLINT
static_cast
<
long
long
>
(
num_nan
),
// NOLINT
static_cast
<
long
long
>
(
num_inf
),
// NOLINT
static_cast
<
long
long
>
(
num_zero
),
// NOLINT
static_cast
<
float
>
(
max_value
),
static_cast
<
float
>
(
min_value
),
static_cast
<
float
>
(
mean_value
));
if
(
check_nan_inf_level
==
0
)
{
#if defined(__NVCC__) || defined(__HIPCC__)
PADDLE_ENFORCE
(
false
,
"There are NAN or INF (num_nan=%ld, num_inf=%lld) in %s."
,
static_cast
<
long
long
>
(
num_nan
),
// NOLINT
static_cast
<
long
long
>
(
num_inf
),
// NOLINT
"There are NAN or INF (num_nan=%ld, num_inf=%lld, "
"num_zero=%lld) in %s."
,
static_cast
<
long
long
>
(
num_nan
),
// NOLINT
static_cast
<
long
long
>
(
num_inf
),
// NOLINT
static_cast
<
long
long
>
(
num_zero
),
// NOLINT
debug_info
);
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"There are NAN or INF (num_nan=%lld, num_inf=%lld) in %s."
,
static_cast
<
long
long
>
(
num_nan
),
// NOLINT
static_cast
<
long
long
>
(
num_inf
),
// NOLINT
"There are NAN or INF (num_nan=%lld, num_inf=%lld, num_zero=%lld) in "
"%s."
,
static_cast
<
long
long
>
(
num_nan
),
// NOLINT
static_cast
<
long
long
>
(
num_inf
),
// NOLINT
static_cast
<
long
long
>
(
num_zero
),
// NOLINT
debug_info
));
#endif
}
...
...
@@ -114,6 +120,7 @@ void PrintForDifferentLevelFile(const char* debug_info,
int64_t
numel
,
int64_t
num_nan
,
int64_t
num_inf
,
int64_t
num_zero
,
MT
max_value
,
MT
min_value
,
MT
mean_value
,
...
...
@@ -136,9 +143,10 @@ void PrintForDifferentLevelFile(const char* debug_info,
if
(
num_nan
>
0
||
num_inf
>
0
)
{
outfile
<<
"[PRECISION] [ERROR] in "
<<
debug_info
<<
", numel="
<<
static_cast
<
long
long
>
(
numel
)
// NOLINT
<<
", num_nan="
<<
static_cast
<
long
long
>
(
num_nan
)
// NOLINT
<<
", num_inf="
<<
static_cast
<
long
long
>
(
num_inf
)
// NOLINT
<<
", numel="
<<
static_cast
<
long
long
>
(
numel
)
// NOLINT
<<
", num_nan="
<<
static_cast
<
long
long
>
(
num_nan
)
// NOLINT
<<
", num_inf="
<<
static_cast
<
long
long
>
(
num_inf
)
// NOLINT
<<
", num_zero="
<<
static_cast
<
long
long
>
(
num_zero
)
// NOLINT
<<
", max="
<<
static_cast
<
float
>
(
max_value
)
<<
", min="
<<
static_cast
<
float
>
(
min_value
)
<<
", mean="
<<
static_cast
<
float
>
(
mean_value
)
<<
std
::
endl
;
...
...
@@ -200,6 +208,7 @@ static void CheckNanInfCpuImpl(const T* value_ptr,
std
::
vector
<
int64_t
>
thread_num_nan
(
num_threads
,
0
);
std
::
vector
<
int64_t
>
thread_num_inf
(
num_threads
,
0
);
std
::
vector
<
int64_t
>
thread_num_zero
(
num_threads
,
0
);
std
::
vector
<
MT
>
thread_min_value
(
num_threads
,
static_cast
<
MT
>
(
value_ptr
[
0
]));
std
::
vector
<
MT
>
thread_max_value
(
num_threads
,
static_cast
<
MT
>
(
value_ptr
[
0
]));
std
::
vector
<
MT
>
thread_mean_value
(
num_threads
,
static_cast
<
MT
>
(
0
));
...
...
@@ -230,17 +239,22 @@ static void CheckNanInfCpuImpl(const T* value_ptr,
}
else
if
(
std
::
isinf
(
value
))
{
thread_num_inf
[
tid
]
+=
1
;
}
if
(
value
==
0
)
{
thread_num_zero
[
tid
]
+=
1
;
}
}
}
int64_t
num_nan
=
0
;
int64_t
num_inf
=
0
;
int64_t
num_zero
=
0
;
MT
min_value
=
thread_min_value
[
0
];
MT
max_value
=
thread_max_value
[
0
];
MT
mean_value
=
static_cast
<
MT
>
(
0
);
for
(
int
i
=
0
;
i
<
num_threads
;
++
i
)
{
num_nan
+=
thread_num_nan
[
i
];
num_inf
+=
thread_num_inf
[
i
];
num_zero
+=
thread_num_zero
[
i
];
min_value
=
std
::
min
(
thread_min_value
[
i
],
min_value
);
max_value
=
std
::
max
(
thread_max_value
[
i
],
max_value
);
mean_value
+=
thread_mean_value
[
i
];
...
...
@@ -254,6 +268,7 @@ static void CheckNanInfCpuImpl(const T* value_ptr,
numel
,
num_nan
,
num_inf
,
num_zero
,
max_value
,
min_value
,
mean_value
,
...
...
@@ -266,6 +281,7 @@ static void CheckNanInfCpuImpl(const T* value_ptr,
numel
,
num_nan
,
num_inf
,
num_zero
,
max_value
,
min_value
,
mean_value
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录