device_context.cc 10.8 KB
Newer Older
W
Wilber 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/phi/core/device_context.h"
16

17
#include "paddle/phi/core/dense_tensor.h"
18
#include "paddle/phi/core/enforce.h"
19
#include "paddle/phi/core/selected_rows.h"
J
Jack Zhou 已提交
20
#include "paddle/phi/core/string_tensor.h"
W
Wilber 已提交
21

22
namespace phi {
23
using DataType = paddle::experimental::DataType;
W
Wilber 已提交
24 25 26 27 28

struct DeviceContext::Impl {
  Impl() = default;
  ~Impl() = default;

W
Wilber 已提交
29
  void SetAllocator(const Allocator* allocator) {
30 31
    PADDLE_ENFORCE_NOT_NULL(
        allocator,
32
        phi::errors::InvalidArgument(
33
            "Required allocator shall not be nullptr, but received nullptr."));
W
Wilber 已提交
34 35 36
    device_allocator_ = allocator;
  }

37 38 39
  void SetHostAllocator(const Allocator* allocator) {
    PADDLE_ENFORCE_NOT_NULL(
        allocator,
40
        phi::errors::InvalidArgument(
41 42 43 44 45 46 47
            "Required allocator shall not be nullptr, but received nullptr."));
    host_allocator_ = allocator;
  }

  void SetZeroAllocator(const Allocator* allocator) {
    PADDLE_ENFORCE_NOT_NULL(
        allocator,
48
        phi::errors::InvalidArgument(
49 50 51
            "Required allocator shall not be nullptr, but received nullptr."));
    zero_allocator_ = allocator;
  }
W
Wilber 已提交
52

W
wanghuancoder 已提交
53 54 55 56 57 58 59 60
  void SetPinnedAllocator(const Allocator* allocator) {
    PADDLE_ENFORCE_NOT_NULL(
        allocator,
        phi::errors::InvalidArgument(
            "Required allocator shall not be nullptr, but received nullptr."));
    pinned_allocator_ = allocator;
  }

W
Wilber 已提交
61
  const Allocator& GetAllocator() const {
62 63
    PADDLE_ENFORCE_NOT_NULL(
        device_allocator_,
64 65
        phi::errors::InvalidArgument("Required device_allocator_ shall not be "
                                     "nullptr, but received nullptr."));
W
Wilber 已提交
66 67
    return *device_allocator_;
  }
W
Wilber 已提交
68

W
Wilber 已提交
69
  const Allocator& GetHostAllocator() const {
70 71
    PADDLE_ENFORCE_NOT_NULL(
        host_allocator_,
72 73
        phi::errors::InvalidArgument("Required host_allocator_ shall not be "
                                     "nullptr, but received nullptr."));
W
Wilber 已提交
74 75
    return *host_allocator_;
  }
W
Wilber 已提交
76

77 78 79
  const Allocator& GetZeroAllocator() const {
    PADDLE_ENFORCE_NOT_NULL(
        zero_allocator_,
W
wanghuancoder 已提交
80
        phi::errors::InvalidArgument("Required zero_allocator_ shall not be "
81
                                     "nullptr, but received nullptr."));
82 83 84
    return *zero_allocator_;
  }

W
wanghuancoder 已提交
85 86 87 88 89 90 91 92
  const Allocator& GetPinnedAllocator() const {
    PADDLE_ENFORCE_NOT_NULL(
        pinned_allocator_,
        phi::errors::InvalidArgument("Required pinned_allocator_ shall not be "
                                     "nullptr, but received nullptr."));
    return *pinned_allocator_;
  }

93
  void* Alloc(TensorBase* tensor,
94
              const Place& place,
95
              DataType dtype = DataType::UNDEFINED,
W
wanghuancoder 已提交
96 97
              size_t requested_size = 0,
              bool pinned = false) const {
98 99
    PADDLE_ENFORCE_NOT_NULL(
        tensor,
100
        phi::errors::InvalidArgument(
101 102 103 104
            "Required tensor shall not be nullptr, but received nullptr."));
    if (dtype == DataType::UNDEFINED) {
      dtype = tensor->dtype();
    }
105 106 107 108 109 110
    // NOTE(paddle-dev): In case of tensor has already hold allocation and
    // is going to allocate allocation on new place, we will clear its holder
    // firstly and then re-alloc it.
    if (tensor->initialized() && tensor->place() != place) {
      ClearHolder(tensor);
    }
W
wanghuancoder 已提交
111 112 113
    auto* allocator = tensor->numel() == 0
                          ? zero_allocator_
                          : (pinned ? pinned_allocator_ : device_allocator_);
114 115 116 117 118
    return tensor->AllocateFrom(
        const_cast<Allocator*>(allocator), dtype, requested_size);
  }

