device_context.cc 12.7 KB
Newer Older
W
Wilber 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/phi/core/device_context.h"
16

Y
Yuang Liu 已提交
17 18 19 20
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/device/gpu/cuda/cuda_graph.h"
#endif

21
#include "paddle/phi/core/dense_tensor.h"
22
#include "paddle/phi/core/enforce.h"
23
#include "paddle/phi/core/selected_rows.h"
J
Jack Zhou 已提交
24
#include "paddle/phi/core/string_tensor.h"
W
Wilber 已提交
25

26
namespace phi {
27
using DataType = paddle::experimental::DataType;
W
Wilber 已提交
28 29 30 31 32

struct DeviceContext::Impl {
  Impl() = default;
  ~Impl() = default;

W
Wilber 已提交
33
  void SetAllocator(const Allocator* allocator) {
34 35
    PADDLE_ENFORCE_NOT_NULL(
        allocator,
36
        phi::errors::InvalidArgument(
37
            "Required allocator shall not be nullptr, but received nullptr."));
W
Wilber 已提交
38 39 40
    device_allocator_ = allocator;
  }

41 42 43
  void SetHostAllocator(const Allocator* allocator) {
    PADDLE_ENFORCE_NOT_NULL(
        allocator,
44
        phi::errors::InvalidArgument(
45 46 47 48 49 50 51
            "Required allocator shall not be nullptr, but received nullptr."));
    host_allocator_ = allocator;
  }

  void SetZeroAllocator(const Allocator* allocator) {
    PADDLE_ENFORCE_NOT_NULL(
        allocator,
52
        phi::errors::InvalidArgument(
53 54 55
            "Required allocator shall not be nullptr, but received nullptr."));
    zero_allocator_ = allocator;
  }
W
Wilber 已提交
56

W
wanghuancoder 已提交
57 58 59 60 61 62 63 64
  void SetPinnedAllocator(const Allocator* allocator) {
    PADDLE_ENFORCE_NOT_NULL(
        allocator,
        phi::errors::InvalidArgument(
            "Required allocator shall not be nullptr, but received nullptr."));
    pinned_allocator_ = allocator;
  }

Y
Yuang Liu 已提交
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
#ifdef PADDLE_WITH_CUDA
  void SetCUDAGraphAllocator(const Allocator* allocator) {
    // NOTE (Yuang): cuda graph allocator can be set to nullptr, so don't check
    // validation of the allocator here
    cuda_graph_allocator_ = allocator;
  }

  const Allocator& GetCUDAGraphAllocator() const {
    PADDLE_ENFORCE_NOT_NULL(cuda_graph_allocator_,
                            phi::errors::InvalidArgument(
                                "Required cuda_graph_allocator_ shall not be "
                                "nullptr, but received nullptr."));
    return *cuda_graph_allocator_;
  }

