device_context.cc 10.8 KB
Newer Older
W
Wilber 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/phi/core/device_context.h"
16
#include "paddle/phi/core/dense_tensor.h"
17
#include "paddle/phi/core/enforce.h"
18
#include "paddle/phi/core/selected_rows.h"
J
Jack Zhou 已提交
19
#include "paddle/phi/core/string_tensor.h"
W
Wilber 已提交
20

21
namespace phi {
22
using DataType = paddle::experimental::DataType;
W
Wilber 已提交
23 24 25 26 27

struct DeviceContext::Impl {
  Impl() = default;
  ~Impl() = default;

W
Wilber 已提交
28
  void SetAllocator(const Allocator* allocator) {
29 30
    PADDLE_ENFORCE_NOT_NULL(
        allocator,
31
        phi::errors::InvalidArgument(
32
            "Required allocator shall not be nullptr, but received nullptr."));
W
Wilber 已提交
33 34 35
    device_allocator_ = allocator;
  }

36 37 38
  void SetHostAllocator(const Allocator* allocator) {
    PADDLE_ENFORCE_NOT_NULL(
        allocator,
39
        phi::errors::InvalidArgument(
40 41 42 43 44 45 46
            "Required allocator shall not be nullptr, but received nullptr."));
    host_allocator_ = allocator;
  }

  void SetZeroAllocator(const Allocator* allocator) {
    PADDLE_ENFORCE_NOT_NULL(
        allocator,
47
        phi::errors::InvalidArgument(
48 49 50
            "Required allocator shall not be nullptr, but received nullptr."));
    zero_allocator_ = allocator;
  }
W
Wilber 已提交
51

W
wanghuancoder 已提交
52 53 54 55 56 57 58 59
  void SetPinnedAllocator(const Allocator* allocator) {
    PADDLE_ENFORCE_NOT_NULL(
        allocator,
        phi::errors::InvalidArgument(
            "Required allocator shall not be nullptr, but received nullptr."));
    pinned_allocator_ = allocator;
  }

W
Wilber 已提交
60
  const Allocator& GetAllocator() const {
61 62
    PADDLE_ENFORCE_NOT_NULL(
        device_allocator_,
63 64
        phi::errors::InvalidArgument("Required device_allocator_ shall not be "
                                     "nullptr, but received nullptr."));
W
Wilber 已提交
65 66
    return *device_allocator_;
  }
W
Wilber 已提交
67

W
Wilber 已提交
68
  const Allocator& GetHostAllocator() const {
69 70
    PADDLE_ENFORCE_NOT_NULL(
        host_allocator_,
71 72
        phi::errors::InvalidArgument("Required host_allocator_ shall not be "
                                     "nullptr, but received nullptr."));
W
Wilber 已提交
73 74
    return *host_allocator_;
  }
W
Wilber 已提交
75

76 77 78
  const Allocator& GetZeroAllocator() const {
    PADDLE_ENFORCE_NOT_NULL(
        zero_allocator_,
W
wanghuancoder 已提交
79
        phi::errors::InvalidArgument("Required zero_allocator_ shall not be "
80
                                     "nullptr, but received nullptr."));
81 82 83
    return *zero_allocator_;
  }

W
wanghuancoder 已提交
84 85 86 87 88 89 90 91
  const Allocator& GetPinnedAllocator() const {
    PADDLE_ENFORCE_NOT_NULL(
        pinned_allocator_,
        phi::errors::InvalidArgument("Required pinned_allocator_ shall not be "
                                     "nullptr, but received nullptr."));
    return *pinned_allocator_;
  }

92
  void* Alloc(TensorBase* tensor,
93
              const Place& place,
94
              DataType dtype = DataType::UNDEFINED,
W
wanghuancoder 已提交
95 96
              size_t requested_size = 0,
              bool pinned = false) const {
97 98
    PADDLE_ENFORCE_NOT_NULL(
        tensor,
99
        phi::errors::InvalidArgument(
100 101 102 103
            "Required tensor shall not be nullptr, but received nullptr."));
    if (dtype == DataType::UNDEFINED) {
      dtype = tensor->dtype();
    }
104 105 106 107 108 109
    // NOTE(paddle-dev): In case of tensor has already hold allocation and
    // is going to allocate allocation on new place, we will clear its holder
    // firstly and then re-alloc it.
    if (tensor->initialized() && tensor->place() != place) {
      ClearHolder(tensor);
    }
W
wanghuancoder 已提交
110 111 112
    auto* allocator = tensor->numel() == 0
                          ? zero_allocator_
                          : (pinned ? pinned_allocator_ : device_allocator_);
113 114 115 116 117
    return tensor->AllocateFrom(
        const_cast<Allocator*>(allocator), dtype, requested_size);
  }

