stream_callback_manager.cc 2.9 KB
Newer Older
S
sneaxiy 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/platform/stream_callback_manager.h"
G
GaoWei8 已提交
16
#include <utility>
S
sneaxiy 已提交
17 18 19 20 21
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace platform {

22
#ifdef PADDLE_WITH_CUDA
S
fix bug  
sneaxiy 已提交
23
#if CUDA_VERSION >= 10000
S
fix bug  
sneaxiy 已提交
24
static void CUDART_CB StreamCallbackFunc(void *user_data)
S
fix bug  
sneaxiy 已提交
25 26 27 28
#else
static void CUDART_CB StreamCallbackFunc(cudaStream_t stream,
                                         cudaError_t status, void *user_data)
#endif
29 30 31 32 33
#endif

#if PADDLE_WITH_ASCEND_CL
static void StreamCallbackFunc(void *user_data)
#endif
S
fix bug  
sneaxiy 已提交
34 35 36 37 38
{
  std::unique_ptr<std::function<void()>> func(
      reinterpret_cast<std::function<void()> *>(user_data));
  (*func)();
}
S
sneaxiy 已提交
39

40 41
template <typename Stream>
StreamCallbackManager<Stream>::StreamCallbackManager(const Stream stream)
S
fix bug  
sneaxiy 已提交
42
    : stream_(stream), thread_pool_(1) {}
S
sneaxiy 已提交
43

44 45
template <typename Stream>
void StreamCallbackManager<Stream>::AddCallback(std::function<void()> callback) const {
S
fix bug  
sneaxiy 已提交
46 47 48 49 50 51 52 53
  auto *callback_func = new std::function<void()>(std::move(callback));
  auto *func = new std::function<void()>([this, callback_func] {
    std::lock_guard<std::mutex> lock(mtx_);
    last_future_ = thread_pool_.enqueue([callback_func] {
      std::unique_ptr<std::function<void()>> releaser(callback_func);
      (*callback_func)();
    });
  });
54
#ifdef PADDLE_WITH_CUDA
S
sneaxiy 已提交
55
#if CUDA_VERSION >= 10000
G
GaoWei8 已提交
56 57
  PADDLE_ENFORCE_CUDA_SUCCESS(
      cudaLaunchHostFunc(stream_, StreamCallbackFunc, func));
S
sneaxiy 已提交
58
#else
G
GaoWei8 已提交
59 60
  PADDLE_ENFORCE_CUDA_SUCCESS(
      cudaStreamAddCallback(stream_, StreamCallbackFunc, func, 0));
S
sneaxiy 已提交
61
#endif
62 63 64 65 66 67
#endif

#if PADDLE_WITH_ASCEND_CL
  PADDLE_ENFORCE_NPU_SUCCESS(aclrtLaunchCallback(StreamCallbackFunc, func,
                                                 ACL_CALLBACK_BLOCK, stream_));
#endif
S
sneaxiy 已提交
68 69
}

70 71 72
template <typename Stream>
void StreamCallbackManager<Stream>::Wait() const {
#ifdef PADDLE_WITH_CUDA
G
GaoWei8 已提交
73
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream_));
74 75 76 77
#endif
#ifdef PADDLE_WITH_ASCEND_CL
  PADDLE_ENFORCE_NPU_SUCCESS(aclrtSynchronizeStream(stream_));
#endif
S
fix bug  
sneaxiy 已提交
78 79 80 81 82 83
  {
    std::lock_guard<std::mutex> lock(mtx_);
    if (last_future_.valid()) {
      last_future_.wait();
    }
  }
S
sneaxiy 已提交
84 85
}

86 87 88 89 90 91 92
#ifdef PADDLE_WITH_CUDA
template struct StreamCallbackManager<cudaStream_t>;
#endif
#ifdef PADDLE_WITH_ASCEND_CL
template struct StreamCallbackManager<aclrtStream>;
#endif

S
sneaxiy 已提交
93 94
}  // namespace platform
}  // namespace paddle