device_context.cc 13.9 KB
Newer Older
W
Wilber 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/phi/core/device_context.h"
16

Y
Yuang Liu 已提交
17
#ifdef PADDLE_WITH_CUDA
18
#include "paddle/phi/backends/gpu/cuda/cuda_graph.h"
Y
Yuang Liu 已提交
19 20
#endif

21
#include "paddle/phi/core/dense_tensor.h"
22
#include "paddle/phi/core/enforce.h"
23
#include "paddle/phi/core/selected_rows.h"
J
Jack Zhou 已提交
24
#include "paddle/phi/core/string_tensor.h"
W
Wilber 已提交
25

26
namespace phi {
27
using DataType = paddle::experimental::DataType;
W
Wilber 已提交
28 29 30 31 32

struct DeviceContext::Impl {
  Impl() = default;
  ~Impl() = default;

W
Wilber 已提交
33
  void SetAllocator(const Allocator* allocator) {
34 35
    PADDLE_ENFORCE_NOT_NULL(
        allocator,
36
        phi::errors::InvalidArgument(
37
            "Required allocator shall not be nullptr, but received nullptr."));
W
Wilber 已提交
38 39 40
    device_allocator_ = allocator;
  }

41 42 43
  void SetHostAllocator(const Allocator* allocator) {
    PADDLE_ENFORCE_NOT_NULL(
        allocator,
44
        phi::errors::InvalidArgument(
45 46 47 48 49 50 51
            "Required allocator shall not be nullptr, but received nullptr."));
    host_allocator_ = allocator;
  }

  void SetZeroAllocator(const Allocator* allocator) {
    PADDLE_ENFORCE_NOT_NULL(
        allocator,
52
        phi::errors::InvalidArgument(
53 54 55
            "Required allocator shall not be nullptr, but received nullptr."));
    zero_allocator_ = allocator;
  }
W
Wilber 已提交
56

57 58 59 60 61 62 63 64
  void SetHostZeroAllocator(const Allocator* allocator) {
    PADDLE_ENFORCE_NOT_NULL(
        allocator,
        phi::errors::InvalidArgument(
            "Required allocator shall not be nullptr, but received nullptr."));
    host_zero_allocator_ = allocator;
  }

W
wanghuancoder 已提交
65 66 67 68 69 70 71 72
  void SetPinnedAllocator(const Allocator* allocator) {
    PADDLE_ENFORCE_NOT_NULL(
        allocator,
        phi::errors::InvalidArgument(
            "Required allocator shall not be nullptr, but received nullptr."));
    pinned_allocator_ = allocator;
  }

Y
Yuang Liu 已提交
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
#ifdef PADDLE_WITH_CUDA
  void SetCUDAGraphAllocator(const Allocator* allocator) {
    // NOTE (Yuang): cuda graph allocator can be set to nullptr, so don't check
    // validation of the allocator here
    cuda_graph_allocator_ = allocator;
  }

  const Allocator& GetCUDAGraphAllocator() const {
    PADDLE_ENFORCE_NOT_NULL(cuda_graph_allocator_,
                            phi::errors::InvalidArgument(
                                "Required cuda_graph_allocator_ shall not be "
                                "nullptr, but received nullptr."));
    return *cuda_graph_allocator_;
  }

