cuda_stream.h 4.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <cstdint>
#include <memory>
W
wanghuancoder 已提交
19

20
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
W
Wilber 已提交
21
#include "paddle/fluid/platform/device/gpu/gpu_types.h"
22 23 24 25 26 27 28 29
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/stream_callback_manager.h"

namespace paddle {
namespace platform {
namespace stream {

30
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
31 32 33 34 35 36

enum class Priority : uint8_t {
  kNull = 0x0,
  kHigh = 0x1,
  kNormal = 0x2,
};
37 38 39 40 41 42

enum class StreamFlag : uint8_t {
  kDefaultFlag = 0x0,
  kStreamNonBlocking = 0x1,
};

43
#endif
44
class CUDAStream final {
45
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
46

47 48
 public:
  CUDAStream() = default;
49
  explicit CUDAStream(const Place& place,
50 51 52
                      const Priority& priority = Priority::kNormal,
                      const StreamFlag& flag = StreamFlag::kDefaultFlag) {
    Init(place, priority, flag);
53 54 55
  }
  virtual ~CUDAStream() { Destroy(); }

56 57
  bool Init(const Place& place, const Priority& priority = Priority::kNormal,
            const StreamFlag& flag = StreamFlag::kDefaultFlag);
58 59 60 61 62 63 64

  template <typename Callback>
  void AddCallback(Callback&& callback) const {
    callback_manager_->AddCallback(callback);
  }

  template <typename Callback>
65 66 67
#ifdef PADDLE_WITH_HIP
  void RecordEvent(hipEvent_t ev, Callback callback) const {
    callback();
68
    PADDLE_ENFORCE_GPU_SUCCESS(hipEventRecord(ev, stream_));
69 70
  }
#else
71 72
  void RecordEvent(cudaEvent_t ev, Callback callback) const {
    callback();
73
    PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(ev, stream_));
74
  }
75
#endif
76

77 78
#ifdef PADDLE_WITH_HIP
  void RecordEvent(hipEvent_t ev) const {
79
    PADDLE_ENFORCE_GPU_SUCCESS(hipEventRecord(ev, stream_));
80 81
  }
#else
82
  void RecordEvent(cudaEvent_t ev) const {
83
    PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(ev, stream_));
84
  }
85
#endif
86

87 88
#ifdef PADDLE_WITH_HIP
  void WaitEvent(hipEvent_t ev) const {
89
    PADDLE_ENFORCE_GPU_SUCCESS(hipStreamWaitEvent(stream_, ev, 0));
90 91
  }
#else
92
  void WaitEvent(cudaEvent_t ev) const {
93
    PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamWaitEvent(stream_, ev, 0));
94
  }
95
#endif
96 97 98 99

  void Wait() const;
  void WaitCallback() const { callback_manager_->Wait(); }

100 101 102
#ifdef PADDLE_WITH_HIP
  const hipStream_t& raw_stream() const { return stream_; }
#else
103
  const cudaStream_t& raw_stream() const { return stream_; }
104
#endif
105 106
  void Destroy();

107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
  bool Query() const {
#ifdef PADDLE_WITH_HIP
    hipError_t err = hipStreamQuery(stream_);
    if (err == hipSuccess) {
      return true;
    }
    if (err == hipErrorNotReady) {
      return false;
    }
#else
    cudaError_t err = cudaStreamQuery(stream_);
    if (err == cudaSuccess) {
      return true;
    }
    if (err == cudaErrorNotReady) {
      return false;
    }
#endif

126
    PADDLE_ENFORCE_GPU_SUCCESS(err);
127 128 129
    return false;
  }

130
  void Synchronize() const { platform::GpuStreamSync(stream_); }
131

132 133
  const Place& GetPlace() const { return place_; }

W
Wilber 已提交
134 135 136
  // Note: Can only be used under thread_local semantics.
  void SetStream(gpuStream_t stream);

137 138
 private:
  Place place_;
W
Wilber 已提交
139
  bool owned_stream_{true};
140 141 142
#ifdef PADDLE_WITH_HIP
  hipStream_t stream_{nullptr};
#else
143
  cudaStream_t stream_{nullptr};
144
#endif
145
  Priority priority_{Priority::kNormal};
146
  std::unique_ptr<StreamCallbackManager<gpuStream_t>> callback_manager_;
147
#endif
148 149 150
  DISABLE_COPY_AND_ASSIGN(CUDAStream);
};

151
CUDAStream* get_current_stream(int deviceId);
152
CUDAStream* set_current_stream(CUDAStream* stream);
153 154 155 156

}  // namespace stream
}  // namespace platform
}  // namespace paddle