diff --git a/paddle/fluid/framework/op_registry.h b/paddle/fluid/framework/op_registry.h index 818da7478b2392841d0b1b7221270b6f840465ec..9f0dc50774addc1fa7b674d329095dc61458e03a 100644 --- a/paddle/fluid/framework/op_registry.h +++ b/paddle/fluid/framework/op_registry.h @@ -134,6 +134,17 @@ class OpRegistry { static std::unique_ptr CreateOp(const OpDesc& op_desc); }; +template +inline void CheckKernelLaunch(const char* op_type){}; + +#ifdef PADDLE_WITH_CUDA +template <> +inline void CheckKernelLaunch<::paddle::platform::CUDAPlace>( + const char* op_type) { + PADDLE_ENFORCE_CUDA_LAUNCH_SUCCESS(op_type); +}; +#endif + template struct OpKernelRegistrarFunctor; @@ -162,8 +173,9 @@ struct OpKernelRegistrarFunctor { RegisterKernelClass( op_type, library_type, customized_type_value, - [](const framework::ExecutionContext& ctx) { + [op_type](const framework::ExecutionContext& ctx) { KERNEL_TYPE().Compute(ctx); + CheckKernelLaunch(op_type); }); constexpr auto size = std::tuple_size>::value; OpKernelRegistrarFunctor @@ -223,8 +235,13 @@ struct OpKernelRegistrarFunctorEx(op_type, library_type, - customized_type_value, Functor()); + RegisterKernelClass( + op_type, library_type, customized_type_value, + + [op_type](const framework::ExecutionContext& ctx) { + Functor()(ctx); + CheckKernelLaunch(op_type); + }); constexpr auto size = std::tuple_size>::value; diff --git a/paddle/fluid/platform/enforce.h b/paddle/fluid/platform/enforce.h index cfca3ceadf41a2e769569da7f56ac01d56ad2341..d42733823e669b03daa8f29dfa0c40be38de1069 100644 --- a/paddle/fluid/platform/enforce.h +++ b/paddle/fluid/platform/enforce.h @@ -991,6 +991,16 @@ DEFINE_CUDA_STATUS_TYPE(ncclResult_t, ncclSuccess); } \ } while (0) +#define PADDLE_ENFORCE_CUDA_LAUNCH_SUCCESS(OP) \ + do { \ + auto res = cudaGetLastError(); \ + if (UNLIKELY(res != cudaSuccess)) { \ + auto msg = ::paddle::platform::build_nvidia_error_msg(res); \ + PADDLE_THROW(platform::errors::Fatal("CUDA error after kernel (%s): %s", \ + OP, msg)); \ + } \ + } while (0) + inline void retry_sleep(unsigned milliseconds) { #ifdef _WIN32 Sleep(milliseconds);