  bool IsCUDAGraphAllocatorValid() const {
    return cuda_graph_allocator_ != nullptr;
  }
#endif

W
Wilber 已提交
93
  const Allocator& GetAllocator() const {
94 95
    PADDLE_ENFORCE_NOT_NULL(
        device_allocator_,
96 97
        phi::errors::InvalidArgument("Required device_allocator_ shall not be "
                                     "nullptr, but received nullptr."));
W
Wilber 已提交
98 99
    return *device_allocator_;
  }
W
Wilber 已提交
100

W
Wilber 已提交
101
  const Allocator& GetHostAllocator() const {
102 103
    PADDLE_ENFORCE_NOT_NULL(
        host_allocator_,
104 105
        phi::errors::InvalidArgument("Required host_allocator_ shall not be "
                                     "nullptr, but received nullptr."));
W
Wilber 已提交
106 107
    return *host_allocator_;
  }
W
Wilber 已提交
108

109 110 111
  const Allocator& GetZeroAllocator() const {
    PADDLE_ENFORCE_NOT_NULL(
        zero_allocator_,
W
wanghuancoder 已提交
112
        phi::errors::InvalidArgument("Required zero_allocator_ shall not be "
113
                                     "nullptr, but received nullptr."));
114 115 116
    return *zero_allocator_;
  }

117 118 119 120 121 122 123 124
  const Allocator& GetHostZeroAllocator() const {
    PADDLE_ENFORCE_NOT_NULL(
        host_zero_allocator_,
        phi::errors::InvalidArgument("Required zero_allocator_ shall not be "
                                     "nullptr, but received nullptr."));
    return *host_zero_allocator_;
  }

W
wanghuancoder 已提交
125 126 127 128 129 130 131 132
  const Allocator& GetPinnedAllocator() const {
    PADDLE_ENFORCE_NOT_NULL(
        pinned_allocator_,
        phi::errors::InvalidArgument("Required pinned_allocator_ shall not be "
                                     "nullptr, but received nullptr."));
    return *pinned_allocator_;
  }

133
  void* Alloc(TensorBase* tensor,
134
              const Place& place,
135
              DataType dtype = DataType::UNDEFINED,
W
wanghuancoder 已提交
136 137
              size_t requested_size = 0,
              bool pinned = false) const {
138 139
    PADDLE_ENFORCE_NOT_NULL(
        tensor,
140
        phi::errors::InvalidArgument(
141 142 143 144
            "Required tensor shall not be nullptr, but received nullptr."));
    if (dtype == DataType::UNDEFINED) {
      dtype = tensor->dtype();
    }
145 146 147 148 149 150
    // NOTE(paddle-dev): In case of tensor has already hold allocation and
    // is going to allocate allocation on new place, we will clear its holder
    // firstly and then re-alloc it.
    if (tensor->initialized() && tensor->place() != place) {
      ClearHolder(tensor);
    }
151
    auto* allocator = tensor->numel() == 0 && requested_size == 0
W
wanghuancoder 已提交
152 153
                          ? zero_allocator_
                          : (pinned ? pinned_allocator_ : device_allocator_);
Y
Yuang Liu 已提交
154 155
#ifdef PADDLE_WITH_CUDA
    bool must_cuda_graph_allocator = (tensor->numel() != 0) && !pinned;
156 157 158
    if (must_cuda_graph_allocator &&
        place.GetType() == phi::AllocationType::GPU &&
        phi::backends::gpu::CUDAGraph::IsThisThreadCapturing()) {
Y
Yuang Liu 已提交
159 160 161 162 163 164 165
      PADDLE_ENFORCE_NOT_NULL(cuda_graph_allocator_,
                              phi::errors::InvalidArgument(
                                  "Required cuda_graph_allocator_ shall not be "
                                  "nullptr, but received nullptr."));
      allocator = cuda_graph_allocator_;
    }
#endif
166 167 168 169 170
    return tensor->AllocateFrom(
        const_cast<Allocator*>(allocator), dtype, requested_size);
  }

  template <typename T>
171 172
  T* Alloc(TensorBase* tensor,
           const Place& place,
W
wanghuancoder 已提交
173 174
           size_t requested_size = 0,
           bool pinned = false) const {
175
    DataType dtype = paddle::experimental::CppTypeToDataType<T>::Type();
W
wanghuancoder 已提交
176
    return static_cast<T*>(Alloc(tensor, place, dtype, requested_size, pinned));
177
  }
W
Wilber 已提交
178

179 180 181 182 183
  void* HostAlloc(TensorBase* tensor,
                  DataType dtype = DataType::UNDEFINED,
                  size_t requested_size = 0) const {
    PADDLE_ENFORCE_NOT_NULL(
        tensor,
184
        phi::errors::InvalidArgument(
185 186 187 188
            "Required tensor shall not be nullptr, but received nullptr."));
    if (dtype == DataType::UNDEFINED) {
      dtype = tensor->dtype();
    }
189 190 191
    if (tensor->initialized() && tensor->place() != CPUPlace()) {
      ClearHolder(tensor);
    }
192 193
    auto* allocator =
        tensor->numel() == 0 ? host_zero_allocator_ : host_allocator_;
194 195 196 197 198
    return tensor->AllocateFrom(
        const_cast<Allocator*>(allocator), dtype, requested_size);
  }

  template <typename T>
199
  T* HostAlloc(phi::TensorBase* tensor, size_t requested_size = 0) const {
200 201 202 203
    DataType dtype = paddle::experimental::CppTypeToDataType<T>::Type();
    return static_cast<T*>(HostAlloc(tensor, dtype, requested_size));
  }

W
Wilber 已提交
204 205 206
  void SetGenerator(Generator* gen) {
    PADDLE_ENFORCE_NOT_NULL(
        gen,
207
        phi::errors::InvalidArgument(
W
Wilber 已提交
208
            "Required generator shall not be nullptr, but received nullptr."));
L
Leo Chen 已提交
209
    device_generator_ = gen;
W
Wilber 已提交
210 211 212 213
  }