  bool IsCUDAGraphAllocatorValid() const {
    return cuda_graph_allocator_ != nullptr;
  }
#endif

W
Wilber 已提交
85
  const Allocator& GetAllocator() const {
86 87
    PADDLE_ENFORCE_NOT_NULL(
        device_allocator_,
88 89
        phi::errors::InvalidArgument("Required device_allocator_ shall not be "
                                     "nullptr, but received nullptr."));
W
Wilber 已提交
90 91
    return *device_allocator_;
  }
W
Wilber 已提交
92

W
Wilber 已提交
93
  const Allocator& GetHostAllocator() const {
94 95
    PADDLE_ENFORCE_NOT_NULL(
        host_allocator_,
96 97
        phi::errors::InvalidArgument("Required host_allocator_ shall not be "
                                     "nullptr, but received nullptr."));
W
Wilber 已提交
98 99
    return *host_allocator_;
  }
W
Wilber 已提交
100

101 102 103
  const Allocator& GetZeroAllocator() const {
    PADDLE_ENFORCE_NOT_NULL(
        zero_allocator_,
W
wanghuancoder 已提交
104
        phi::errors::InvalidArgument("Required zero_allocator_ shall not be "
105
                                     "nullptr, but received nullptr."));
106 107 108
    return *zero_allocator_;
  }

W
wanghuancoder 已提交
109 110 111 112 113 114 115 116
  const Allocator& GetPinnedAllocator() const {
    PADDLE_ENFORCE_NOT_NULL(
        pinned_allocator_,
        phi::errors::InvalidArgument("Required pinned_allocator_ shall not be "
                                     "nullptr, but received nullptr."));
    return *pinned_allocator_;
  }

117
  void* Alloc(TensorBase* tensor,
118
              const Place& place,
119
              DataType dtype = DataType::UNDEFINED,
W
wanghuancoder 已提交
120 121
              size_t requested_size = 0,
              bool pinned = false) const {
122 123
    PADDLE_ENFORCE_NOT_NULL(
        tensor,
124
        phi::errors::InvalidArgument(
125 126 127 128
            "Required tensor shall not be nullptr, but received nullptr."));
    if (dtype == DataType::UNDEFINED) {
      dtype = tensor->dtype();
    }
129 130 131 132 133 134
    // NOTE(paddle-dev): In case of tensor has already hold allocation and
    // is going to allocate allocation on new place, we will clear its holder
    // firstly and then re-alloc it.
    if (tensor->initialized() && tensor->place() != place) {
      ClearHolder(tensor);
    }
W
wanghuancoder 已提交
135 136 137
    auto* allocator = tensor->numel() == 0
                          ? zero_allocator_
                          : (pinned ? pinned_allocator_ : device_allocator_);
Y
Yuang Liu 已提交
138 139 140 141 142 143 144 145 146 147 148
#ifdef PADDLE_WITH_CUDA
    bool must_cuda_graph_allocator = (tensor->numel() != 0) && !pinned;
    if (must_cuda_graph_allocator && paddle::platform::is_gpu_place(place) &&
        paddle::platform::CUDAGraph::IsThisThreadCapturing()) {
      PADDLE_ENFORCE_NOT_NULL(cuda_graph_allocator_,
                              phi::errors::InvalidArgument(
                                  "Required cuda_graph_allocator_ shall not be "
                                  "nullptr, but received nullptr."));
      allocator = cuda_graph_allocator_;
    }
#endif
149 150 151 152 153
    return tensor->AllocateFrom(
        const_cast<Allocator*>(allocator), dtype, requested_size);
  }

  template <typename T>
154 155
  T* Alloc(TensorBase* tensor,
           const Place& place,
W
wanghuancoder 已提交
156 157
           size_t requested_size = 0,
           bool pinned = false) const {
158
    DataType dtype = paddle::experimental::CppTypeToDataType<T>::Type();
W
wanghuancoder 已提交
159
    return static_cast<T*>(Alloc(tensor, place, dtype, requested_size, pinned));
160
  }
W
Wilber 已提交
161

162 163 164 165 166
  void* HostAlloc(TensorBase* tensor,
                  DataType dtype = DataType::UNDEFINED,
                  size_t requested_size = 0) const {
    PADDLE_ENFORCE_NOT_NULL(
        tensor,
167
        phi::errors::InvalidArgument(
168 169 170 171
            "Required tensor shall not be nullptr, but received nullptr."));
    if (dtype == DataType::UNDEFINED) {
      dtype = tensor->dtype();
    }
172 173 174
    if (tensor->initialized() && tensor->place() != CPUPlace()) {
      ClearHolder(tensor);
    }
175 176 177 178 179 180
    auto* allocator = tensor->numel() == 0 ? zero_allocator_ : host_allocator_;
    return tensor->AllocateFrom(
        const_cast<Allocator*>(allocator), dtype, requested_size);
  }

  template <typename T>
181
  T* HostAlloc(phi::TensorBase* tensor, size_t requested_size = 0) const {
182 183 184 185
    DataType dtype = paddle::experimental::CppTypeToDataType<T>::Type();
    return static_cast<T*>(HostAlloc(tensor, dtype, requested_size));
  }

W
Wilber 已提交
186 187 188
  void SetGenerator(Generator* gen) {
    PADDLE_ENFORCE_NOT_NULL(
        gen,
189
        phi::errors::InvalidArgument(
W
Wilber 已提交
190
            "Required generator shall not be nullptr, but received nullptr."));
L
Leo Chen 已提交
191
    device_generator_ = gen;
W
Wilber 已提交
192 193 194 195
  }

