未验证 提交 19eefef4 编写于 作者: X XiangGao 提交者: GitHub

Check for cuda errors immediately after kernel launch (#32557)

Co-authored-by: NYang Zhang <yangzhang@live.com>
上级 c1db7e32
......@@ -134,6 +134,17 @@ class OpRegistry {
static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc);
};
template <typename PlaceType>
inline void CheckKernelLaunch(const char* op_type){};
#ifdef PADDLE_WITH_CUDA
template <>
inline void CheckKernelLaunch<::paddle::platform::CUDAPlace>(
const char* op_type) {
PADDLE_ENFORCE_CUDA_LAUNCH_SUCCESS(op_type);
};
#endif
template <typename PlaceType, bool at_end, size_t I, typename... KernelType>
struct OpKernelRegistrarFunctor;
......@@ -162,8 +173,9 @@ struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
RegisterKernelClass<PlaceType, T>(
op_type, library_type, customized_type_value,
[](const framework::ExecutionContext& ctx) {
[op_type](const framework::ExecutionContext& ctx) {
KERNEL_TYPE().Compute(ctx);
CheckKernelLaunch<PlaceType>(op_type);
});
constexpr auto size = std::tuple_size<std::tuple<KernelTypes...>>::value;
OpKernelRegistrarFunctor<PlaceType, I + 1 == size, I + 1, KernelTypes...>
......@@ -223,8 +235,13 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
void operator()(const char* op_type, const char* library_type,
int customized_type_value) const {
RegisterKernelClass<PlaceType, T>(op_type, library_type,
customized_type_value, Functor());
RegisterKernelClass<PlaceType, T>(
op_type, library_type, customized_type_value,
[op_type](const framework::ExecutionContext& ctx) {
Functor()(ctx);
CheckKernelLaunch<PlaceType>(op_type);
});
constexpr auto size =
std::tuple_size<std::tuple<DataTypeAndKernelType...>>::value;
......
......@@ -991,6 +991,16 @@ DEFINE_CUDA_STATUS_TYPE(ncclResult_t, ncclSuccess);
} \
} while (0)
#define PADDLE_ENFORCE_CUDA_LAUNCH_SUCCESS(OP) \
do { \
auto res = cudaGetLastError(); \
if (UNLIKELY(res != cudaSuccess)) { \
auto msg = ::paddle::platform::build_nvidia_error_msg(res); \
PADDLE_THROW(platform::errors::Fatal("CUDA error after kernel (%s): %s", \
OP, msg)); \
} \
} while (0)
inline void retry_sleep(unsigned milliseconds) {
#ifdef _WIN32
Sleep(milliseconds);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册