stream_callback_manager.cc 3.9 KB
Newer Older
S
sneaxiy 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/platform/stream_callback_manager.h"
16
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
17
#include "paddle/fluid/platform/device/npu/npu_info.h"
18
#include "paddle/fluid/platform/enforce.h"
F
fwenguang 已提交
19 20 21 22
#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/platform/device/mlu/enforce.h"
#include "paddle/fluid/platform/device/mlu/mlu_info.h"
#endif
S
sneaxiy 已提交
23 24 25 26

namespace paddle {
namespace platform {

27 28 29
#ifdef PADDLE_WITH_HIP
static void StreamCallbackFunc(gpuStream_t stream, gpuError_t status,
                               void *user_data)
30 31 32 33
#endif
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 10000
    static void CUDART_CB StreamCallbackFunc(void *user_data)
S
fix bug  
sneaxiy 已提交
34
#else
35 36 37 38 39 40 41
    static void CUDART_CB
    StreamCallbackFunc(cudaStream_t stream, cudaError_t status, void *user_data)
#endif
#endif

#if PADDLE_WITH_ASCEND_CL
        static void StreamCallbackFunc(void *user_data)
S
fix bug  
sneaxiy 已提交
42
#endif
F
fwenguang 已提交
43 44 45
#if PADDLE_WITH_MLU
            static void StreamCallbackFunc(void *user_data)
#endif
S
fix bug  
sneaxiy 已提交
46 47 48 49 50
{
  std::unique_ptr<std::function<void()>> func(
      reinterpret_cast<std::function<void()> *>(user_data));
  (*func)();
}
S
sneaxiy 已提交
51

52 53
template <typename Stream>
StreamCallbackManager<Stream>::StreamCallbackManager(const Stream stream)
S
fix bug  
sneaxiy 已提交
54
    : stream_(stream), thread_pool_(1) {}
S
sneaxiy 已提交
55

56 57 58
template <typename Stream>
void StreamCallbackManager<Stream>::AddCallback(
    std::function<void()> callback) const {
S
fix bug  
sneaxiy 已提交
59 60 61 62 63 64 65 66
  auto *callback_func = new std::function<void()>(std::move(callback));
  auto *func = new std::function<void()>([this, callback_func] {
    std::lock_guard<std::mutex> lock(mtx_);
    last_future_ = thread_pool_.enqueue([callback_func] {
      std::unique_ptr<std::function<void()>> releaser(callback_func);
      (*callback_func)();
    });
  });
67

68
#ifdef PADDLE_WITH_HIP
69
  PADDLE_ENFORCE_GPU_SUCCESS(
70
      hipStreamAddCallback(stream_, StreamCallbackFunc, func, 0));
71 72 73
#endif
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 10000
74
  PADDLE_ENFORCE_GPU_SUCCESS(
G
GaoWei8 已提交
75
      cudaLaunchHostFunc(stream_, StreamCallbackFunc, func));
S
sneaxiy 已提交
76
#else
77
  PADDLE_ENFORCE_GPU_SUCCESS(
G
GaoWei8 已提交
78
      cudaStreamAddCallback(stream_, StreamCallbackFunc, func, 0));
S
sneaxiy 已提交
79
#endif
80 81 82
#endif

#if PADDLE_WITH_ASCEND_CL
83 84
  VLOG(3) << "aclrtLaunchCallback at stream: " << stream_;
  // TODO(zhiqiu): failed to call aclrtLaunchCallback
85
  NPULaunchCallback(StreamCallbackFunc, func, ACL_CALLBACK_BLOCK, stream_);
86
#endif
F
fwenguang 已提交
87 88 89 90 91 92 93

#if PADDLE_WITH_MLU
  VLOG(3) << "MLULaunchCallback at stream: " << stream_;
  LOG(ERROR) << "failed to call MLULaunchCallback, "
             << "because mlu not support StreamAddCallback yet. "
             << "function: " << func;
#endif
S
sneaxiy 已提交
94 95
}

96 97
template <typename Stream>
void StreamCallbackManager<Stream>::Wait() const {
98 99
#if defined(PADDLE_WITH_HIP) || defined(PADDLE_WITH_CUDA)
  platform::GpuStreamSync(stream_);
100
#endif
F
fwenguang 已提交
101 102 103
#ifdef PADDLE_WITH_MLU
  PADDLE_ENFORCE_MLU_SUCCESS(cnrtQueueSync(stream_));
#endif
104
#ifdef PADDLE_WITH_ASCEND_CL
105
  NPUStreamSync(stream_);
106
#endif
S
fix bug  
sneaxiy 已提交
107 108 109 110 111 112
  {
    std::lock_guard<std::mutex> lock(mtx_);
    if (last_future_.valid()) {
      last_future_.wait();
    }
  }
S
sneaxiy 已提交
113 114
}

115 116 117 118 119 120 121 122 123
#ifdef PADDLE_WITH_CUDA
template struct StreamCallbackManager<gpuStream_t>;
#endif
#ifdef PADDLE_WITH_HIP
template struct StreamCallbackManager<hipStream_t>;
#endif
#ifdef PADDLE_WITH_ASCEND_CL
template struct StreamCallbackManager<aclrtStream>;
#endif
F
fwenguang 已提交
124 125 126
#ifdef PADDLE_WITH_MLU
template struct StreamCallbackManager<mluStream>;
#endif
127

S
sneaxiy 已提交
128 129
}  // namespace platform
}  // namespace paddle