  template <typename T>
118 119
  T* Alloc(TensorBase* tensor,
           const Place& place,
W
wanghuancoder 已提交
120 121
           size_t requested_size = 0,
           bool pinned = false) const {
122
    DataType dtype = paddle::experimental::CppTypeToDataType<T>::Type();
W
wanghuancoder 已提交
123
    return static_cast<T*>(Alloc(tensor, place, dtype, requested_size, pinned));
124
  }
W
Wilber 已提交
125

126 127 128 129 130
  void* HostAlloc(TensorBase* tensor,
                  DataType dtype = DataType::UNDEFINED,
                  size_t requested_size = 0) const {
    PADDLE_ENFORCE_NOT_NULL(
        tensor,
131
        phi::errors::InvalidArgument(
132 133 134 135
            "Required tensor shall not be nullptr, but received nullptr."));
    if (dtype == DataType::UNDEFINED) {
      dtype = tensor->dtype();
    }
136 137 138
    if (tensor->initialized() && tensor->place() != CPUPlace()) {
      ClearHolder(tensor);
    }
139 140 141 142 143 144
    auto* allocator = tensor->numel() == 0 ? zero_allocator_ : host_allocator_;
    return tensor->AllocateFrom(
        const_cast<Allocator*>(allocator), dtype, requested_size);
  }

  template <typename T>
145
  T* HostAlloc(phi::TensorBase* tensor, size_t requested_size = 0) const {
146 147 148 149
    DataType dtype = paddle::experimental::CppTypeToDataType<T>::Type();
    return static_cast<T*>(HostAlloc(tensor, dtype, requested_size));
  }

W
Wilber 已提交
150 151 152
  void SetGenerator(Generator* gen) {
    PADDLE_ENFORCE_NOT_NULL(
        gen,
153
        phi::errors::InvalidArgument(
W
Wilber 已提交
154
            "Required generator shall not be nullptr, but received nullptr."));
L
Leo Chen 已提交
155
    device_generator_ = gen;
W
Wilber 已提交
156 157 158 159
  }

  Generator* GetGenerator() const {
    PADDLE_ENFORCE_NOT_NULL(
L
Leo Chen 已提交
160
        device_generator_,
161 162
        phi::errors::InvalidArgument("Required generator_ shall not be "
                                     "nullptr, but received nullptr."));
L
Leo Chen 已提交
163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
    return device_generator_;
  }

  void SetHostGenerator(Generator* gen) {
    PADDLE_ENFORCE_NOT_NULL(
        gen,
        phi::errors::InvalidArgument(
            "Required generator shall not be nullptr, but received nullptr."));
    host_generator_ = gen;
  }

