device_context.cc 9.5 KB
Newer Older
W
Wilber 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/phi/core/device_context.h"
16
#include "paddle/phi/core/dense_tensor.h"
17
#include "paddle/phi/core/enforce.h"
18
#include "paddle/phi/core/selected_rows.h"
W
Wilber 已提交
19

20
namespace phi {
21
using DataType = paddle::experimental::DataType;
W
Wilber 已提交
22 23 24 25 26

struct DeviceContext::Impl {
  Impl() = default;
  ~Impl() = default;

W
Wilber 已提交
27
  void SetAllocator(const Allocator* allocator) {
28 29
    PADDLE_ENFORCE_NOT_NULL(
        allocator,
30
        phi::errors::InvalidArgument(
31
            "Required allocator shall not be nullptr, but received nullptr."));
W
Wilber 已提交
32 33 34
    device_allocator_ = allocator;
  }

35 36 37
  void SetHostAllocator(const Allocator* allocator) {
    PADDLE_ENFORCE_NOT_NULL(
        allocator,
38
        phi::errors::InvalidArgument(
39 40 41 42 43 44 45
            "Required allocator shall not be nullptr, but received nullptr."));
    host_allocator_ = allocator;
  }

  void SetZeroAllocator(const Allocator* allocator) {
    PADDLE_ENFORCE_NOT_NULL(
        allocator,
46
        phi::errors::InvalidArgument(
47 48 49
            "Required allocator shall not be nullptr, but received nullptr."));
    zero_allocator_ = allocator;
  }
W
Wilber 已提交
50

W
Wilber 已提交
51
  const Allocator& GetAllocator() const {
52 53
    PADDLE_ENFORCE_NOT_NULL(
        device_allocator_,
54 55
        phi::errors::InvalidArgument("Required device_allocator_ shall not be "
                                     "nullptr, but received nullptr."));
W
Wilber 已提交
56 57
    return *device_allocator_;
  }
W
Wilber 已提交
58

W
Wilber 已提交
59
  const Allocator& GetHostAllocator() const {
60 61
    PADDLE_ENFORCE_NOT_NULL(
        host_allocator_,
62 63
        phi::errors::InvalidArgument("Required host_allocator_ shall not be "
                                     "nullptr, but received nullptr."));
W
Wilber 已提交
64 65
    return *host_allocator_;
  }
W
Wilber 已提交
66

67 68 69
  const Allocator& GetZeroAllocator() const {
    PADDLE_ENFORCE_NOT_NULL(
        zero_allocator_,
70 71
        phi::errors::InvalidArgument("Required host_allocator_ shall not be "
                                     "nullptr, but received nullptr."));
72 73 74 75
    return *zero_allocator_;
  }

  void* Alloc(TensorBase* tensor,
76
              const Place& place,
77 78 79 80
              DataType dtype = DataType::UNDEFINED,
              size_t requested_size = 0) const {
    PADDLE_ENFORCE_NOT_NULL(
        tensor,
81
        phi::errors::InvalidArgument(
82 83 84 85
            "Required tensor shall not be nullptr, but received nullptr."));
    if (dtype == DataType::UNDEFINED) {
      dtype = tensor->dtype();
    }
86 87 88 89 90 91
    // NOTE(paddle-dev): In case of tensor has already hold allocation and
    // is going to allocate allocation on new place, we will clear its holder
    // firstly and then re-alloc it.
    if (tensor->initialized() && tensor->place() != place) {
      ClearHolder(tensor);
    }
92 93 94 95 96 97 98
    auto* allocator =
        tensor->numel() == 0 ? zero_allocator_ : device_allocator_;
    return tensor->AllocateFrom(
        const_cast<Allocator*>(allocator), dtype, requested_size);
  }

  template <typename T>
99 100 101
  T* Alloc(TensorBase* tensor,
           const Place& place,
           size_t requested_size = 0) const {
102
    DataType dtype = paddle::experimental::CppTypeToDataType<T>::Type();
103
    return static_cast<T*>(Alloc(tensor, place, dtype, requested_size));
104
  }
W
Wilber 已提交
105

106 107 108 109 110
  void* HostAlloc(TensorBase* tensor,
                  DataType dtype = DataType::UNDEFINED,
                  size_t requested_size = 0) const {
    PADDLE_ENFORCE_NOT_NULL(
        tensor,
111
        phi::errors::InvalidArgument(
112 113 114 115
            "Required tensor shall not be nullptr, but received nullptr."));
    if (dtype == DataType::UNDEFINED) {
      dtype = tensor->dtype();
    }
116 117 118
    if (tensor->initialized() && tensor->place() != CPUPlace()) {
      ClearHolder(tensor);
    }
119 120 121 122 123 124
    auto* allocator = tensor->numel() == 0 ? zero_allocator_ : host_allocator_;
    return tensor->AllocateFrom(
        const_cast<Allocator*>(allocator), dtype, requested_size);
  }

  template <typename T>
125
  T* HostAlloc(phi::TensorBase* tensor, size_t requested_size = 0) const {
126 127 128 129
    DataType dtype = paddle::experimental::CppTypeToDataType<T>::Type();
    return static_cast<T*>(HostAlloc(tensor, dtype, requested_size));
  }

W
Wilber 已提交
130 131 132
  void SetGenerator(Generator* gen) {
    PADDLE_ENFORCE_NOT_NULL(
        gen,
133
        phi::errors::InvalidArgument(
W
Wilber 已提交
134
            "Required generator shall not be nullptr, but received nullptr."));
L
Leo Chen 已提交
135
    device_generator_ = gen;
W
Wilber 已提交
136 137 138 139
  }

