device_context.cc 13.8 KB
Newer Older
W
Wilber 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/phi/core/device_context.h"
16

Y
Yuang Liu 已提交
17 18 19 20
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/device/gpu/cuda/cuda_graph.h"
#endif

21
#include "paddle/phi/core/dense_tensor.h"
22
#include "paddle/phi/core/enforce.h"
23
#include "paddle/phi/core/selected_rows.h"
J
Jack Zhou 已提交
24
#include "paddle/phi/core/string_tensor.h"
W
Wilber 已提交
25

26
namespace phi {
27
using DataType = paddle::experimental::DataType;
W
Wilber 已提交
28 29 30 31 32

struct DeviceContext::Impl {
  Impl() = default;
  ~Impl() = default;

W
Wilber 已提交
33
  void SetAllocator(const Allocator* allocator) {
34 35
    PADDLE_ENFORCE_NOT_NULL(
        allocator,
36
        phi::errors::InvalidArgument(
37
            "Required allocator shall not be nullptr, but received nullptr."));
W
Wilber 已提交
38 39 40
    device_allocator_ = allocator;
  }

41 42 43
  void SetHostAllocator(const Allocator* allocator) {
    PADDLE_ENFORCE_NOT_NULL(
        allocator,
44
        phi::errors::InvalidArgument(
45 46 47 48 49 50 51
            "Required allocator shall not be nullptr, but received nullptr."));
    host_allocator_ = allocator;
  }

  void SetZeroAllocator(const Allocator* allocator) {
    PADDLE_ENFORCE_NOT_NULL(
        allocator,
52
        phi::errors::InvalidArgument(
53 54 55
            "Required allocator shall not be nullptr, but received nullptr."));
    zero_allocator_ = allocator;
  }
W
Wilber 已提交
56

57 58 59 60 61 62 63 64
  void SetHostZeroAllocator(const Allocator* allocator) {
    PADDLE_ENFORCE_NOT_NULL(
        allocator,
        phi::errors::InvalidArgument(
            "Required allocator shall not be nullptr, but received nullptr."));
    host_zero_allocator_ = allocator;
  }

W
wanghuancoder 已提交
65 66 67 68 69 70 71 72
  void SetPinnedAllocator(const Allocator* allocator) {
    PADDLE_ENFORCE_NOT_NULL(
        allocator,
        phi::errors::InvalidArgument(
            "Required allocator shall not be nullptr, but received nullptr."));
    pinned_allocator_ = allocator;
  }

Y
Yuang Liu 已提交
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
#ifdef PADDLE_WITH_CUDA
  void SetCUDAGraphAllocator(const Allocator* allocator) {
    // NOTE (Yuang): cuda graph allocator can be set to nullptr, so don't check
    // validation of the allocator here
    cuda_graph_allocator_ = allocator;
  }

  const Allocator& GetCUDAGraphAllocator() const {
    PADDLE_ENFORCE_NOT_NULL(cuda_graph_allocator_,
                            phi::errors::InvalidArgument(
                                "Required cuda_graph_allocator_ shall not be "
                                "nullptr, but received nullptr."));
    return *cuda_graph_allocator_;
  }