  template <typename T>
119 120
  T* Alloc(TensorBase* tensor,
           const Place& place,
W
wanghuancoder 已提交
121 122
           size_t requested_size = 0,
           bool pinned = false) const {
123
    DataType dtype = paddle::experimental::CppTypeToDataType<T>::Type();
W
wanghuancoder 已提交
124
    return static_cast<T*>(Alloc(tensor, place, dtype, requested_size, pinned));
125
  }
W
Wilber 已提交
126

127 128 129 130 131
  void* HostAlloc(TensorBase* tensor,
                  DataType dtype = DataType::UNDEFINED,
                  size_t requested_size = 0) const {
    PADDLE_ENFORCE_NOT_NULL(
        tensor,
132
        phi::errors::InvalidArgument(
133 134 135 136
            "Required tensor shall not be nullptr, but received nullptr."));
    if (dtype == DataType::UNDEFINED) {
      dtype = tensor->dtype();
    }
137 138 139
    if (tensor->initialized() && tensor->place() != CPUPlace()) {
      ClearHolder(tensor);
    }
140 141 142 143 144 145
    auto* allocator = tensor->numel() == 0 ? zero_allocator_ : host_allocator_;
    return tensor->AllocateFrom(
        const_cast<Allocator*>(allocator), dtype, requested_size);
  }

  template <typename T>
146
  T* HostAlloc(phi::TensorBase* tensor, size_t requested_size = 0) const {
147 148 149 150
    DataType dtype = paddle::experimental::CppTypeToDataType<T>::Type();
    return static_cast<T*>(HostAlloc(tensor, dtype, requested_size));
  }

W
Wilber 已提交
151 152 153
  void SetGenerator(Generator* gen) {
    PADDLE_ENFORCE_NOT_NULL(
        gen,
154
        phi::errors::InvalidArgument(
W
Wilber 已提交
155
            "Required generator shall not be nullptr, but received nullptr."));
L
Leo Chen 已提交
156
    device_generator_ = gen;
W
Wilber 已提交
157 158 159 160
  }

  Generator* GetGenerator() const {
    PADDLE_ENFORCE_NOT_NULL(
L
Leo Chen 已提交
161
        device_generator_,
162 163
        phi::errors::InvalidArgument("Required generator_ shall not be "
                                     "nullptr, but received nullptr."));
L
Leo Chen 已提交
164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
    return device_generator_;
  }

  void SetHostGenerator(Generator* gen) {
    PADDLE_ENFORCE_NOT_NULL(
        gen,
        phi::errors::InvalidArgument(
            "Required generator shall not be nullptr, but received nullptr."));
    host_generator_ = gen;
  }

