stream_callback_manager.cc 3.3 KB
Newer Older
S
sneaxiy 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/platform/stream_callback_manager.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace platform {

21 22 23
#ifdef PADDLE_WITH_HIP
static void StreamCallbackFunc(gpuStream_t stream, gpuError_t status,
                               void *user_data)
24 25 26 27
#endif
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 10000
    static void CUDART_CB StreamCallbackFunc(void *user_data)
S
fix bug  
sneaxiy 已提交
28
#else
29 30 31 32 33 34 35
    static void CUDART_CB
    StreamCallbackFunc(cudaStream_t stream, cudaError_t status, void *user_data)
#endif
#endif

#if PADDLE_WITH_ASCEND_CL
        static void StreamCallbackFunc(void *user_data)
S
fix bug  
sneaxiy 已提交
36 37 38 39 40 41
#endif
{
  std::unique_ptr<std::function<void()>> func(
      reinterpret_cast<std::function<void()> *>(user_data));
  (*func)();
}
S
sneaxiy 已提交
42

43 44
template <typename Stream>
StreamCallbackManager<Stream>::StreamCallbackManager(const Stream stream)
S
fix bug  
sneaxiy 已提交
45
    : stream_(stream), thread_pool_(1) {}
S
sneaxiy 已提交
46

47 48 49
template <typename Stream>
void StreamCallbackManager<Stream>::AddCallback(
    std::function<void()> callback) const {
S
fix bug  
sneaxiy 已提交
50 51 52 53 54 55 56 57
  auto *callback_func = new std::function<void()>(std::move(callback));
  auto *func = new std::function<void()>([this, callback_func] {
    std::lock_guard<std::mutex> lock(mtx_);
    last_future_ = thread_pool_.enqueue([callback_func] {
      std::unique_ptr<std::function<void()>> releaser(callback_func);
      (*callback_func)();
    });
  });
58

59 60 61
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(
      hipStreamAddCallback(stream_, StreamCallbackFunc, func, 0));
62 63 64
#endif
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 10000
G
GaoWei8 已提交
65 66
  PADDLE_ENFORCE_CUDA_SUCCESS(
      cudaLaunchHostFunc(stream_, StreamCallbackFunc, func));
S
sneaxiy 已提交
67
#else
G
GaoWei8 已提交
68 69
  PADDLE_ENFORCE_CUDA_SUCCESS(
      cudaStreamAddCallback(stream_, StreamCallbackFunc, func, 0));
S
sneaxiy 已提交
70
#endif
71 72 73 74 75 76
#endif

#if PADDLE_WITH_ASCEND_CL
  PADDLE_ENFORCE_NPU_SUCCESS(aclrtLaunchCallback(StreamCallbackFunc, func,
                                                 ACL_CALLBACK_BLOCK, stream_));
#endif
S
sneaxiy 已提交
77 78
}

79 80
template <typename Stream>
void StreamCallbackManager<Stream>::Wait() const {
81 82
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream_));
83 84
#endif
#ifdef PADDLE_WITH_CUDA
G
GaoWei8 已提交
85
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream_));
86 87 88
#endif
#ifdef PADDLE_WITH_ASCEND_CL
  PADDLE_ENFORCE_NPU_SUCCESS(aclrtSynchronizeStream(stream_));
89
#endif
S
fix bug  
sneaxiy 已提交
90 91 92 93 94 95
  {
    std::lock_guard<std::mutex> lock(mtx_);
    if (last_future_.valid()) {
      last_future_.wait();
    }
  }
S
sneaxiy 已提交
96 97
}

98 99 100 101 102 103 104 105 106 107
#ifdef PADDLE_WITH_CUDA
template struct StreamCallbackManager<gpuStream_t>;
#endif
#ifdef PADDLE_WITH_HIP
template struct StreamCallbackManager<hipStream_t>;
#endif
#ifdef PADDLE_WITH_ASCEND_CL
template struct StreamCallbackManager<aclrtStream>;
#endif

S
sneaxiy 已提交
108 109
}  // namespace platform
}  // namespace paddle