Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
c18fddd3
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2297
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
c18fddd3
编写于
1月 31, 2023
作者:
N
niuliling123
提交者:
GitHub
1月 31, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Save nan log to file when output_dir is setted (#49200)
上级
0e51f398
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
344 addition
and
112 deletion
+344
-112
paddle/fluid/framework/details/nan_inf_utils_detail.cc
paddle/fluid/framework/details/nan_inf_utils_detail.cc
+17
-106
paddle/fluid/framework/details/nan_inf_utils_detail.cu
paddle/fluid/framework/details/nan_inf_utils_detail.cu
+30
-4
paddle/fluid/framework/details/nan_inf_utils_detail.h
paddle/fluid/framework/details/nan_inf_utils_detail.h
+185
-2
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+4
-0
python/paddle/fluid/tests/unittests/test_nan_inf_dir.py
python/paddle/fluid/tests/unittests/test_nan_inf_dir.py
+108
-0
未找到文件。
paddle/fluid/framework/details/nan_inf_utils_detail.cc
浏览文件 @
c18fddd3
...
@@ -30,6 +30,23 @@ DECLARE_int32(check_nan_inf_level);
...
@@ -30,6 +30,23 @@ DECLARE_int32(check_nan_inf_level);
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
struct
DebugTools
{
DebugTools
()
{}
std
::
string
path
=
""
;
};
static
DebugTools
debug_nan_inf
;
void
SetNanInfDebugPath
(
const
std
::
string
&
nan_inf_path
)
{
debug_nan_inf
.
path
=
nan_inf_path
;
VLOG
(
4
)
<<
"Set the log's path of debug tools : "
<<
nan_inf_path
;
}
std
::
string
GetNanPath
()
{
if
(
debug_nan_inf
.
path
.
empty
())
{
return
""
;
}
return
debug_nan_inf
.
path
+
"/"
;
}
static
std
::
once_flag
white_list_init_flag
;
static
std
::
once_flag
white_list_init_flag
;
...
@@ -134,112 +151,6 @@ static void InitWhiteListFormEnv() {
...
@@ -134,112 +151,6 @@ static void InitWhiteListFormEnv() {
}
}
}
}
template
<
typename
T
,
std
::
enable_if_t
<!
std
::
is_same
<
T
,
phi
::
dtype
::
complex
<
float
>
>::
value
&&
!
std
::
is_same
<
T
,
phi
::
dtype
::
complex
<
double
>>::
value
,
bool
>
=
true
>
static
void
CheckNanInfCpuImpl
(
const
T
*
value_ptr
,
const
int64_t
numel
,
const
std
::
string
&
cpu_hint_str
)
{
using
MT
=
typename
phi
::
dtype
::
template
MPTypeTrait
<
T
>
::
Type
;
#ifdef _OPENMP
// Use maximum 4 threads to collect the nan and inf information.
int
num_threads
=
std
::
max
(
omp_get_num_threads
(),
1
);
num_threads
=
std
::
min
(
num_threads
,
4
);
#else
int
num_threads
=
1
;
#endif
std
::
vector
<
int64_t
>
thread_num_nan
(
num_threads
,
0
);
std
::
vector
<
int64_t
>
thread_num_inf
(
num_threads
,
0
);
std
::
vector
<
MT
>
thread_min_value
(
num_threads
,
static_cast
<
MT
>
(
value_ptr
[
0
]));
std
::
vector
<
MT
>
thread_max_value
(
num_threads
,
static_cast
<
MT
>
(
value_ptr
[
0
]));
std
::
vector
<
MT
>
thread_mean_value
(
num_threads
,
static_cast
<
MT
>
(
0
));
#ifdef _OPENMP
#pragma omp parallel num_threads(num_threads)
#endif
{
#ifdef _OPENMP
int64_t
tid
=
omp_get_thread_num
();
int64_t
chunk_size
=
(
numel
+
num_threads
-
1
)
/
num_threads
;
int64_t
begin
=
tid
*
chunk_size
;
int64_t
end
=
chunk_size
+
begin
>
numel
?
numel
:
chunk_size
+
begin
;
#else
int64_t
tid
=
0
;
int64_t
begin
=
0
;
int64_t
end
=
numel
;
#endif
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
MT
value
=
static_cast
<
MT
>
(
value_ptr
[
i
]);
thread_min_value
[
tid
]
=
std
::
min
(
thread_min_value
[
tid
],
value
);
thread_max_value
[
tid
]
=
std
::
max
(
thread_max_value
[
tid
],
value
);
thread_mean_value
[
tid
]
+=
value
/
static_cast
<
MT
>
(
numel
);
if
(
std
::
isnan
(
value
))
{
thread_num_nan
[
tid
]
+=
1
;
}
else
if
(
std
::
isinf
(
value
))
{
thread_num_inf
[
tid
]
+=
1
;
}
}
}
int64_t
num_nan
=
0
;
int64_t
num_inf
=
0
;
MT
min_value
=
thread_min_value
[
0
];
MT
max_value
=
thread_max_value
[
0
];
MT
mean_value
=
static_cast
<
MT
>
(
0
);
for
(
int
i
=
0
;
i
<
num_threads
;
++
i
)
{
num_nan
+=
thread_num_nan
[
i
];
num_inf
+=
thread_num_inf
[
i
];
min_value
=
std
::
min
(
thread_min_value
[
i
],
min_value
);
max_value
=
std
::
max
(
thread_max_value
[
i
],
max_value
);
mean_value
+=
thread_mean_value
[
i
];
}
PrintForDifferentLevel
<
T
,
MT
>
(
cpu_hint_str
.
c_str
(),
numel
,
num_nan
,
num_inf
,
max_value
,
min_value
,
mean_value
,
FLAGS_check_nan_inf_level
);
}
template
<
typename
T
,
std
::
enable_if_t
<
std
::
is_same
<
T
,
phi
::
dtype
::
complex
<
float
>
>::
value
||
std
::
is_same
<
T
,
phi
::
dtype
::
complex
<
double
>>::
value
,
bool
>
=
true
>
void
CheckNanInfCpuImpl
(
const
T
*
value_ptr
,
const
int64_t
numel
,
const
std
::
string
&
cpu_hint_str
)
{
using
RealType
=
typename
T
::
value_type
;
RealType
real_sum
=
0.0
f
,
imag_sum
=
0.0
f
;
#ifdef _OPENMP
#pragma omp parallel for reduction(+ : real_sum) reduction(+ : imag_sum)
#endif
for
(
int64_t
i
=
0
;
i
<
numel
;
++
i
)
{
T
value
=
value_ptr
[
i
];
real_sum
+=
(
value
.
real
-
value
.
real
);
imag_sum
+=
(
value
.
imag
-
value
.
imag
);
}
if
(
std
::
isnan
(
real_sum
)
||
std
::
isinf
(
real_sum
)
||
std
::
isnan
(
imag_sum
)
||
std
::
isinf
(
imag_sum
))
{
// hot fix for compile failed in gcc4.8
// here also need print detail info of nan or inf later
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"There are NAN or INF in %s."
,
cpu_hint_str
));
}
}
template
<
>
template
<
>
template
<
typename
T
>
template
<
typename
T
>
void
TensorCheckerVisitor
<
phi
::
CPUContext
>::
apply
(
void
TensorCheckerVisitor
<
phi
::
CPUContext
>::
apply
(
...
...
paddle/fluid/framework/details/nan_inf_utils_detail.cu
浏览文件 @
c18fddd3
...
@@ -322,18 +322,26 @@ __global__ void FindGlobalMaxMinAndPrint(const int64_t* block_num_nan_ptr,
...
@@ -322,18 +322,26 @@ __global__ void FindGlobalMaxMinAndPrint(const int64_t* block_num_nan_ptr,
}
}
template
<
typename
T
>
template
<
typename
T
>
static
char
*
GetGpuHintStringPtr
(
const
phi
::
GPUContext
&
ctx
,
inline
std
::
string
GetHintString
(
const
std
::
string
&
op_type
,
const
std
::
string
&
op_type
,
const
std
::
string
&
var_name
,
const
std
::
string
&
var_name
,
int
dev_id
)
{
const
phi
::
Place
&
place
,
int
dev_id
=
-
1
)
{
std
::
string
op_var
=
GetCpuHintString
<
T
>
(
op_type
,
var_name
,
place
,
dev_id
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
(
dev_id
>=
0
&&
dev_id
<
multi_op_var2gpu_str_mutex
().
size
()),
(
dev_id
>=
0
&&
dev_id
<
multi_op_var2gpu_str_mutex
().
size
()),
true
,
true
,
platform
::
errors
::
OutOfRange
(
"GPU dev_id must >=0 and < dev_count=%d"
,
platform
::
errors
::
OutOfRange
(
"GPU dev_id must >=0 and < dev_count=%d"
,
multi_op_var2gpu_str_mutex
().
size
()));
multi_op_var2gpu_str_mutex
().
size
()));
return
op_var
;
}
template
<
typename
T
>
static
char
*
GetGpuHintStringPtr
(
const
phi
::
GPUContext
&
ctx
,
const
std
::
string
&
op_type
,
const
std
::
string
&
var_name
,
int
dev_id
)
{
std
::
string
op_var
=
std
::
string
op_var
=
Get
Cpu
HintString
<
T
>
(
op_type
,
var_name
,
ctx
.
GetPlace
(),
dev_id
);
GetHintString
<
T
>
(
op_type
,
var_name
,
ctx
.
GetPlace
(),
dev_id
);
char
*
gpu_str_ptr
=
nullptr
;
char
*
gpu_str_ptr
=
nullptr
;
{
{
...
@@ -396,6 +404,24 @@ void TensorCheckerVisitor<phi::GPUContext>::apply(
...
@@ -396,6 +404,24 @@ void TensorCheckerVisitor<phi::GPUContext>::apply(
auto
*
dev_ctx
=
reinterpret_cast
<
phi
::
GPUContext
*>
(
auto
*
dev_ctx
=
reinterpret_cast
<
phi
::
GPUContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
tensor
.
place
()));
platform
::
DeviceContextPool
::
Instance
().
Get
(
tensor
.
place
()));
int
dev_id
=
tensor
.
place
().
device
;
int
dev_id
=
tensor
.
place
().
device
;
// Write log to file
auto
file_path
=
GetNanPath
();
if
(
file_path
.
size
()
>
0
)
{
phi
::
DenseTensor
cpu_tensor
;
platform
::
CPUPlace
cpu_place
;
cpu_tensor
.
Resize
(
tensor
.
dims
());
// 1. copy from gpu to cpu
paddle
::
framework
::
TensorCopySync
(
tensor
,
cpu_place
,
&
cpu_tensor
);
auto
*
dev_ctx
=
reinterpret_cast
<
phi
::
GPUContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
tensor
.
place
()));
const
std
::
string
debug_info
=
GetHintString
<
T
>
(
op_type
,
var_name
,
place
,
dev_id
);
// 2. write log to file
CheckNanInfCpuImpl
(
cpu_tensor
.
data
<
T
>
(),
tensor
.
numel
(),
debug_info
,
"gpu"
);
return
;
}
// Write log to window
char
*
gpu_str_ptr
=
char
*
gpu_str_ptr
=
GetGpuHintStringPtr
<
T
>
(
*
dev_ctx
,
op_type
,
var_name
,
dev_id
);
GetGpuHintStringPtr
<
T
>
(
*
dev_ctx
,
op_type
,
var_name
,
dev_id
);
...
...
paddle/fluid/framework/details/nan_inf_utils_detail.h
浏览文件 @
c18fddd3
...
@@ -13,17 +13,33 @@
...
@@ -13,17 +13,33 @@
// limitations under the License.
// limitations under the License.
#pragma once
#pragma once
#include <fstream>
#include <iostream>
#include <string>
#include <string>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/funcs/eigen/extensions.h"
#ifdef _WIN32
#include <direct.h>
#include <io.h>
#define MKDIR(path) _mkdir(path)
#else
#include <sys/stat.h>
#define MKDIR(path) mkdir(path, S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH)
#endif
DECLARE_int32
(
check_nan_inf_level
);
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
void
SetNanInfDebugPath
(
const
std
::
string
&
nan_inf_path
);
std
::
string
GetNanPath
();
template
<
typename
T
,
template
<
typename
T
,
typename
MT
,
typename
MT
,
std
::
enable_if_t
<
std
::
is_same
<
T
,
float
>
::
value
,
bool
>
=
true
>
std
::
enable_if_t
<
std
::
is_same
<
T
,
float
>
::
value
,
bool
>
=
true
>
...
@@ -93,6 +109,49 @@ HOSTDEVICE void PrintForDifferentLevel(const char* debug_info,
...
@@ -93,6 +109,49 @@ HOSTDEVICE void PrintForDifferentLevel(const char* debug_info,
}
}
}
}
template
<
typename
T
,
typename
MT
>
void
PrintForDifferentLevelFile
(
const
char
*
debug_info
,
int64_t
numel
,
int64_t
num_nan
,
int64_t
num_inf
,
MT
max_value
,
MT
min_value
,
MT
mean_value
,
int
check_nan_inf_level
,
const
std
::
string
&
log_name
)
{
int
dev_id
=
0
;
#ifdef PADDLE_WITH_HIP
hipGetDevice
(
&
dev_id
);
#elif PADDLE_WITH_CUDA
cudaGetDevice
(
&
dev_id
);
#endif
auto
file_path
=
GetNanPath
();
MKDIR
(
file_path
.
c_str
());
std
::
string
file_name
=
"worker_"
+
log_name
+
"."
+
std
::
to_string
(
dev_id
);
std
::
string
path
=
file_path
+
file_name
;
std
::
ofstream
outfile
(
path
,
std
::
ios
::
app
);
if
(
!
outfile
.
is_open
())
{
return
;
}
if
(
num_nan
>
0
||
num_inf
>
0
)
{
outfile
<<
"[PRECISION] [ERROR] in "
<<
debug_info
<<
", numel="
<<
static_cast
<
long
long
>
(
numel
)
// NOLINT
<<
", num_nan="
<<
static_cast
<
long
long
>
(
num_nan
)
// NOLINT
<<
", num_inf="
<<
static_cast
<
long
long
>
(
num_inf
)
// NOLINT
<<
", max="
<<
static_cast
<
float
>
(
max_value
)
<<
", min="
<<
static_cast
<
float
>
(
min_value
)
<<
", mean="
<<
static_cast
<
float
>
(
mean_value
)
<<
std
::
endl
;
}
else
if
(
NeedPrint
<
T
,
MT
>
(
max_value
,
min_value
,
check_nan_inf_level
))
{
outfile
<<
"[PRECISION] in "
<<
debug_info
<<
", numel="
<<
static_cast
<
long
long
>
(
numel
)
// NOLINT
<<
", max="
<<
static_cast
<
float
>
(
max_value
)
<<
", min="
<<
static_cast
<
float
>
(
min_value
)
<<
", mean="
<<
static_cast
<
float
>
(
mean_value
)
<<
std
::
endl
;
}
outfile
.
close
();
}
template
<
typename
T
>
template
<
typename
T
>
inline
std
::
string
GetCpuHintString
(
const
std
::
string
&
op_type
,
inline
std
::
string
GetCpuHintString
(
const
std
::
string
&
op_type
,
const
std
::
string
&
var_name
,
const
std
::
string
&
var_name
,
...
@@ -120,6 +179,130 @@ inline std::string GetCpuHintString(const std::string& op_type,
...
@@ -120,6 +179,130 @@ inline std::string GetCpuHintString(const std::string& op_type,
return
ss
.
str
();
return
ss
.
str
();
}
}
template
<
typename
T
,
std
::
enable_if_t
<!
std
::
is_same
<
T
,
phi
::
dtype
::
complex
<
float
>
>::
value
&&
!
std
::
is_same
<
T
,
phi
::
dtype
::
complex
<
double
>>::
value
,
bool
>
=
true
>
static
void
CheckNanInfCpuImpl
(
const
T
*
value_ptr
,
const
int64_t
numel
,
const
std
::
string
&
cpu_hint_str
,
const
std
::
string
log_name
=
"cpu"
)
{
using
MT
=
typename
phi
::
dtype
::
template
MPTypeTrait
<
T
>
::
Type
;
#ifdef _OPENMP
// Use maximum 4 threads to collect the nan and inf information.
int
num_threads
=
std
::
max
(
omp_get_num_threads
(),
1
);
num_threads
=
std
::
min
(
num_threads
,
4
);
#else
int
num_threads
=
1
;
#endif
std
::
vector
<
int64_t
>
thread_num_nan
(
num_threads
,
0
);
std
::
vector
<
int64_t
>
thread_num_inf
(
num_threads
,
0
);
std
::
vector
<
MT
>
thread_min_value
(
num_threads
,
static_cast
<
MT
>
(
value_ptr
[
0
]));
std
::
vector
<
MT
>
thread_max_value
(
num_threads
,
static_cast
<
MT
>
(
value_ptr
[
0
]));
std
::
vector
<
MT
>
thread_mean_value
(
num_threads
,
static_cast
<
MT
>
(
0
));
#ifdef _OPENMP
#pragma omp parallel num_threads(num_threads)
#endif
{
#ifdef _OPENMP
int64_t
tid
=
omp_get_thread_num
();
int64_t
chunk_size
=
(
numel
+
num_threads
-
1
)
/
num_threads
;
int64_t
begin
=
tid
*
chunk_size
;
int64_t
end
=
chunk_size
+
begin
>
numel
?
numel
:
chunk_size
+
begin
;
#else
int64_t
tid
=
0
;
int64_t
begin
=
0
;
int64_t
end
=
numel
;
#endif
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
MT
value
=
static_cast
<
MT
>
(
value_ptr
[
i
]);
thread_min_value
[
tid
]
=
std
::
min
(
thread_min_value
[
tid
],
value
);
thread_max_value
[
tid
]
=
std
::
max
(
thread_max_value
[
tid
],
value
);
thread_mean_value
[
tid
]
+=
value
/
static_cast
<
MT
>
(
numel
);
if
(
std
::
isnan
(
value
))
{
thread_num_nan
[
tid
]
+=
1
;
}
else
if
(
std
::
isinf
(
value
))
{
thread_num_inf
[
tid
]
+=
1
;
}
}
}
int64_t
num_nan
=
0
;
int64_t
num_inf
=
0
;
MT
min_value
=
thread_min_value
[
0
];
MT
max_value
=
thread_max_value
[
0
];
MT
mean_value
=
static_cast
<
MT
>
(
0
);
for
(
int
i
=
0
;
i
<
num_threads
;
++
i
)
{
num_nan
+=
thread_num_nan
[
i
];
num_inf
+=
thread_num_inf
[
i
];
min_value
=
std
::
min
(
thread_min_value
[
i
],
min_value
);
max_value
=
std
::
max
(
thread_max_value
[
i
],
max_value
);
mean_value
+=
thread_mean_value
[
i
];
}
auto
file_path
=
GetNanPath
();
// Write log to file
if
(
file_path
.
size
()
>
0
)
{
VLOG
(
4
)
<<
"[FLAGS_check_nan_inf_level="
<<
FLAGS_check_nan_inf_level
<<
"]. Write log to "
<<
file_path
;
PrintForDifferentLevelFile
<
T
,
MT
>
(
cpu_hint_str
.
c_str
(),
numel
,
num_nan
,
num_inf
,
max_value
,
min_value
,
mean_value
,
FLAGS_check_nan_inf_level
,
log_name
);
return
;
}
PrintForDifferentLevel
<
T
,
MT
>
(
cpu_hint_str
.
c_str
(),
numel
,
num_nan
,
num_inf
,
max_value
,
min_value
,
mean_value
,
FLAGS_check_nan_inf_level
);
}
template
<
typename
T
,
std
::
enable_if_t
<
std
::
is_same
<
T
,
phi
::
dtype
::
complex
<
float
>
>::
value
||
std
::
is_same
<
T
,
phi
::
dtype
::
complex
<
double
>>::
value
,
bool
>
=
true
>
void
CheckNanInfCpuImpl
(
const
T
*
value_ptr
,
const
int64_t
numel
,
const
std
::
string
&
cpu_hint_str
,
const
std
::
string
log_name
=
"cpu"
)
{
using
RealType
=
typename
T
::
value_type
;
RealType
real_sum
=
0.0
f
,
imag_sum
=
0.0
f
;
#ifdef _OPENMP
#pragma omp parallel for reduction(+ : real_sum) reduction(+ : imag_sum)
#endif
for
(
int64_t
i
=
0
;
i
<
numel
;
++
i
)
{
T
value
=
value_ptr
[
i
];
real_sum
+=
(
value
.
real
-
value
.
real
);
imag_sum
+=
(
value
.
imag
-
value
.
imag
);
}
if
(
std
::
isnan
(
real_sum
)
||
std
::
isinf
(
real_sum
)
||
std
::
isnan
(
imag_sum
)
||
std
::
isinf
(
imag_sum
))
{
// hot fix for compile failed in gcc4.8
// here also need print detail info of nan or inf later
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"There are NAN or INF in %s."
,
cpu_hint_str
));
}
}
template
<
typename
DeviceContext
>
template
<
typename
DeviceContext
>
struct
TensorCheckerVisitor
{
struct
TensorCheckerVisitor
{
TensorCheckerVisitor
(
const
std
::
string
&
o
,
TensorCheckerVisitor
(
const
std
::
string
&
o
,
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
c18fddd3
...
@@ -34,6 +34,7 @@ limitations under the License. */
...
@@ -34,6 +34,7 @@ limitations under the License. */
#include "paddle/fluid/framework/custom_operator.h"
#include "paddle/fluid/framework/custom_operator.h"
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/details/nan_inf_utils_detail.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/executor_cache.h"
#include "paddle/fluid/framework/executor_cache.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
...
@@ -2671,6 +2672,9 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -2671,6 +2672,9 @@ All parameter, weight, gradient are variables in Paddle.
m
.
def
(
"use_layout_autotune"
,
m
.
def
(
"use_layout_autotune"
,
[]
{
return
egr
::
Controller
::
Instance
().
UseLayoutAutoTune
();
});
[]
{
return
egr
::
Controller
::
Instance
().
UseLayoutAutoTune
();
});
// Add the api for nan op debug
m
.
def
(
"set_nan_inf_debug_path"
,
&
paddle
::
framework
::
details
::
SetNanInfDebugPath
);
BindFleetWrapper
(
&
m
);
BindFleetWrapper
(
&
m
);
BindIO
(
&
m
);
BindIO
(
&
m
);
...
...
python/paddle/fluid/tests/unittests/test_nan_inf_dir.py
0 → 100644
浏览文件 @
c18fddd3
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
import
unittest
import
numpy
as
np
import
paddle
class
TestNanInfDirCheckResult
(
unittest
.
TestCase
):
def
generate_inputs
(
self
,
shape
,
dtype
=
"float32"
):
data
=
np
.
random
.
random
(
size
=
shape
).
astype
(
dtype
)
# [-10, 10)
x
=
(
data
*
20
-
10
)
*
np
.
random
.
randint
(
low
=
0
,
high
=
2
,
size
=
shape
).
astype
(
dtype
)
y
=
np
.
random
.
randint
(
low
=
0
,
high
=
2
,
size
=
shape
).
astype
(
dtype
)
return
x
,
y
def
get_reference_num_nan_inf
(
self
,
x
):
out
=
np
.
log
(
x
)
num_nan
=
np
.
sum
(
np
.
isnan
(
out
))
num_inf
=
np
.
sum
(
np
.
isinf
(
out
))
print
(
"[reference] num_nan={}, num_inf={}"
.
format
(
num_nan
,
num_inf
))
return
num_nan
,
num_inf
def
get_num_nan_inf
(
self
,
x_np
,
use_cuda
=
True
,
add_assert
=
False
,
pt
=
"nan_inf_log_dir"
):
num_nan
=
0
num_inf
=
0
if
add_assert
:
if
use_cuda
:
paddle
.
device
.
set_device
(
"gpu:0"
)
else
:
paddle
.
device
.
set_device
(
"cpu"
)
x
=
paddle
.
to_tensor
(
x_np
)
out
=
paddle
.
log
(
x
)
sys
.
stdout
.
flush
()
if
not
use_cuda
:
os
.
path
.
exists
(
pt
)
num_nan
=
0
num_inf
=
0
for
root
,
dirs
,
files
in
os
.
walk
(
pt
):
for
file_name
in
files
:
if
file_name
.
startswith
(
'worker_cpu'
):
file_path
=
os
.
path
.
join
(
root
,
file_name
)
with
open
(
file_path
,
"rb"
)
as
fp
:
for
e
in
fp
:
err_str_list
=
(
str
(
e
)
.
replace
(
"("
,
" "
)
.
replace
(
")"
,
" "
)
.
replace
(
","
,
" "
)
.
split
(
" "
)
)
for
err_str
in
err_str_list
:
if
"num_nan"
in
err_str
:
num_nan
=
int
(
err_str
.
split
(
"="
)[
1
])
elif
"num_inf"
in
err_str
:
num_inf
=
int
(
err_str
.
split
(
"="
)[
1
])
print
(
"[paddle] num_nan={}, num_inf={}"
.
format
(
num_nan
,
num_inf
)
)
return
num_nan
,
num_inf
def
test_num_nan_inf
(
self
):
path
=
"nan_inf_log_dir"
paddle
.
fluid
.
core
.
set_nan_inf_debug_path
(
path
)
def
_check_num_nan_inf
(
use_cuda
):
shape
=
[
32
,
32
]
x_np
,
_
=
self
.
generate_inputs
(
shape
)
num_nan_np
,
num_inf_np
=
self
.
get_reference_num_nan_inf
(
x_np
)
add_assert
=
(
num_nan_np
+
num_inf_np
)
>
0
num_nan
,
num_inf
=
self
.
get_num_nan_inf
(
x_np
,
use_cuda
,
add_assert
,
path
)
if
not
use_cuda
:
assert
num_nan
==
num_nan_np
and
num_inf
==
num_inf_np
paddle
.
set_flags
(
{
"FLAGS_check_nan_inf"
:
1
,
"FLAGS_check_nan_inf_level"
:
3
}
)
_check_num_nan_inf
(
use_cuda
=
False
)
if
paddle
.
fluid
.
core
.
is_compiled_with_cuda
():
_check_num_nan_inf
(
use_cuda
=
True
)
x
=
paddle
.
to_tensor
([
2
,
3
,
4
],
'float32'
)
y
=
paddle
.
to_tensor
([
1
,
5
,
2
],
'float32'
)
z
=
paddle
.
add
(
x
,
y
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录