  bool IsCUDAGraphAllocatorValid() const {
    return cuda_graph_allocator_ != nullptr;
  }
#endif

W
Wilber 已提交
93
  const Allocator& GetAllocator() const {
94 95
    PADDLE_ENFORCE_NOT_NULL(
        device_allocator_,
96 97
        phi::errors::InvalidArgument("Required device_allocator_ shall not be "
                                     "nullptr, but received nullptr."));
W
Wilber 已提交
98 99
    return *device_allocator_;
  }
W
Wilber 已提交
100

W
Wilber 已提交
101
  const Allocator& GetHostAllocator() const {
102 103
    PADDLE_ENFORCE_NOT_NULL(
        host_allocator_,
104 105
        phi::errors::InvalidArgument("Required host_allocator_ shall not be "
                                     "nullptr, but received nullptr."));
W
Wilber 已提交
106 107
    return *host_allocator_;
  }
W
Wilber 已提交
108

109 110 111
  const Allocator& GetZeroAllocator() const {
    PADDLE_ENFORCE_NOT_NULL(
        zero_allocator_,
W
wanghuancoder 已提交
112
        phi::errors::InvalidArgument("Required zero_allocator_ shall not be "
113
                                     "nullptr, but received nullptr."));
114 115 116
    return *zero_allocator_;
  }

117 118 119 120 121 122 123 124
  const Allocator& GetHostZeroAllocator() const {
    PADDLE_ENFORCE_NOT_NULL(
        host_zero_allocator_,
        phi::errors::InvalidArgument("Required zero_allocator_ shall not be "
                                     "nullptr, but received nullptr."));
    return *host_zero_allocator_;
  }

W
wanghuancoder 已提交
125 126 127 128 129 130 131 132
  const Allocator& GetPinnedAllocator() const {
    PADDLE_ENFORCE_NOT_NULL(
        pinned_allocator_,
        phi::errors::InvalidArgument("Required pinned_allocator_ shall not be "
                                     "nullptr, but received nullptr."));
    return *pinned_allocator_;
  }

133
  void* Alloc(TensorBase* tensor,
134
              const Place& place,
135
              DataType dtype = DataType::UNDEFINED,
W
wanghuancoder 已提交
136 137
              size_t requested_size = 0,
              bool pinned = false) const {
138 139
    PADDLE_ENFORCE_NOT_NULL(
        tensor,
140
        phi::errors::InvalidArgument(
141 142 143 144
            "Required tensor shall not be nullptr, but received nullptr."));
    if (dtype == DataType::UNDEFINED) {
      dtype = tensor->dtype();
    }
145 146 147 148 149 150
    // NOTE(paddle-dev): In case of tensor has already hold allocation and
    // is going to allocate allocation on new place, we will clear its holder
    // firstly and then re-alloc it.
    if (tensor->initialized() && tensor->place() != place) {
      ClearHolder(tensor);
    }
W
wanghuancoder 已提交
151 152 153
    auto* allocator = tensor->numel() == 0
                          ? zero_allocator_
                          : (pinned ? pinned_allocator_ : device_allocator_);
Y
Yuang Liu 已提交
154 155 156 157 158 159 160 161 162 163 164
#ifdef PADDLE_WITH_CUDA
    bool must_cuda_graph_allocator = (tensor->numel() != 0) && !pinned;
    if (must_cuda_graph_allocator && paddle::platform::is_gpu_place(place) &&
        paddle::platform::CUDAGraph::IsThisThreadCapturing()) {
      PADDLE_ENFORCE_NOT_NULL(cuda_graph_allocator_,
                              phi::errors::InvalidArgument(
                                  "Required cuda_graph_allocator_ shall not be "
                                  "nullptr, but received nullptr."));
      allocator = cuda_graph_allocator_;
    }
#endif
165 166 167 168 169
    return tensor->AllocateFrom(
        const_cast<Allocator*>(allocator), dtype, requested_size);
  }

  template <typename T>
170 171
  T* Alloc(TensorBase* tensor,
           const Place& place,
W
wanghuancoder 已提交
172 173
           size_t requested_size = 0,
           bool pinned = false) const {
174
    DataType dtype = paddle::experimental::CppTypeToDataType<T>::Type();
W
wanghuancoder 已提交
175
    return static_cast<T*>(Alloc(tensor, place, dtype, requested_size, pinned));
176
  }
W
Wilber 已提交
177

178 179 180 181 182
  void* HostAlloc(TensorBase* tensor,
                  DataType dtype = DataType::UNDEFINED,
                  size_t requested_size = 0) const {
    PADDLE_ENFORCE_NOT_NULL(
        tensor,
183
        phi::errors::InvalidArgument(
184 185 186 187
            "Required tensor shall not be nullptr, but received nullptr."));
    if (dtype == DataType::UNDEFINED) {
      dtype = tensor->dtype();
    }
188 189 190
    if (tensor->initialized() && tensor->place() != CPUPlace()) {
      ClearHolder(tensor);
    }
191 192
    auto* allocator =
        tensor->numel() == 0 ? host_zero_allocator_ : host_allocator_;
193 194 195 196 197
    return tensor->AllocateFrom(
        const_cast<Allocator*>(allocator), dtype, requested_size);
  }

  template <typename T>
198
  T* HostAlloc(phi::TensorBase* tensor, size_t requested_size = 0) const {
199 200 201 202
    DataType dtype = paddle::experimental::CppTypeToDataType<T>::Type();
    return static_cast<T*>(HostAlloc(tensor, dtype, requested_size));
  }

W
Wilber 已提交
203 204 205
  void SetGenerator(Generator* gen) {
    PADDLE_ENFORCE_NOT_NULL(
        gen,
206
        phi::errors::InvalidArgument(
W
Wilber 已提交
207
            "Required generator shall not be nullptr, but received nullptr."));
L
Leo Chen 已提交
208
    device_generator_ = gen;
W
Wilber 已提交
209 210 211 212
  }

