cuda_stream.cc 4.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/platform/stream/cuda_stream.h"
16

17
#include "paddle/fluid/platform/cuda_device_guard.h"
W
Wilber 已提交
18
#include "paddle/fluid/platform/device/gpu/gpu_types.h"
19
#include "paddle/fluid/platform/device_context.h"
20 21 22 23 24 25
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace platform {
namespace stream {

26 27
bool CUDAStream::Init(const Place& place, const Priority& priority,
                      const StreamFlag& flag) {
28 29 30 31
  PADDLE_ENFORCE_EQ(is_gpu_place(place), true,
                    platform::errors::InvalidArgument(
                        "Cuda stream must be created using cuda place."));
  place_ = place;
32
  CUDADeviceGuard guard(place_.device);
33
  if (priority == Priority::kHigh) {
34
#ifdef PADDLE_WITH_HIP
35
    PADDLE_ENFORCE_GPU_SUCCESS(hipStreamCreateWithPriority(
36
        &stream_, static_cast<unsigned int>(flag), -1));
37
#else
38
    PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreateWithPriority(
39
        &stream_, static_cast<unsigned int>(flag), -1));
40
#endif
41
  } else if (priority == Priority::kNormal) {
42
#ifdef PADDLE_WITH_HIP
43
    PADDLE_ENFORCE_GPU_SUCCESS(hipStreamCreateWithPriority(
44
        &stream_, static_cast<unsigned int>(flag), 0));
45
#else
46
    PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreateWithPriority(
47
        &stream_, static_cast<unsigned int>(flag), 0));
48
#endif
49
  }
50 51
  callback_manager_.reset(new StreamCallbackManager<gpuStream_t>(stream_));
  VLOG(3) << "GPUStream Init stream: " << stream_
52 53
          << ", priority: " << static_cast<int>(priority)
          << ", flag:" << static_cast<int>(flag);
54 55 56 57
  return true;
}

void CUDAStream::Destroy() {
58
  CUDADeviceGuard guard(place_.device);
59 60
  Wait();
  WaitCallback();
W
Wilber 已提交
61
  if (stream_ && owned_stream_) {
62
#ifdef PADDLE_WITH_HIP
63
    PADDLE_ENFORCE_GPU_SUCCESS(hipStreamDestroy(stream_));
64
#else
65
    PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamDestroy(stream_));
66
#endif
67 68 69 70 71
  }
  stream_ = nullptr;
}

void CUDAStream::Wait() const {
72 73 74 75 76 77 78 79 80 81 82
#ifdef PADDLE_WITH_HIP
  hipError_t e_sync = hipSuccess;
#if !defined(_WIN32)
  e_sync = hipStreamSynchronize(stream_);
#else
  while (e_sync = hipStreamQuery(stream_)) {
    if (e_sync == hipErrorNotReady) continue;
    break;
  }
#endif
#else
83 84 85 86 87 88 89 90 91
  cudaError_t e_sync = cudaSuccess;
#if !defined(_WIN32)
  e_sync = cudaStreamSynchronize(stream_);
#else
  while (e_sync = cudaStreamQuery(stream_)) {
    if (e_sync == cudaErrorNotReady) continue;
    break;
  }
#endif
92
#endif  // PADDLE_WITH_HIP
93

94
  PADDLE_ENFORCE_GPU_SUCCESS(e_sync);
95 96
}

W
Wilber 已提交
97 98 99 100 101 102 103 104 105 106 107 108 109 110
// Note: Can only be used under thread_local semantics.
void CUDAStream::SetStream(gpuStream_t stream) {
  if (owned_stream_ && stream_) {
#ifdef PADDLE_WITH_HIP
    PADDLE_ENFORCE_GPU_SUCCESS(hipStreamDestroy(stream_));
#else
    PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamDestroy(stream_));
#endif
  }
  owned_stream_ = false;
  stream_ = stream;
  callback_manager_.reset(new StreamCallbackManager<gpuStream_t>(stream_));
}

111 112 113 114 115 116 117 118 119 120
CUDAStream* get_current_stream(int deviceId) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
  if (deviceId == -1) {
    deviceId = platform::GetCurrentDeviceId();
  }

  auto& pool = platform::DeviceContextPool::Instance();

  platform::Place device = CUDAPlace(deviceId);

W
Wilber 已提交
121 122
  return static_cast<platform::CUDADeviceContext*>(pool.Get(device))
      ->GetCudaStream();
123 124 125 126 127 128 129
#else
  PADDLE_THROW(platform::errors::Unavailable(
      "Paddle is not compiled with CUDA. Cannot visit cuda current stream."));
  return nullptr;
#endif
}

130 131 132 133 134
CUDAStream* set_current_stream(CUDAStream* stream) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
  auto& device = stream->GetPlace();
  auto& pool = platform::DeviceContextPool::Instance();
  return static_cast<platform::CUDADeviceContext*>(pool.Get(device))
W
Wilber 已提交
135
      ->SetCudaStream(stream);
136 137
#else
  PADDLE_THROW(platform::errors::Unavailable(
W
Wilber 已提交
138 139 140
      "Paddle is not compiled with CUDA. Cannot visit cuda current"
      "stream."));
  return CUDAStream(nullptr);
141 142
#endif
}
143 144 145
}  // namespace stream
}  // namespace platform
}  // namespace paddle