  Generator* GetGenerator() const {
    PADDLE_ENFORCE_NOT_NULL(
L
Leo Chen 已提交
214
        device_generator_,
215 216
        phi::errors::InvalidArgument("Required generator_ shall not be "
                                     "nullptr, but received nullptr."));
L
Leo Chen 已提交
217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233
    return device_generator_;
  }

  void SetHostGenerator(Generator* gen) {
    PADDLE_ENFORCE_NOT_NULL(
        gen,
        phi::errors::InvalidArgument(
            "Required generator shall not be nullptr, but received nullptr."));
    host_generator_ = gen;
  }

  Generator* GetHostGenerator() const {
    PADDLE_ENFORCE_NOT_NULL(
        host_generator_,
        phi::errors::InvalidArgument("Required generator_ shall not be "
                                     "nullptr, but received nullptr."));
    return host_generator_;
W
Wilber 已提交
234 235
  }

236
 private:
237 238 239 240 241 242 243
  void ClearHolder(TensorBase* tensor) const {
    if (!tensor->initialized()) return;

    if (DenseTensor::classof(tensor)) {
      static_cast<DenseTensor*>(tensor)->clear();
    } else if (SelectedRows::classof(tensor)) {
      static_cast<SelectedRows*>(tensor)->mutable_value()->clear();
J
Jack Zhou 已提交
244 245
    } else if (StringTensor::classof(tensor)) {
      static_cast<StringTensor*>(tensor)->clear();
246 247 248 249 250 251
    } else {
      PADDLE_THROW(errors::Unimplemented(
          "Only support DenseTensor and SelectedRows now."));
    }
  }

252 253 254
  const Allocator* device_allocator_{nullptr};
  const Allocator* host_allocator_{nullptr};
  const Allocator* zero_allocator_{nullptr};
255
  const Allocator* host_zero_allocator_{nullptr};
W
wanghuancoder 已提交
256
  const Allocator* pinned_allocator_{nullptr};
Y
Yuang Liu 已提交
257 258 259
#ifdef PADDLE_WITH_CUDA
  const Allocator* cuda_graph_allocator_{nullptr};
#endif
L
Leo Chen 已提交
260 261
  Generator* device_generator_{nullptr};
  Generator* host_generator_{nullptr};
W
Wilber 已提交
262 263 264 265 266
};

DeviceContext::DeviceContext() { impl_ = std::make_unique<Impl>(); }

DeviceContext::DeviceContext(const DeviceContext& other) {
267
  impl_->SetHostAllocator(&other.GetHostAllocator());
W
Wilber 已提交
268
  impl_->SetAllocator(&other.GetAllocator());
269
  impl_->SetZeroAllocator(&other.GetZeroAllocator());
270
  impl_->SetHostZeroAllocator(&other.GetHostZeroAllocator());
W
wanghuancoder 已提交
271
  impl_->SetPinnedAllocator(&other.GetPinnedAllocator());
L
Leo Chen 已提交
272 273
  impl_->SetHostGenerator(other.GetHostGenerator());
  impl_->SetGenerator(other.GetGenerator());
Y
Yuang Liu 已提交
274 275 276 277 278
#ifdef PADDLE_WITH_CUDA
  if (other.IsCUDAGraphAllocatorValid()) {
    impl_->SetCUDAGraphAllocator(&other.GetCUDAGraphAllocator());
  }
#endif
W
Wilber 已提交
279 280 281 282 283 284
}

DeviceContext::DeviceContext(DeviceContext&& other) {
  impl_ = std::move(other.impl_);
}

285
DeviceContext& DeviceContext::operator=(DeviceContext&& other) = default;
286

W
Wilber 已提交
287 288
DeviceContext::~DeviceContext() = default;

W
Wilber 已提交
289 290
void DeviceContext::SetAllocator(const Allocator* allocator) {
  impl_->SetAllocator(allocator);
W
Wilber 已提交
291 292
}

W
Wilber 已提交
293 294
const Allocator& DeviceContext::GetAllocator() const {
  return impl_->GetAllocator();
295 296 297 298
}

void DeviceContext::SetHostAllocator(const Allocator* allocator) {
  impl_->SetHostAllocator(allocator);
W
Wilber 已提交
299 300 301 302
}

const Allocator& DeviceContext::GetHostAllocator() const {
  return impl_->GetHostAllocator();
W
Wilber 已提交
303 304
}

Y
Yuang Liu 已提交
305 306 307 308 309 310 311 312 313 314 315 316 317 318
#ifdef PADDLE_WITH_CUDA
void DeviceContext::SetCUDAGraphAllocator(const Allocator* allocator) {
  impl_->SetCUDAGraphAllocator(allocator);
}

const Allocator& DeviceContext::GetCUDAGraphAllocator() const {
  return impl_->GetCUDAGraphAllocator();
}

bool DeviceContext::IsCUDAGraphAllocatorValid() const {
  return impl_->IsCUDAGraphAllocatorValid();
}
#endif

319 320
void DeviceContext::SetZeroAllocator(const Allocator* allocator) {
  impl_->SetZeroAllocator(allocator);
W
Wilber 已提交
321 322
}

323 324 325 326
void DeviceContext::SetHostZeroAllocator(const Allocator* allocator) {
  impl_->SetHostZeroAllocator(allocator);
}

327 328 329
const Allocator& DeviceContext::GetZeroAllocator() const {
  return impl_->GetZeroAllocator();
}
W
Wilber 已提交
330

331 332 333 334
const Allocator& DeviceContext::GetHostZeroAllocator() const {
  return impl_->GetHostZeroAllocator();
}

W
wanghuancoder 已提交
335 336 337 338 339 340 341
void DeviceContext::SetPinnedAllocator(const Allocator* allocator) {
  impl_->SetPinnedAllocator(allocator);
}
const Allocator& DeviceContext::GetPinnedAllocator() const {
  return impl_->GetPinnedAllocator();
}

342 343
void* DeviceContext::Alloc(TensorBase* tensor,
                           DataType dtype,
W
wanghuancoder 已提交
344 345
                           size_t requested_size,
                           bool pinned) const {
346 347 348 349
  if (pinned) {
    return impl_->Alloc(
        tensor, GetPinnedPlace(GetPlace()), dtype, requested_size, pinned);
  }
W
wanghuancoder 已提交
350
  return impl_->Alloc(tensor, GetPlace(), dtype, requested_size, pinned);
W
Wilber 已提交
351
}
W
Wilber 已提交
352

353
template <typename T>
W
wanghuancoder 已提交
354 355 356
T* DeviceContext::Alloc(TensorBase* tensor,
                        size_t requested_size,
                        bool pinned) const {
357 358 359 360
  if (pinned) {
    return impl_->Alloc<T>(
        tensor, GetPinnedPlace(GetPlace()), requested_size, pinned);
  }
W
wanghuancoder 已提交
361
  return impl_->Alloc<T>(tensor, GetPlace(), requested_size, pinned);
362 363 364 365 366 367 368 369 370 371 372 373 374 375
}

void* DeviceContext::HostAlloc(TensorBase* tensor,
                               DataType dtype,
                               size_t requested_size) const {
  return impl_->HostAlloc(tensor, dtype, requested_size);
}

template <typename T>
T* DeviceContext::HostAlloc(TensorBase* tensor, size_t requested_size) const {
  return impl_->HostAlloc<T>(tensor, requested_size);
}

#define DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(dtype)              \
W
wanghuancoder 已提交
376 377
  template dtype* DeviceContext::Alloc(                              \
      TensorBase* tensor, size_t requested_size, bool pinned) const; \
378 379 380 381 382 383 384 385 386 387 388 389 390 391 392
  template dtype* DeviceContext::HostAlloc(TensorBase* tensor,       \
                                           size_t requested_size) const;

DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(bool)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(int8_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(uint8_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(int16_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(int32_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(int64_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(float)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(double)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::bfloat16)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::float16)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::complex64)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::complex128)
J
Jack Zhou 已提交
393
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::pstring)
394 395 396

#undef DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION

W
Wilber 已提交
397 398 399 400
void DeviceContext::SetGenerator(Generator* gen) { impl_->SetGenerator(gen); }

Generator* DeviceContext::GetGenerator() const { return impl_->GetGenerator(); }

L
Leo Chen 已提交
401 402 403 404 405 406 407 408
void DeviceContext::SetHostGenerator(Generator* gen) {
  impl_->SetHostGenerator(gen);
}

Generator* DeviceContext::GetHostGenerator() const {
  return impl_->GetHostGenerator();
}

409
}  // namespace phi