提交 86845536 编写于 作者: S sneaxiy

stream callback support in cuda 10

test=develop
上级 8fc05e03
......@@ -24,8 +24,6 @@
namespace paddle {
namespace platform {
using StreamCallback = std::function<void(cudaStream_t, cudaError_t)>;
class StreamCallbackManager;
struct StreamCallbackContext {
......@@ -35,7 +33,7 @@ struct StreamCallbackContext {
: manager_(manager), callback_(callback) {}
const StreamCallbackManager *manager_; // do not own
StreamCallback callback_;
std::function<void()> callback_;
};
class StreamCallbackManager {
......@@ -45,16 +43,18 @@ class StreamCallbackManager {
template <typename Callback>
inline void AddCallback(Callback &&callback) const {
AddCallbackWithStreamAndErrorInfo(
[=](cudaStream_t, cudaError_t) { callback(); });
}
template <typename Callback>
inline void AddCallbackWithStreamAndErrorInfo(Callback &&callback) const {
auto *stream_callback_context = new StreamCallbackContext(this, callback);
PADDLE_ENFORCE(cudaStreamAddCallback(
stream_, StreamCallbackManager::StreamCallbackFunc,
stream_callback_context, 0));
auto *stream_callback_context =
new StreamCallbackContext(this, std::forward<Callback>(callback));
PADDLE_ENFORCE(
#if CUDA_VERSION >= 10000
cudaLaunchHostFunc(stream_, StreamCallbackManager::StreamCallbackFunc,
stream_callback_context)
#else
cudaStreamAddCallback(stream_,
StreamCallbackManager::StreamCallbackFunc,
stream_callback_context, 0)
#endif
); // NOLINT
}
void Wait() const { thread_pool_.reset(new ThreadPool(1)); }
......@@ -63,17 +63,21 @@ class StreamCallbackManager {
const cudaStream_t stream_;
mutable std::unique_ptr<ThreadPool> thread_pool_;
// cudaStreamCallback cannot call CUDA API inside, so we have to use
// thread_pool here
// cudaStreamCallback cannot call CUDA API inside, so we have to use
// thread_pool here
#if CUDA_VERSION >= 10000
static void CUDART_CB StreamCallbackFunc(void *user_data)
#else
static void CUDART_CB StreamCallbackFunc(cudaStream_t stream,
cudaError_t status,
void *user_data) {
cudaError_t status, void *user_data)
#endif
{
auto *callback_context_ptr =
reinterpret_cast<StreamCallbackContext *>(user_data);
callback_context_ptr->manager_->thread_pool_->enqueue([=]() {
std::unique_ptr<StreamCallbackContext> callback_context(
callback_context_ptr);
callback_context->callback_(stream, status);
callback_context->callback_();
});
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册