stream_callback_manager.cc 3.3 KB
Newer Older
S
sneaxiy 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/platform/stream_callback_manager.h"
16
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
17
#include "paddle/fluid/platform/device/npu/npu_info.h"
18
#include "paddle/fluid/platform/enforce.h"
S
sneaxiy 已提交
19 20 21 22

namespace paddle {
namespace platform {

23 24 25
#ifdef PADDLE_WITH_HIP
static void StreamCallbackFunc(gpuStream_t stream, gpuError_t status,
                               void *user_data)
26 27 28 29
#endif
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 10000
    static void CUDART_CB StreamCallbackFunc(void *user_data)
S
fix bug  
sneaxiy 已提交
30
#else
31 32 33 34 35 36 37
    static void CUDART_CB
    StreamCallbackFunc(cudaStream_t stream, cudaError_t status, void *user_data)
#endif
#endif

#if PADDLE_WITH_ASCEND_CL
        static void StreamCallbackFunc(void *user_data)
S
fix bug  
sneaxiy 已提交
38 39 40 41 42 43
#endif
{
  std::unique_ptr<std::function<void()>> func(
      reinterpret_cast<std::function<void()> *>(user_data));
  (*func)();
}
S
sneaxiy 已提交
44

45 46
template <typename Stream>
StreamCallbackManager<Stream>::StreamCallbackManager(const Stream stream)
S
fix bug  
sneaxiy 已提交
47
    : stream_(stream), thread_pool_(1) {}
S
sneaxiy 已提交
48

49 50 51
template <typename Stream>
void StreamCallbackManager<Stream>::AddCallback(
    std::function<void()> callback) const {
S
fix bug  
sneaxiy 已提交
52 53 54 55 56 57 58 59
  auto *callback_func = new std::function<void()>(std::move(callback));
  auto *func = new std::function<void()>([this, callback_func] {
    std::lock_guard<std::mutex> lock(mtx_);
    last_future_ = thread_pool_.enqueue([callback_func] {
      std::unique_ptr<std::function<void()>> releaser(callback_func);
      (*callback_func)();
    });
  });
60

61
#ifdef PADDLE_WITH_HIP
62
  PADDLE_ENFORCE_GPU_SUCCESS(
63
      hipStreamAddCallback(stream_, StreamCallbackFunc, func, 0));
64 65 66
#endif
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 10000
67
  PADDLE_ENFORCE_GPU_SUCCESS(
G
GaoWei8 已提交
68
      cudaLaunchHostFunc(stream_, StreamCallbackFunc, func));
S
sneaxiy 已提交
69
#else
70
  PADDLE_ENFORCE_GPU_SUCCESS(
G
GaoWei8 已提交
71
      cudaStreamAddCallback(stream_, StreamCallbackFunc, func, 0));
S
sneaxiy 已提交
72
#endif
73 74 75
#endif

#if PADDLE_WITH_ASCEND_CL
76 77
  VLOG(3) << "aclrtLaunchCallback at stream: " << stream_;
  // TODO(zhiqiu): failed to call aclrtLaunchCallback
78
  NPULaunchCallback(StreamCallbackFunc, func, ACL_CALLBACK_BLOCK, stream_);
79
#endif
S
sneaxiy 已提交
80 81
}

82 83
template <typename Stream>
void StreamCallbackManager<Stream>::Wait() const {
84 85
#if defined(PADDLE_WITH_HIP) || defined(PADDLE_WITH_CUDA)
  platform::GpuStreamSync(stream_);
86 87
#endif
#ifdef PADDLE_WITH_ASCEND_CL
88
  NPUStreamSync(stream_);
89
#endif
S
fix bug  
sneaxiy 已提交
90 91 92 93 94 95
  {
    std::lock_guard<std::mutex> lock(mtx_);
    if (last_future_.valid()) {
      last_future_.wait();
    }
  }
S
sneaxiy 已提交
96 97
}

98 99 100 101 102 103 104 105 106 107
#ifdef PADDLE_WITH_CUDA
template struct StreamCallbackManager<gpuStream_t>;
#endif
#ifdef PADDLE_WITH_HIP
template struct StreamCallbackManager<hipStream_t>;
#endif
#ifdef PADDLE_WITH_ASCEND_CL
template struct StreamCallbackManager<aclrtStream>;
#endif

S
sneaxiy 已提交
108 109
}  // namespace platform
}  // namespace paddle