cuda_stream.cc 4.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/platform/stream/cuda_stream.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
17
#include "paddle/fluid/platform/device_context.h"
18 19 20 21 22 23
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace platform {
namespace stream {

24 25
bool CUDAStream::Init(const Place& place, const Priority& priority,
                      const StreamFlag& flag) {
26 27 28 29
  PADDLE_ENFORCE_EQ(is_gpu_place(place), true,
                    platform::errors::InvalidArgument(
                        "Cuda stream must be created using cuda place."));
  place_ = place;
30
  CUDADeviceGuard guard(BOOST_GET_CONST(CUDAPlace, place_).device);
31
  if (priority == Priority::kHigh) {
32
#ifdef PADDLE_WITH_HIP
33
    PADDLE_ENFORCE_GPU_SUCCESS(hipStreamCreateWithPriority(
34
        &stream_, static_cast<unsigned int>(flag), -1));
35
#else
36
    PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreateWithPriority(
37
        &stream_, static_cast<unsigned int>(flag), -1));
38
#endif
39
  } else if (priority == Priority::kNormal) {
40
#ifdef PADDLE_WITH_HIP
41
    PADDLE_ENFORCE_GPU_SUCCESS(hipStreamCreateWithPriority(
42
        &stream_, static_cast<unsigned int>(flag), 0));
43
#else
44
    PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreateWithPriority(
45
        &stream_, static_cast<unsigned int>(flag), 0));
46
#endif
47
  }
48 49
  callback_manager_.reset(new StreamCallbackManager<gpuStream_t>(stream_));
  VLOG(3) << "GPUStream Init stream: " << stream_
50 51
          << ", priority: " << static_cast<int>(priority)
          << ", flag:" << static_cast<int>(flag);
52 53 54 55
  return true;
}

void CUDAStream::Destroy() {
56
  CUDADeviceGuard guard(BOOST_GET_CONST(CUDAPlace, place_).device);
57 58
  Wait();
  WaitCallback();
W
Wilber 已提交
59
  if (stream_ && owned_stream_) {
60
#ifdef PADDLE_WITH_HIP
61
    PADDLE_ENFORCE_GPU_SUCCESS(hipStreamDestroy(stream_));
62
#else
63
    PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamDestroy(stream_));
64
#endif
65 66 67 68 69
  }
  stream_ = nullptr;
}

void CUDAStream::Wait() const {
70 71 72 73 74 75 76 77 78 79 80
#ifdef PADDLE_WITH_HIP
  hipError_t e_sync = hipSuccess;
#if !defined(_WIN32)
  e_sync = hipStreamSynchronize(stream_);
#else
  while (e_sync = hipStreamQuery(stream_)) {
    if (e_sync == hipErrorNotReady) continue;
    break;
  }
#endif
#else
81 82 83 84 85 86 87 88 89
  cudaError_t e_sync = cudaSuccess;
#if !defined(_WIN32)
  e_sync = cudaStreamSynchronize(stream_);
#else
  while (e_sync = cudaStreamQuery(stream_)) {
    if (e_sync == cudaErrorNotReady) continue;
    break;
  }
#endif
90
#endif  // PADDLE_WITH_HIP
91

92
  PADDLE_ENFORCE_GPU_SUCCESS(e_sync);
93 94
}

W
Wilber 已提交
95 96 97 98 99 100 101 102 103 104 105 106 107 108
// Note: Can only be used under thread_local semantics.
void CUDAStream::SetStream(gpuStream_t stream) {
  if (owned_stream_ && stream_) {
#ifdef PADDLE_WITH_HIP
    PADDLE_ENFORCE_GPU_SUCCESS(hipStreamDestroy(stream_));
#else
    PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamDestroy(stream_));
#endif
  }
  owned_stream_ = false;
  stream_ = stream;
  callback_manager_.reset(new StreamCallbackManager<gpuStream_t>(stream_));
}

109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
CUDAStream* get_current_stream(int deviceId) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
  if (deviceId == -1) {
    deviceId = platform::GetCurrentDeviceId();
  }

  auto& pool = platform::DeviceContextPool::Instance();

  platform::Place device = CUDAPlace(deviceId);

  auto stream = static_cast<platform::CUDADeviceContext*>(pool.Get(device))
                    ->context()
                    ->Stream()
                    .get();
  return stream;
#else
  PADDLE_THROW(platform::errors::Unavailable(
      "Paddle is not compiled with CUDA. Cannot visit cuda current stream."));
  return nullptr;
#endif
}

131 132 133 134 135 136 137 138 139 140 141 142 143
CUDAStream* set_current_stream(CUDAStream* stream) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
  auto& device = stream->GetPlace();
  auto& pool = platform::DeviceContextPool::Instance();
  return static_cast<platform::CUDADeviceContext*>(pool.Get(device))
      ->context()
      ->SetStream(stream);
#else
  PADDLE_THROW(platform::errors::Unavailable(
      "Paddle is not compiled with CUDA. Cannot visit cuda current stream."));
  return nullptr;
#endif
}
144 145 146
}  // namespace stream
}  // namespace platform
}  // namespace paddle