device_context.cc 8.6 KB
Newer Older
W
Wilber 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16 17
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/tensor_base.h"
W
Wilber 已提交
18

19
namespace phi {
20
using DataType = paddle::experimental::DataType;
W
Wilber 已提交
21 22 23 24 25

struct DeviceContext::Impl {
  Impl() = default;
  ~Impl() = default;

W
Wilber 已提交
26
  void SetAllocator(const Allocator* allocator) {
27 28
    PADDLE_ENFORCE_NOT_NULL(
        allocator,
29
        phi::errors::InvalidArgument(
30
            "Required allocator shall not be nullptr, but received nullptr."));
W
Wilber 已提交
31 32 33
    device_allocator_ = allocator;
  }

34 35 36
  void SetHostAllocator(const Allocator* allocator) {
    PADDLE_ENFORCE_NOT_NULL(
        allocator,
37
        phi::errors::InvalidArgument(
38 39 40 41 42 43 44
            "Required allocator shall not be nullptr, but received nullptr."));
    host_allocator_ = allocator;
  }

  void SetZeroAllocator(const Allocator* allocator) {
    PADDLE_ENFORCE_NOT_NULL(
        allocator,
45
        phi::errors::InvalidArgument(
46 47 48
            "Required allocator shall not be nullptr, but received nullptr."));
    zero_allocator_ = allocator;
  }
W
Wilber 已提交
49

W
Wilber 已提交
50
  const Allocator& GetAllocator() const {
51 52
    PADDLE_ENFORCE_NOT_NULL(
        device_allocator_,
53 54
        phi::errors::InvalidArgument("Required device_allocator_ shall not be "
                                     "nullptr, but received nullptr."));
W
Wilber 已提交
55 56
    return *device_allocator_;
  }
W
Wilber 已提交
57

W
Wilber 已提交
58
  const Allocator& GetHostAllocator() const {
59 60
    PADDLE_ENFORCE_NOT_NULL(
        host_allocator_,
61 62
        phi::errors::InvalidArgument("Required host_allocator_ shall not be "
                                     "nullptr, but received nullptr."));
W
Wilber 已提交
63 64
    return *host_allocator_;
  }
W
Wilber 已提交
65

66 67 68
  const Allocator& GetZeroAllocator() const {
    PADDLE_ENFORCE_NOT_NULL(
        zero_allocator_,
69 70
        phi::errors::InvalidArgument("Required host_allocator_ shall not be "
                                     "nullptr, but received nullptr."));
71 72 73 74 75 76 77 78
    return *zero_allocator_;
  }

  void* Alloc(TensorBase* tensor,
              DataType dtype = DataType::UNDEFINED,
              size_t requested_size = 0) const {
    PADDLE_ENFORCE_NOT_NULL(
        tensor,
79
        phi::errors::InvalidArgument(
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
            "Required tensor shall not be nullptr, but received nullptr."));
    if (dtype == DataType::UNDEFINED) {
      dtype = tensor->dtype();
    }
    auto* allocator =
        tensor->numel() == 0 ? zero_allocator_ : device_allocator_;
    return tensor->AllocateFrom(
        const_cast<Allocator*>(allocator), dtype, requested_size);
  }

  template <typename T>
  T* Alloc(TensorBase* tensor, size_t requested_size = 0) const {
    DataType dtype = paddle::experimental::CppTypeToDataType<T>::Type();
    return static_cast<T*>(Alloc(tensor, dtype, requested_size));
  }
W
Wilber 已提交
95

96 97 98 99 100
  void* HostAlloc(TensorBase* tensor,
                  DataType dtype = DataType::UNDEFINED,
                  size_t requested_size = 0) const {
    PADDLE_ENFORCE_NOT_NULL(
        tensor,
101
        phi::errors::InvalidArgument(
102 103 104 105 106 107 108 109 110 111
            "Required tensor shall not be nullptr, but received nullptr."));
    if (dtype == DataType::UNDEFINED) {
      dtype = tensor->dtype();
    }
    auto* allocator = tensor->numel() == 0 ? zero_allocator_ : host_allocator_;
    return tensor->AllocateFrom(
        const_cast<Allocator*>(allocator), dtype, requested_size);
  }

  template <typename T>
112
  T* HostAlloc(phi::TensorBase* tensor, size_t requested_size = 0) const {
113 114 115 116
    DataType dtype = paddle::experimental::CppTypeToDataType<T>::Type();
    return static_cast<T*>(HostAlloc(tensor, dtype, requested_size));
  }

W
Wilber 已提交
117 118 119
  void SetGenerator(Generator* gen) {
    PADDLE_ENFORCE_NOT_NULL(
        gen,
120
        phi::errors::InvalidArgument(
W
Wilber 已提交
121
            "Required generator shall not be nullptr, but received nullptr."));
L
Leo Chen 已提交
122
    device_generator_ = gen;
W
Wilber 已提交
123 124 125 126
  }

