未验证 提交 02d757da 编写于 作者: Z Zeng Jinle 提交者: GitHub

Merge pull request #14268 from sneaxiy/stream_callback_support_in_cuda10

Stream Callback Support in CUDA 10
...@@ -24,8 +24,6 @@ ...@@ -24,8 +24,6 @@
namespace paddle { namespace paddle {
namespace platform { namespace platform {
using StreamCallback = std::function<void(cudaStream_t, cudaError_t)>;
class StreamCallbackManager; class StreamCallbackManager;
struct StreamCallbackContext { struct StreamCallbackContext {
...@@ -35,7 +33,7 @@ struct StreamCallbackContext { ...@@ -35,7 +33,7 @@ struct StreamCallbackContext {
: manager_(manager), callback_(callback) {} : manager_(manager), callback_(callback) {}
const StreamCallbackManager *manager_; // do not own const StreamCallbackManager *manager_; // do not own
StreamCallback callback_; std::function<void()> callback_;
}; };
class StreamCallbackManager { class StreamCallbackManager {
...@@ -45,16 +43,18 @@ class StreamCallbackManager { ...@@ -45,16 +43,18 @@ class StreamCallbackManager {
template <typename Callback> template <typename Callback>
inline void AddCallback(Callback &&callback) const { inline void AddCallback(Callback &&callback) const {
AddCallbackWithStreamAndErrorInfo( auto *stream_callback_context =
[=](cudaStream_t, cudaError_t) { callback(); }); new StreamCallbackContext(this, std::forward<Callback>(callback));
} PADDLE_ENFORCE(
#if CUDA_VERSION >= 10000
template <typename Callback> cudaLaunchHostFunc(stream_, StreamCallbackManager::StreamCallbackFunc,
inline void AddCallbackWithStreamAndErrorInfo(Callback &&callback) const { stream_callback_context)
auto *stream_callback_context = new StreamCallbackContext(this, callback); #else
PADDLE_ENFORCE(cudaStreamAddCallback( cudaStreamAddCallback(stream_,
stream_, StreamCallbackManager::StreamCallbackFunc, StreamCallbackManager::StreamCallbackFunc,
stream_callback_context, 0)); stream_callback_context, 0)
#endif
); // NOLINT
} }
void Wait() const { thread_pool_.reset(new ThreadPool(1)); } void Wait() const { thread_pool_.reset(new ThreadPool(1)); }
...@@ -63,17 +63,21 @@ class StreamCallbackManager { ...@@ -63,17 +63,21 @@ class StreamCallbackManager {
const cudaStream_t stream_; const cudaStream_t stream_;
mutable std::unique_ptr<ThreadPool> thread_pool_; mutable std::unique_ptr<ThreadPool> thread_pool_;
// cudaStreamCallback cannot call CUDA API inside, so we have to use // cudaStreamCallback cannot call CUDA API inside, so we have to use
// thread_pool here // thread_pool here
#if CUDA_VERSION >= 10000
static void CUDART_CB StreamCallbackFunc(void *user_data)
#else
static void CUDART_CB StreamCallbackFunc(cudaStream_t stream, static void CUDART_CB StreamCallbackFunc(cudaStream_t stream,
cudaError_t status, cudaError_t status, void *user_data)
void *user_data) { #endif
{
auto *callback_context_ptr = auto *callback_context_ptr =
reinterpret_cast<StreamCallbackContext *>(user_data); reinterpret_cast<StreamCallbackContext *>(user_data);
callback_context_ptr->manager_->thread_pool_->enqueue([=]() { callback_context_ptr->manager_->thread_pool_->enqueue([=]() {
std::unique_ptr<StreamCallbackContext> callback_context( std::unique_ptr<StreamCallbackContext> callback_context(
callback_context_ptr); callback_context_ptr);
callback_context->callback_(stream, status); callback_context->callback_();
}); });
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册