  Generator* GetGenerator() const {
    PADDLE_ENFORCE_NOT_NULL(
L
Leo Chen 已提交
196
        device_generator_,
197 198
        phi::errors::InvalidArgument("Required generator_ shall not be "
                                     "nullptr, but received nullptr."));
L
Leo Chen 已提交
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
    return device_generator_;
  }

  void SetHostGenerator(Generator* gen) {
    PADDLE_ENFORCE_NOT_NULL(
        gen,
        phi::errors::InvalidArgument(
            "Required generator shall not be nullptr, but received nullptr."));
    host_generator_ = gen;
  }

  Generator* GetHostGenerator() const {
    PADDLE_ENFORCE_NOT_NULL(
        host_generator_,
        phi::errors::InvalidArgument("Required generator_ shall not be "
                                     "nullptr, but received nullptr."));
    return host_generator_;
W
Wilber 已提交
216 217
  }

218
 private:
219 220 221 222 223 224 225
  void ClearHolder(TensorBase* tensor) const {
    if (!tensor->initialized()) return;

    if (DenseTensor::classof(tensor)) {
      static_cast<DenseTensor*>(tensor)->clear();
    } else if (SelectedRows::classof(tensor)) {
      static_cast<SelectedRows*>(tensor)->mutable_value()->clear();
J
Jack Zhou 已提交
226 227
    } else if (StringTensor::classof(tensor)) {
      static_cast<StringTensor*>(tensor)->clear();
228 229 230 231 232 233
    } else {
      PADDLE_THROW(errors::Unimplemented(
          "Only support DenseTensor and SelectedRows now."));
    }
  }

234 235 236
  const Allocator* device_allocator_{nullptr};
  const Allocator* host_allocator_{nullptr};
  const Allocator* zero_allocator_{nullptr};
W
wanghuancoder 已提交
237
  const Allocator* pinned_allocator_{nullptr};
Y
Yuang Liu 已提交
238 239 240
#ifdef PADDLE_WITH_CUDA
  const Allocator* cuda_graph_allocator_{nullptr};
#endif
L
Leo Chen 已提交
241 242
  Generator* device_generator_{nullptr};
  Generator* host_generator_{nullptr};
W
Wilber 已提交
243 244 245 246 247
};

DeviceContext::DeviceContext() { impl_ = std::make_unique<Impl>(); }

DeviceContext::DeviceContext(const DeviceContext& other) {
248
  impl_->SetHostAllocator(&other.GetHostAllocator());
W
Wilber 已提交
249
  impl_->SetAllocator(&other.GetAllocator());
250
  impl_->SetZeroAllocator(&other.GetZeroAllocator());
W
wanghuancoder 已提交
251
  impl_->SetPinnedAllocator(&other.GetPinnedAllocator());
L
Leo Chen 已提交
252 253
  impl_->SetHostGenerator(other.GetHostGenerator());
  impl_->SetGenerator(other.GetGenerator());
Y
Yuang Liu 已提交
254 255 256 257 258
#ifdef PADDLE_WITH_CUDA
  if (other.IsCUDAGraphAllocatorValid()) {
    impl_->SetCUDAGraphAllocator(&other.GetCUDAGraphAllocator());
  }
#endif
W
Wilber 已提交
259 260 261 262 263 264
}

DeviceContext::DeviceContext(DeviceContext&& other) {
  impl_ = std::move(other.impl_);
}

265
DeviceContext& DeviceContext::operator=(DeviceContext&& other) = default;
266

W
Wilber 已提交
267 268
DeviceContext::~DeviceContext() = default;

W
Wilber 已提交
269 270
void DeviceContext::SetAllocator(const Allocator* allocator) {
  impl_->SetAllocator(allocator);
W
Wilber 已提交
271 272
}

W
Wilber 已提交
273 274
const Allocator& DeviceContext::GetAllocator() const {
  return impl_->GetAllocator();
275 276 277 278
}

void DeviceContext::SetHostAllocator(const Allocator* allocator) {
  impl_->SetHostAllocator(allocator);
W
Wilber 已提交
279 280 281 282
}

const Allocator& DeviceContext::GetHostAllocator() const {
  return impl_->GetHostAllocator();
W
Wilber 已提交
283 284
}

Y
Yuang Liu 已提交
285 286 287 288 289 290 291 292 293 294 295 296 297 298
#ifdef PADDLE_WITH_CUDA
void DeviceContext::SetCUDAGraphAllocator(const Allocator* allocator) {
  impl_->SetCUDAGraphAllocator(allocator);
}

const Allocator& DeviceContext::GetCUDAGraphAllocator() const {
  return impl_->GetCUDAGraphAllocator();
}

bool DeviceContext::IsCUDAGraphAllocatorValid() const {
  return impl_->IsCUDAGraphAllocatorValid();
}
#endif

299 300
void DeviceContext::SetZeroAllocator(const Allocator* allocator) {
  impl_->SetZeroAllocator(allocator);
W
Wilber 已提交
301 302
}

303 304 305
const Allocator& DeviceContext::GetZeroAllocator() const {
  return impl_->GetZeroAllocator();
}
W
Wilber 已提交
306

W
wanghuancoder 已提交
307 308 309 310 311 312 313
void DeviceContext::SetPinnedAllocator(const Allocator* allocator) {
  impl_->SetPinnedAllocator(allocator);
}
const Allocator& DeviceContext::GetPinnedAllocator() const {
  return impl_->GetPinnedAllocator();
}

314 315
void* DeviceContext::Alloc(TensorBase* tensor,
                           DataType dtype,
W
wanghuancoder 已提交
316 317 318
                           size_t requested_size,
                           bool pinned) const {
  return impl_->Alloc(tensor, GetPlace(), dtype, requested_size, pinned);
W
Wilber 已提交
319
}
W
Wilber 已提交
320

321
template <typename T>
W
wanghuancoder 已提交
322 323 324 325
T* DeviceContext::Alloc(TensorBase* tensor,
                        size_t requested_size,
                        bool pinned) const {
  return impl_->Alloc<T>(tensor, GetPlace(), requested_size, pinned);
326 327 328 329 330 331 332 333 334 335 336 337 338 339
}

void* DeviceContext::HostAlloc(TensorBase* tensor,
                               DataType dtype,
                               size_t requested_size) const {
  return impl_->HostAlloc(tensor, dtype, requested_size);
}

template <typename T>
T* DeviceContext::HostAlloc(TensorBase* tensor, size_t requested_size) const {
  return impl_->HostAlloc<T>(tensor, requested_size);
}

#define DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(dtype)              \
W
wanghuancoder 已提交
340 341
  template dtype* DeviceContext::Alloc(                              \
      TensorBase* tensor, size_t requested_size, bool pinned) const; \
342 343 344 345 346 347 348 349 350 351 352 353 354 355 356
  template dtype* DeviceContext::HostAlloc(TensorBase* tensor,       \
                                           size_t requested_size) const;

DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(bool)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(int8_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(uint8_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(int16_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(int32_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(int64_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(float)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(double)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::bfloat16)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::float16)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::complex64)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::complex128)
J
Jack Zhou 已提交
357
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::pstring)
358 359 360

#undef DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION

W
Wilber 已提交
361 362 363 364
void DeviceContext::SetGenerator(Generator* gen) { impl_->SetGenerator(gen); }

Generator* DeviceContext::GetGenerator() const { return impl_->GetGenerator(); }

L
Leo Chen 已提交
365 366 367 368 369 370 371 372
void DeviceContext::SetHostGenerator(Generator* gen) {
  impl_->SetHostGenerator(gen);
}

Generator* DeviceContext::GetHostGenerator() const {
  return impl_->GetHostGenerator();
}

373
}  // namespace phi