  Generator* GetHostGenerator() const {
    PADDLE_ENFORCE_NOT_NULL(
        host_generator_,
        phi::errors::InvalidArgument("Required generator_ shall not be "
                                     "nullptr, but received nullptr."));
    return host_generator_;
W
Wilber 已提交
181 182
  }

183
 private:
184 185 186 187 188 189 190
  void ClearHolder(TensorBase* tensor) const {
    if (!tensor->initialized()) return;

    if (DenseTensor::classof(tensor)) {
      static_cast<DenseTensor*>(tensor)->clear();
    } else if (SelectedRows::classof(tensor)) {
      static_cast<SelectedRows*>(tensor)->mutable_value()->clear();
J
Jack Zhou 已提交
191 192
    } else if (StringTensor::classof(tensor)) {
      static_cast<StringTensor*>(tensor)->clear();
193 194 195 196 197 198
    } else {
      PADDLE_THROW(errors::Unimplemented(
          "Only support DenseTensor and SelectedRows now."));
    }
  }

199 200 201
  const Allocator* device_allocator_{nullptr};
  const Allocator* host_allocator_{nullptr};
  const Allocator* zero_allocator_{nullptr};
W
wanghuancoder 已提交
202
  const Allocator* pinned_allocator_{nullptr};
L
Leo Chen 已提交
203 204
  Generator* device_generator_{nullptr};
  Generator* host_generator_{nullptr};
W
Wilber 已提交
205 206 207 208 209
};

DeviceContext::DeviceContext() { impl_ = std::make_unique<Impl>(); }

DeviceContext::DeviceContext(const DeviceContext& other) {
210
  impl_->SetHostAllocator(&other.GetHostAllocator());
W
Wilber 已提交
211
  impl_->SetAllocator(&other.GetAllocator());
212
  impl_->SetZeroAllocator(&other.GetZeroAllocator());
W
wanghuancoder 已提交
213
  impl_->SetPinnedAllocator(&other.GetPinnedAllocator());
L
Leo Chen 已提交
214 215
  impl_->SetHostGenerator(other.GetHostGenerator());
  impl_->SetGenerator(other.GetGenerator());
W
Wilber 已提交
216 217 218 219 220 221
}

DeviceContext::DeviceContext(DeviceContext&& other) {
  impl_ = std::move(other.impl_);
}

222
DeviceContext& DeviceContext::operator=(DeviceContext&& other) = default;
223

W
Wilber 已提交
224 225
DeviceContext::~DeviceContext() = default;

W
Wilber 已提交
226 227
void DeviceContext::SetAllocator(const Allocator* allocator) {
  impl_->SetAllocator(allocator);
W
Wilber 已提交
228 229
}

W
Wilber 已提交
230 231
const Allocator& DeviceContext::GetAllocator() const {
  return impl_->GetAllocator();
232 233 234 235
}

void DeviceContext::SetHostAllocator(const Allocator* allocator) {
  impl_->SetHostAllocator(allocator);
W
Wilber 已提交
236 237 238 239
}

const Allocator& DeviceContext::GetHostAllocator() const {
  return impl_->GetHostAllocator();
W
Wilber 已提交
240 241
}

242 243
void DeviceContext::SetZeroAllocator(const Allocator* allocator) {
  impl_->SetZeroAllocator(allocator);
W
Wilber 已提交
244 245
}

246 247 248
const Allocator& DeviceContext::GetZeroAllocator() const {
  return impl_->GetZeroAllocator();
}
W
Wilber 已提交
249

W
wanghuancoder 已提交
250 251 252 253 254 255 256
void DeviceContext::SetPinnedAllocator(const Allocator* allocator) {
  impl_->SetPinnedAllocator(allocator);
}
const Allocator& DeviceContext::GetPinnedAllocator() const {
  return impl_->GetPinnedAllocator();
}

257 258
void* DeviceContext::Alloc(TensorBase* tensor,
                           DataType dtype,
W
wanghuancoder 已提交
259 260 261
                           size_t requested_size,
                           bool pinned) const {
  return impl_->Alloc(tensor, GetPlace(), dtype, requested_size, pinned);
W
Wilber 已提交
262
}
W
Wilber 已提交
263

264
template <typename T>
W
wanghuancoder 已提交
265 266 267 268
T* DeviceContext::Alloc(TensorBase* tensor,
                        size_t requested_size,
                        bool pinned) const {
  return impl_->Alloc<T>(tensor, GetPlace(), requested_size, pinned);
269 270 271 272 273 274 275 276 277 278 279 280 281 282
}

void* DeviceContext::HostAlloc(TensorBase* tensor,
                               DataType dtype,
                               size_t requested_size) const {
  return impl_->HostAlloc(tensor, dtype, requested_size);
}

template <typename T>
T* DeviceContext::HostAlloc(TensorBase* tensor, size_t requested_size) const {
  return impl_->HostAlloc<T>(tensor, requested_size);
}

#define DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(dtype)              \
W
wanghuancoder 已提交
283 284
  template dtype* DeviceContext::Alloc(                              \
      TensorBase* tensor, size_t requested_size, bool pinned) const; \
285 286 287 288 289 290 291 292 293 294 295 296 297 298 299
  template dtype* DeviceContext::HostAlloc(TensorBase* tensor,       \
                                           size_t requested_size) const;

DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(bool)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(int8_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(uint8_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(int16_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(int32_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(int64_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(float)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(double)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::bfloat16)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::float16)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::complex64)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::complex128)
J
Jack Zhou 已提交
300
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::pstring)
301 302 303

#undef DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION

W
Wilber 已提交
304 305 306 307
void DeviceContext::SetGenerator(Generator* gen) { impl_->SetGenerator(gen); }

Generator* DeviceContext::GetGenerator() const { return impl_->GetGenerator(); }

L
Leo Chen 已提交
308 309 310 311 312 313 314 315
void DeviceContext::SetHostGenerator(Generator* gen) {
  impl_->SetHostGenerator(gen);
}

Generator* DeviceContext::GetHostGenerator() const {
  return impl_->GetHostGenerator();
}

316
}  // namespace phi