  Generator* GetGenerator() const {
    PADDLE_ENFORCE_NOT_NULL(
L
Leo Chen 已提交
213
        device_generator_,
214 215
        phi::errors::InvalidArgument("Required generator_ shall not be "
                                     "nullptr, but received nullptr."));
L
Leo Chen 已提交
216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232
    return device_generator_;
  }

  void SetHostGenerator(Generator* gen) {
    PADDLE_ENFORCE_NOT_NULL(
        gen,
        phi::errors::InvalidArgument(
            "Required generator shall not be nullptr, but received nullptr."));
    host_generator_ = gen;
  }

  Generator* GetHostGenerator() const {
    PADDLE_ENFORCE_NOT_NULL(
        host_generator_,
        phi::errors::InvalidArgument("Required generator_ shall not be "
                                     "nullptr, but received nullptr."));
    return host_generator_;
W
Wilber 已提交
233 234
  }

235
 private:
236 237 238 239 240 241 242
  void ClearHolder(TensorBase* tensor) const {
    if (!tensor->initialized()) return;

    if (DenseTensor::classof(tensor)) {
      static_cast<DenseTensor*>(tensor)->clear();
    } else if (SelectedRows::classof(tensor)) {
      static_cast<SelectedRows*>(tensor)->mutable_value()->clear();
J
Jack Zhou 已提交
243 244
    } else if (StringTensor::classof(tensor)) {
      static_cast<StringTensor*>(tensor)->clear();
245 246 247 248 249 250
    } else {
      PADDLE_THROW(errors::Unimplemented(
          "Only support DenseTensor and SelectedRows now."));
    }
  }

251 252 253
  const Allocator* device_allocator_{nullptr};
  const Allocator* host_allocator_{nullptr};
  const Allocator* zero_allocator_{nullptr};
254
  const Allocator* host_zero_allocator_{nullptr};
W
wanghuancoder 已提交
255
  const Allocator* pinned_allocator_{nullptr};
Y
Yuang Liu 已提交
256 257 258
#ifdef PADDLE_WITH_CUDA
  const Allocator* cuda_graph_allocator_{nullptr};
#endif
L
Leo Chen 已提交
259 260
  Generator* device_generator_{nullptr};
  Generator* host_generator_{nullptr};
W
Wilber 已提交
261 262 263 264 265
};

DeviceContext::DeviceContext() { impl_ = std::make_unique<Impl>(); }

DeviceContext::DeviceContext(const DeviceContext& other) {
266
  impl_->SetHostAllocator(&other.GetHostAllocator());
W
Wilber 已提交
267
  impl_->SetAllocator(&other.GetAllocator());
268
  impl_->SetZeroAllocator(&other.GetZeroAllocator());
269
  impl_->SetHostZeroAllocator(&other.GetHostZeroAllocator());
W
wanghuancoder 已提交
270
  impl_->SetPinnedAllocator(&other.GetPinnedAllocator());
L
Leo Chen 已提交
271 272
  impl_->SetHostGenerator(other.GetHostGenerator());
  impl_->SetGenerator(other.GetGenerator());
Y
Yuang Liu 已提交
273 274 275 276 277
#ifdef PADDLE_WITH_CUDA
  if (other.IsCUDAGraphAllocatorValid()) {
    impl_->SetCUDAGraphAllocator(&other.GetCUDAGraphAllocator());
  }
#endif
W
Wilber 已提交
278 279 280 281 282 283
}

DeviceContext::DeviceContext(DeviceContext&& other) {
  impl_ = std::move(other.impl_);
}

284
DeviceContext& DeviceContext::operator=(DeviceContext&& other) = default;
285

W
Wilber 已提交
286 287
DeviceContext::~DeviceContext() = default;

W
Wilber 已提交
288 289
void DeviceContext::SetAllocator(const Allocator* allocator) {
  impl_->SetAllocator(allocator);
W
Wilber 已提交
290 291
}

W
Wilber 已提交
292 293
const Allocator& DeviceContext::GetAllocator() const {
  return impl_->GetAllocator();
294 295 296 297
}

void DeviceContext::SetHostAllocator(const Allocator* allocator) {
  impl_->SetHostAllocator(allocator);
W
Wilber 已提交
298 299 300 301
}

const Allocator& DeviceContext::GetHostAllocator() const {
  return impl_->GetHostAllocator();
W
Wilber 已提交
302 303
}

Y
Yuang Liu 已提交
304 305 306 307 308 309 310 311 312 313 314 315 316 317
#ifdef PADDLE_WITH_CUDA
void DeviceContext::SetCUDAGraphAllocator(const Allocator* allocator) {
  impl_->SetCUDAGraphAllocator(allocator);
}

const Allocator& DeviceContext::GetCUDAGraphAllocator() const {
  return impl_->GetCUDAGraphAllocator();
}

bool DeviceContext::IsCUDAGraphAllocatorValid() const {
  return impl_->IsCUDAGraphAllocatorValid();
}
#endif

318 319
void DeviceContext::SetZeroAllocator(const Allocator* allocator) {
  impl_->SetZeroAllocator(allocator);
W
Wilber 已提交
320 321
}

322 323 324 325
void DeviceContext::SetHostZeroAllocator(const Allocator* allocator) {
  impl_->SetHostZeroAllocator(allocator);
}

326 327 328
const Allocator& DeviceContext::GetZeroAllocator() const {
  return impl_->GetZeroAllocator();
}
W
Wilber 已提交
329

330 331 332 333
const Allocator& DeviceContext::GetHostZeroAllocator() const {
  return impl_->GetHostZeroAllocator();
}

W
wanghuancoder 已提交
334 335 336 337 338 339 340
void DeviceContext::SetPinnedAllocator(const Allocator* allocator) {
  impl_->SetPinnedAllocator(allocator);
}
const Allocator& DeviceContext::GetPinnedAllocator() const {
  return impl_->GetPinnedAllocator();
}

341 342
void* DeviceContext::Alloc(TensorBase* tensor,
                           DataType dtype,
W
wanghuancoder 已提交
343 344
                           size_t requested_size,
                           bool pinned) const {
345 346 347 348
  if (pinned) {
    return impl_->Alloc(
        tensor, GetPinnedPlace(GetPlace()), dtype, requested_size, pinned);
  }
W
wanghuancoder 已提交
349
  return impl_->Alloc(tensor, GetPlace(), dtype, requested_size, pinned);
W
Wilber 已提交
350
}
W
Wilber 已提交
351

352
template <typename T>
W
wanghuancoder 已提交
353 354 355
T* DeviceContext::Alloc(TensorBase* tensor,
                        size_t requested_size,
                        bool pinned) const {
356 357 358 359
  if (pinned) {
    return impl_->Alloc<T>(
        tensor, GetPinnedPlace(GetPlace()), requested_size, pinned);
  }
W
wanghuancoder 已提交
360
  return impl_->Alloc<T>(tensor, GetPlace(), requested_size, pinned);
361 362 363 364 365 366 367 368 369 370 371 372 373 374
}

void* DeviceContext::HostAlloc(TensorBase* tensor,
                               DataType dtype,
                               size_t requested_size) const {
  return impl_->HostAlloc(tensor, dtype, requested_size);
}

template <typename T>
T* DeviceContext::HostAlloc(TensorBase* tensor, size_t requested_size) const {
  return impl_->HostAlloc<T>(tensor, requested_size);
}

#define DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(dtype)              \
W
wanghuancoder 已提交
375 376
  template dtype* DeviceContext::Alloc(                              \
      TensorBase* tensor, size_t requested_size, bool pinned) const; \
377 378 379 380 381 382 383 384 385 386 387 388 389 390 391
  template dtype* DeviceContext::HostAlloc(TensorBase* tensor,       \
                                           size_t requested_size) const;

DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(bool)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(int8_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(uint8_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(int16_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(int32_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(int64_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(float)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(double)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::bfloat16)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::float16)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::complex64)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::complex128)
J
Jack Zhou 已提交
392
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::pstring)
393 394 395

#undef DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION

W
Wilber 已提交
396 397 398 399
void DeviceContext::SetGenerator(Generator* gen) { impl_->SetGenerator(gen); }

Generator* DeviceContext::GetGenerator() const { return impl_->GetGenerator(); }

L
Leo Chen 已提交
400 401 402 403 404 405 406 407
void DeviceContext::SetHostGenerator(Generator* gen) {
  impl_->SetHostGenerator(gen);
}

Generator* DeviceContext::GetHostGenerator() const {
  return impl_->GetHostGenerator();
}

408
}  // namespace phi