  Generator* GetGenerator() const {
    PADDLE_ENFORCE_NOT_NULL(
L
Leo Chen 已提交
127
        device_generator_,
128 129
        phi::errors::InvalidArgument("Required generator_ shall not be "
                                     "nullptr, but received nullptr."));
L
Leo Chen 已提交
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
    return device_generator_;
  }

  void SetHostGenerator(Generator* gen) {
    PADDLE_ENFORCE_NOT_NULL(
        gen,
        phi::errors::InvalidArgument(
            "Required generator shall not be nullptr, but received nullptr."));
    host_generator_ = gen;
  }

  Generator* GetHostGenerator() const {
    PADDLE_ENFORCE_NOT_NULL(
        host_generator_,
        phi::errors::InvalidArgument("Required generator_ shall not be "
                                     "nullptr, but received nullptr."));
    return host_generator_;
W
Wilber 已提交
147 148
  }

149 150 151 152
 private:
  const Allocator* device_allocator_{nullptr};
  const Allocator* host_allocator_{nullptr};
  const Allocator* zero_allocator_{nullptr};
L
Leo Chen 已提交
153 154
  Generator* device_generator_{nullptr};
  Generator* host_generator_{nullptr};
W
Wilber 已提交
155 156 157 158 159
};

DeviceContext::DeviceContext() { impl_ = std::make_unique<Impl>(); }

DeviceContext::DeviceContext(const DeviceContext& other) {
160
  impl_->SetHostAllocator(&other.GetHostAllocator());
W
Wilber 已提交
161
  impl_->SetAllocator(&other.GetAllocator());
162
  impl_->SetZeroAllocator(&other.GetZeroAllocator());
L
Leo Chen 已提交
163 164
  impl_->SetHostGenerator(other.GetHostGenerator());
  impl_->SetGenerator(other.GetGenerator());
W
Wilber 已提交
165 166 167 168 169 170
}

DeviceContext::DeviceContext(DeviceContext&& other) {
  impl_ = std::move(other.impl_);
}

171 172
DeviceContext& DeviceContext::operator=(DeviceContext&&) = default;

W
Wilber 已提交
173 174
DeviceContext::~DeviceContext() = default;

W
Wilber 已提交
175 176
void DeviceContext::SetAllocator(const Allocator* allocator) {
  impl_->SetAllocator(allocator);
W
Wilber 已提交
177 178
}

W
Wilber 已提交
179 180
const Allocator& DeviceContext::GetAllocator() const {
  return impl_->GetAllocator();
181 182 183 184
}

void DeviceContext::SetHostAllocator(const Allocator* allocator) {
  impl_->SetHostAllocator(allocator);
W
Wilber 已提交
185 186 187 188
}

const Allocator& DeviceContext::GetHostAllocator() const {
  return impl_->GetHostAllocator();
W
Wilber 已提交
189 190
}

191 192
void DeviceContext::SetZeroAllocator(const Allocator* allocator) {
  impl_->SetZeroAllocator(allocator);
W
Wilber 已提交
193 194
}

195 196 197
const Allocator& DeviceContext::GetZeroAllocator() const {
  return impl_->GetZeroAllocator();
}
W
Wilber 已提交
198

199 200 201 202
void* DeviceContext::Alloc(TensorBase* tensor,
                           DataType dtype,
                           size_t requested_size) const {
  return impl_->Alloc(tensor, dtype, requested_size);
W
Wilber 已提交
203
}
W
Wilber 已提交
204

205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241
template <typename T>
T* DeviceContext::Alloc(TensorBase* tensor, size_t requested_size) const {
  return impl_->Alloc<T>(tensor, requested_size);
}

void* DeviceContext::HostAlloc(TensorBase* tensor,
                               DataType dtype,
                               size_t requested_size) const {
  return impl_->HostAlloc(tensor, dtype, requested_size);
}

template <typename T>
T* DeviceContext::HostAlloc(TensorBase* tensor, size_t requested_size) const {
  return impl_->HostAlloc<T>(tensor, requested_size);
}

#define DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(dtype)              \
  template dtype* DeviceContext::Alloc(TensorBase* tensor,           \
                                       size_t requested_size) const; \
  template dtype* DeviceContext::HostAlloc(TensorBase* tensor,       \
                                           size_t requested_size) const;

DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(bool)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(int8_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(uint8_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(int16_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(int32_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(int64_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(float)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(double)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::bfloat16)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::float16)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::complex64)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::complex128)

#undef DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION

W
Wilber 已提交
242 243 244 245
void DeviceContext::SetGenerator(Generator* gen) { impl_->SetGenerator(gen); }

Generator* DeviceContext::GetGenerator() const { return impl_->GetGenerator(); }

L
Leo Chen 已提交
246 247 248 249 250 251 252 253
void DeviceContext::SetHostGenerator(Generator* gen) {
  impl_->SetHostGenerator(gen);
}

Generator* DeviceContext::GetHostGenerator() const {
  return impl_->GetHostGenerator();
}

254
}  // namespace phi