  Generator* GetGenerator() const {
    PADDLE_ENFORCE_NOT_NULL(
L
Leo Chen 已提交
140
        device_generator_,
141 142
        phi::errors::InvalidArgument("Required generator_ shall not be "
                                     "nullptr, but received nullptr."));
L
Leo Chen 已提交
143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
    return device_generator_;
  }

  void SetHostGenerator(Generator* gen) {
    PADDLE_ENFORCE_NOT_NULL(
        gen,
        phi::errors::InvalidArgument(
            "Required generator shall not be nullptr, but received nullptr."));
    host_generator_ = gen;
  }

  Generator* GetHostGenerator() const {
    PADDLE_ENFORCE_NOT_NULL(
        host_generator_,
        phi::errors::InvalidArgument("Required generator_ shall not be "
                                     "nullptr, but received nullptr."));
    return host_generator_;
W
Wilber 已提交
160 161
  }

162
 private:
163 164 165 166 167 168 169 170 171 172 173 174 175
  void ClearHolder(TensorBase* tensor) const {
    if (!tensor->initialized()) return;

    if (DenseTensor::classof(tensor)) {
      static_cast<DenseTensor*>(tensor)->clear();
    } else if (SelectedRows::classof(tensor)) {
      static_cast<SelectedRows*>(tensor)->mutable_value()->clear();
    } else {
      PADDLE_THROW(errors::Unimplemented(
          "Only support DenseTensor and SelectedRows now."));
    }
  }

176 177 178
  const Allocator* device_allocator_{nullptr};
  const Allocator* host_allocator_{nullptr};
  const Allocator* zero_allocator_{nullptr};
L
Leo Chen 已提交
179 180
  Generator* device_generator_{nullptr};
  Generator* host_generator_{nullptr};
W
Wilber 已提交
181 182 183 184 185
};

DeviceContext::DeviceContext() { impl_ = std::make_unique<Impl>(); }

DeviceContext::DeviceContext(const DeviceContext& other) {
186
  impl_->SetHostAllocator(&other.GetHostAllocator());
W
Wilber 已提交
187
  impl_->SetAllocator(&other.GetAllocator());
188
  impl_->SetZeroAllocator(&other.GetZeroAllocator());
L
Leo Chen 已提交
189 190
  impl_->SetHostGenerator(other.GetHostGenerator());
  impl_->SetGenerator(other.GetGenerator());
W
Wilber 已提交
191 192 193 194 195 196
}

DeviceContext::DeviceContext(DeviceContext&& other) {
  impl_ = std::move(other.impl_);
}

197
DeviceContext& DeviceContext::operator=(DeviceContext&& other) = default;
198

W
Wilber 已提交
199 200
DeviceContext::~DeviceContext() = default;

W
Wilber 已提交
201 202
void DeviceContext::SetAllocator(const Allocator* allocator) {
  impl_->SetAllocator(allocator);
W
Wilber 已提交
203 204
}

W
Wilber 已提交
205 206
const Allocator& DeviceContext::GetAllocator() const {
  return impl_->GetAllocator();
207 208 209 210
}

void DeviceContext::SetHostAllocator(const Allocator* allocator) {
  impl_->SetHostAllocator(allocator);
W
Wilber 已提交
211 212 213 214
}

const Allocator& DeviceContext::GetHostAllocator() const {
  return impl_->GetHostAllocator();
W
Wilber 已提交
215 216
}

217 218
void DeviceContext::SetZeroAllocator(const Allocator* allocator) {
  impl_->SetZeroAllocator(allocator);
W
Wilber 已提交
219 220
}

221 222 223
const Allocator& DeviceContext::GetZeroAllocator() const {
  return impl_->GetZeroAllocator();
}
W
Wilber 已提交
224

225 226 227
void* DeviceContext::Alloc(TensorBase* tensor,
                           DataType dtype,
                           size_t requested_size) const {
228
  return impl_->Alloc(tensor, GetPlace(), dtype, requested_size);
W
Wilber 已提交
229
}
W
Wilber 已提交
230

231 232
template <typename T>
T* DeviceContext::Alloc(TensorBase* tensor, size_t requested_size) const {
233
  return impl_->Alloc<T>(tensor, GetPlace(), requested_size);
234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267
}

void* DeviceContext::HostAlloc(TensorBase* tensor,
                               DataType dtype,
                               size_t requested_size) const {
  return impl_->HostAlloc(tensor, dtype, requested_size);
}

template <typename T>
T* DeviceContext::HostAlloc(TensorBase* tensor, size_t requested_size) const {
  return impl_->HostAlloc<T>(tensor, requested_size);
}

#define DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(dtype)              \
  template dtype* DeviceContext::Alloc(TensorBase* tensor,           \
                                       size_t requested_size) const; \
  template dtype* DeviceContext::HostAlloc(TensorBase* tensor,       \
                                           size_t requested_size) const;

DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(bool)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(int8_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(uint8_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(int16_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(int32_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(int64_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(float)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(double)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::bfloat16)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::float16)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::complex64)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::complex128)

#undef DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION

W
Wilber 已提交
268 269 270 271
void DeviceContext::SetGenerator(Generator* gen) { impl_->SetGenerator(gen); }

Generator* DeviceContext::GetGenerator() const { return impl_->GetGenerator(); }

L
Leo Chen 已提交
272 273 274 275 276 277 278 279
void DeviceContext::SetHostGenerator(Generator* gen) {
  impl_->SetHostGenerator(gen);
}

Generator* DeviceContext::GetHostGenerator() const {
  return impl_->GetHostGenerator();
}

280
}  // namespace phi