cuda_stream.cc 4.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/platform/stream/cuda_stream.h"
16

17
#include "paddle/fluid/platform/cuda_device_guard.h"
W
Wilber 已提交
18
#include "paddle/fluid/platform/device/gpu/gpu_types.h"
19
#include "paddle/fluid/platform/device_context.h"
20 21 22 23 24 25
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace platform {
namespace stream {

26 27
bool CUDAStream::Init(const Place& place,
                      const Priority& priority,
28
                      const StreamFlag& flag) {
29 30
  PADDLE_ENFORCE_EQ(is_gpu_place(place),
                    true,
31 32 33
                    platform::errors::InvalidArgument(
                        "Cuda stream must be created using cuda place."));
  place_ = place;
34
  CUDADeviceGuard guard(place_.device);
35
  if (priority == Priority::kHigh) {
36
#ifdef PADDLE_WITH_HIP
37
    PADDLE_ENFORCE_GPU_SUCCESS(hipStreamCreateWithPriority(
38
        &stream_, static_cast<unsigned int>(flag), -1));
39
#else
40
    PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreateWithPriority(
41
        &stream_, static_cast<unsigned int>(flag), -1));
42
#endif
43
  } else if (priority == Priority::kNormal) {
44
#ifdef PADDLE_WITH_HIP
45
    PADDLE_ENFORCE_GPU_SUCCESS(hipStreamCreateWithPriority(
46
        &stream_, static_cast<unsigned int>(flag), 0));
47
#else
48
    PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreateWithPriority(
49
        &stream_, static_cast<unsigned int>(flag), 0));
50
#endif
51
  }
52 53
  callback_manager_.reset(new StreamCallbackManager<gpuStream_t>(stream_));
  VLOG(3) << "GPUStream Init stream: " << stream_
54 55
          << ", priority: " << static_cast<int>(priority)
          << ", flag:" << static_cast<int>(flag);
56 57 58 59
  return true;
}

void CUDAStream::Destroy() {
60
  CUDADeviceGuard guard(place_.device);
61 62
  Wait();
  WaitCallback();
W
Wilber 已提交
63
  if (stream_ && owned_stream_) {
64
#ifdef PADDLE_WITH_HIP
65
    PADDLE_ENFORCE_GPU_SUCCESS(hipStreamDestroy(stream_));
66
#else
67
    PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamDestroy(stream_));
68
#endif
69 70 71 72 73
  }
  stream_ = nullptr;
}

void CUDAStream::Wait() const {
74 75 76 77 78 79 80 81 82 83 84
#ifdef PADDLE_WITH_HIP
  hipError_t e_sync = hipSuccess;
#if !defined(_WIN32)
  e_sync = hipStreamSynchronize(stream_);
#else
  while (e_sync = hipStreamQuery(stream_)) {
    if (e_sync == hipErrorNotReady) continue;
    break;
  }
#endif
#else
85 86 87 88 89 90 91 92 93
  cudaError_t e_sync = cudaSuccess;
#if !defined(_WIN32)
  e_sync = cudaStreamSynchronize(stream_);
#else
  while (e_sync = cudaStreamQuery(stream_)) {
    if (e_sync == cudaErrorNotReady) continue;
    break;
  }
#endif
94
#endif  // PADDLE_WITH_HIP
95

96
  PADDLE_ENFORCE_GPU_SUCCESS(e_sync);
97 98
}

W
Wilber 已提交
99 100 101 102 103 104 105 106 107 108 109 110 111 112
// Note: Can only be used under thread_local semantics.
void CUDAStream::SetStream(gpuStream_t stream) {
  if (owned_stream_ && stream_) {
#ifdef PADDLE_WITH_HIP
    PADDLE_ENFORCE_GPU_SUCCESS(hipStreamDestroy(stream_));
#else
    PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamDestroy(stream_));
#endif
  }
  owned_stream_ = false;
  stream_ = stream;
  callback_manager_.reset(new StreamCallbackManager<gpuStream_t>(stream_));
}

113 114 115 116 117 118 119 120 121 122
CUDAStream* get_current_stream(int deviceId) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
  if (deviceId == -1) {
    deviceId = platform::GetCurrentDeviceId();
  }

  auto& pool = platform::DeviceContextPool::Instance();

  platform::Place device = CUDAPlace(deviceId);

W
Wilber 已提交
123 124
  return static_cast<platform::CUDADeviceContext*>(pool.Get(device))
      ->GetCudaStream();
125 126 127 128 129 130 131
#else
  PADDLE_THROW(platform::errors::Unavailable(
      "Paddle is not compiled with CUDA. Cannot visit cuda current stream."));
  return nullptr;
#endif
}

132 133 134 135 136
CUDAStream* set_current_stream(CUDAStream* stream) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
  auto& device = stream->GetPlace();
  auto& pool = platform::DeviceContextPool::Instance();
  return static_cast<platform::CUDADeviceContext*>(pool.Get(device))
W
Wilber 已提交
137
      ->SetCudaStream(stream);
138 139
#else
  PADDLE_THROW(platform::errors::Unavailable(
W
Wilber 已提交
140 141 142
      "Paddle is not compiled with CUDA. Cannot visit cuda current"
      "stream."));
  return CUDAStream(nullptr);
143 144
#endif
}
145 146 147
}  // namespace stream
}  // namespace platform
}  // namespace paddle