  Generator* GetHostGenerator() const {
    PADDLE_ENFORCE_NOT_NULL(
        host_generator_,
        phi::errors::InvalidArgument("Required generator_ shall not be "
                                     "nullptr, but received nullptr."));
    return host_generator_;
W
Wilber 已提交
180 181
  }

182
 private:
183 184 185 186 187 188 189
  void ClearHolder(TensorBase* tensor) const {
    if (!tensor->initialized()) return;

    if (DenseTensor::classof(tensor)) {
      static_cast<DenseTensor*>(tensor)->clear();
    } else if (SelectedRows::classof(tensor)) {
      static_cast<SelectedRows*>(tensor)->mutable_value()->clear();
J
Jack Zhou 已提交
190 191
    } else if (StringTensor::classof(tensor)) {
      static_cast<StringTensor*>(tensor)->clear();
192 193 194 195 196 197
    } else {
      PADDLE_THROW(errors::Unimplemented(
          "Only support DenseTensor and SelectedRows now."));
    }
  }

198 199 200
  const Allocator* device_allocator_{nullptr};
  const Allocator* host_allocator_{nullptr};
  const Allocator* zero_allocator_{nullptr};
W
wanghuancoder 已提交
201
  const Allocator* pinned_allocator_{nullptr};
L
Leo Chen 已提交
202 203
  Generator* device_generator_{nullptr};
  Generator* host_generator_{nullptr};
W
Wilber 已提交
204 205 206 207 208
};

DeviceContext::DeviceContext() { impl_ = std::make_unique<Impl>(); }

DeviceContext::DeviceContext(const DeviceContext& other) {
209
  impl_->SetHostAllocator(&other.GetHostAllocator());
W
Wilber 已提交
210
  impl_->SetAllocator(&other.GetAllocator());
211
  impl_->SetZeroAllocator(&other.GetZeroAllocator());
W
wanghuancoder 已提交
212
  impl_->SetPinnedAllocator(&other.GetPinnedAllocator());
L
Leo Chen 已提交
213 214
  impl_->SetHostGenerator(other.GetHostGenerator());
  impl_->SetGenerator(other.GetGenerator());
W
Wilber 已提交
215 216 217 218 219 220
}

DeviceContext::DeviceContext(DeviceContext&& other) {
  impl_ = std::move(other.impl_);
}

221
DeviceContext& DeviceContext::operator=(DeviceContext&& other) = default;
222

W
Wilber 已提交
223 224
DeviceContext::~DeviceContext() = default;

W
Wilber 已提交
225 226
void DeviceContext::SetAllocator(const Allocator* allocator) {
  impl_->SetAllocator(allocator);
W
Wilber 已提交
227 228
}

W
Wilber 已提交
229 230
const Allocator& DeviceContext::GetAllocator() const {
  return impl_->GetAllocator();
231 232 233 234
}

void DeviceContext::SetHostAllocator(const Allocator* allocator) {
  impl_->SetHostAllocator(allocator);
W
Wilber 已提交
235 236 237 238
}

const Allocator& DeviceContext::GetHostAllocator() const {
  return impl_->GetHostAllocator();
W
Wilber 已提交
239 240
}

241 242
void DeviceContext::SetZeroAllocator(const Allocator* allocator) {
  impl_->SetZeroAllocator(allocator);
W
Wilber 已提交
243 244
}

245 246 247
const Allocator& DeviceContext::GetZeroAllocator() const {
  return impl_->GetZeroAllocator();
}
W
Wilber 已提交
248

W
wanghuancoder 已提交
249 250 251 252 253 254 255
void DeviceContext::SetPinnedAllocator(const Allocator* allocator) {
  impl_->SetPinnedAllocator(allocator);
}
const Allocator& DeviceContext::GetPinnedAllocator() const {
  return impl_->GetPinnedAllocator();
}

256 257
void* DeviceContext::Alloc(TensorBase* tensor,
                           DataType dtype,
W
wanghuancoder 已提交
258 259 260
                           size_t requested_size,
                           bool pinned) const {
  return impl_->Alloc(tensor, GetPlace(), dtype, requested_size, pinned);
W
Wilber 已提交
261
}
W
Wilber 已提交
262

263
template <typename T>
W
wanghuancoder 已提交
264 265 266 267
T* DeviceContext::Alloc(TensorBase* tensor,
                        size_t requested_size,
                        bool pinned) const {
  return impl_->Alloc<T>(tensor, GetPlace(), requested_size, pinned);
268 269 270 271 272 273 274 275 276 277 278 279 280 281
}

void* DeviceContext::HostAlloc(TensorBase* tensor,
                               DataType dtype,
                               size_t requested_size) const {
  return impl_->HostAlloc(tensor, dtype, requested_size);
}

template <typename T>
T* DeviceContext::HostAlloc(TensorBase* tensor, size_t requested_size) const {
  return impl_->HostAlloc<T>(tensor, requested_size);
}

#define DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(dtype)              \
W
wanghuancoder 已提交
282 283
  template dtype* DeviceContext::Alloc(                              \
      TensorBase* tensor, size_t requested_size, bool pinned) const; \
284 285 286 287 288 289 290 291 292 293 294 295 296 297 298
  template dtype* DeviceContext::HostAlloc(TensorBase* tensor,       \
                                           size_t requested_size) const;

DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(bool)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(int8_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(uint8_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(int16_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(int32_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(int64_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(float)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(double)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::bfloat16)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::float16)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::complex64)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::complex128)
J
Jack Zhou 已提交
299
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::pstring)
300 301 302

#undef DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION

W
Wilber 已提交
303 304 305 306
void DeviceContext::SetGenerator(Generator* gen) { impl_->SetGenerator(gen); }

Generator* DeviceContext::GetGenerator() const { return impl_->GetGenerator(); }

L
Leo Chen 已提交
307 308 309 310 311 312 313 314
void DeviceContext::SetHostGenerator(Generator* gen) {
  impl_->SetHostGenerator(gen);
}

Generator* DeviceContext::GetHostGenerator() const {
  return impl_->GetHostGenerator();
}

315
}  // namespace phi