device_context.cc 9.7 KB
Newer Older
W
Wilber 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/phi/core/device_context.h"
16
#include "paddle/phi/core/dense_tensor.h"
17
#include "paddle/phi/core/enforce.h"
18
#include "paddle/phi/core/selected_rows.h"
J
Jack Zhou 已提交
19
#include "paddle/phi/core/string_tensor.h"
W
Wilber 已提交
20

21
namespace phi {
22
using DataType = paddle::experimental::DataType;
W
Wilber 已提交
23 24 25 26 27

struct DeviceContext::Impl {
  Impl() = default;
  ~Impl() = default;

W
Wilber 已提交
28
  void SetAllocator(const Allocator* allocator) {
29 30
    PADDLE_ENFORCE_NOT_NULL(
        allocator,
31
        phi::errors::InvalidArgument(
32
            "Required allocator shall not be nullptr, but received nullptr."));
W
Wilber 已提交
33 34 35
    device_allocator_ = allocator;
  }

36 37 38
  void SetHostAllocator(const Allocator* allocator) {
    PADDLE_ENFORCE_NOT_NULL(
        allocator,
39
        phi::errors::InvalidArgument(
40 41 42 43 44 45 46
            "Required allocator shall not be nullptr, but received nullptr."));
    host_allocator_ = allocator;
  }

  void SetZeroAllocator(const Allocator* allocator) {
    PADDLE_ENFORCE_NOT_NULL(
        allocator,
47
        phi::errors::InvalidArgument(
48 49 50
            "Required allocator shall not be nullptr, but received nullptr."));
    zero_allocator_ = allocator;
  }
W
Wilber 已提交
51

W
Wilber 已提交
52
  const Allocator& GetAllocator() const {
53 54
    PADDLE_ENFORCE_NOT_NULL(
        device_allocator_,
55 56
        phi::errors::InvalidArgument("Required device_allocator_ shall not be "
                                     "nullptr, but received nullptr."));
W
Wilber 已提交
57 58
    return *device_allocator_;
  }
W
Wilber 已提交
59

W
Wilber 已提交
60
  const Allocator& GetHostAllocator() const {
61 62
    PADDLE_ENFORCE_NOT_NULL(
        host_allocator_,
63 64
        phi::errors::InvalidArgument("Required host_allocator_ shall not be "
                                     "nullptr, but received nullptr."));
W
Wilber 已提交
65 66
    return *host_allocator_;
  }
W
Wilber 已提交
67

68 69 70
  const Allocator& GetZeroAllocator() const {
    PADDLE_ENFORCE_NOT_NULL(
        zero_allocator_,
71 72
        phi::errors::InvalidArgument("Required host_allocator_ shall not be "
                                     "nullptr, but received nullptr."));
73 74 75 76
    return *zero_allocator_;
  }

  void* Alloc(TensorBase* tensor,
77
              const Place& place,
78 79 80 81
              DataType dtype = DataType::UNDEFINED,
              size_t requested_size = 0) const {
    PADDLE_ENFORCE_NOT_NULL(
        tensor,
82
        phi::errors::InvalidArgument(
83 84 85 86
            "Required tensor shall not be nullptr, but received nullptr."));
    if (dtype == DataType::UNDEFINED) {
      dtype = tensor->dtype();
    }
87 88 89 90 91 92
    // NOTE(paddle-dev): In case of tensor has already hold allocation and
    // is going to allocate allocation on new place, we will clear its holder
    // firstly and then re-alloc it.
    if (tensor->initialized() && tensor->place() != place) {
      ClearHolder(tensor);
    }
93 94 95 96 97 98 99
    auto* allocator =
        tensor->numel() == 0 ? zero_allocator_ : device_allocator_;
    return tensor->AllocateFrom(
        const_cast<Allocator*>(allocator), dtype, requested_size);
  }

  template <typename T>
100 101 102
  T* Alloc(TensorBase* tensor,
           const Place& place,
           size_t requested_size = 0) const {
103
    DataType dtype = paddle::experimental::CppTypeToDataType<T>::Type();
104
    return static_cast<T*>(Alloc(tensor, place, dtype, requested_size));
105
  }
W
Wilber 已提交
106

107 108 109 110 111
  void* HostAlloc(TensorBase* tensor,
                  DataType dtype = DataType::UNDEFINED,
                  size_t requested_size = 0) const {
    PADDLE_ENFORCE_NOT_NULL(
        tensor,
112
        phi::errors::InvalidArgument(
113 114 115 116
            "Required tensor shall not be nullptr, but received nullptr."));
    if (dtype == DataType::UNDEFINED) {
      dtype = tensor->dtype();
    }
117 118 119
    if (tensor->initialized() && tensor->place() != CPUPlace()) {
      ClearHolder(tensor);
    }
120 121 122 123 124 125
    auto* allocator = tensor->numel() == 0 ? zero_allocator_ : host_allocator_;
    return tensor->AllocateFrom(
        const_cast<Allocator*>(allocator), dtype, requested_size);
  }

  template <typename T>
126
  T* HostAlloc(phi::TensorBase* tensor, size_t requested_size = 0) const {
127 128 129 130
    DataType dtype = paddle::experimental::CppTypeToDataType<T>::Type();
    return static_cast<T*>(HostAlloc(tensor, dtype, requested_size));
  }

W
Wilber 已提交
131 132 133
  void SetGenerator(Generator* gen) {
    PADDLE_ENFORCE_NOT_NULL(
        gen,
134
        phi::errors::InvalidArgument(
W
Wilber 已提交
135
            "Required generator shall not be nullptr, but received nullptr."));
L
Leo Chen 已提交
136
    device_generator_ = gen;
W
Wilber 已提交
137 138 139 140
  }

  Generator* GetGenerator() const {
    PADDLE_ENFORCE_NOT_NULL(
L
Leo Chen 已提交
141
        device_generator_,
142 143
        phi::errors::InvalidArgument("Required generator_ shall not be "
                                     "nullptr, but received nullptr."));
L
Leo Chen 已提交
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
    return device_generator_;
  }

  void SetHostGenerator(Generator* gen) {
    PADDLE_ENFORCE_NOT_NULL(
        gen,
        phi::errors::InvalidArgument(
            "Required generator shall not be nullptr, but received nullptr."));
    host_generator_ = gen;
  }

  Generator* GetHostGenerator() const {
    PADDLE_ENFORCE_NOT_NULL(
        host_generator_,
        phi::errors::InvalidArgument("Required generator_ shall not be "
                                     "nullptr, but received nullptr."));
    return host_generator_;
W
Wilber 已提交
161 162
  }

163
 private:
164 165 166 167 168 169 170
  void ClearHolder(TensorBase* tensor) const {
    if (!tensor->initialized()) return;

    if (DenseTensor::classof(tensor)) {
      static_cast<DenseTensor*>(tensor)->clear();
    } else if (SelectedRows::classof(tensor)) {
      static_cast<SelectedRows*>(tensor)->mutable_value()->clear();
J
Jack Zhou 已提交
171 172
    } else if (StringTensor::classof(tensor)) {
      static_cast<StringTensor*>(tensor)->clear();
173 174 175 176 177 178
    } else {
      PADDLE_THROW(errors::Unimplemented(
          "Only support DenseTensor and SelectedRows now."));
    }
  }

179 180 181
  const Allocator* device_allocator_{nullptr};
  const Allocator* host_allocator_{nullptr};
  const Allocator* zero_allocator_{nullptr};
L
Leo Chen 已提交
182 183
  Generator* device_generator_{nullptr};
  Generator* host_generator_{nullptr};
W
Wilber 已提交
184 185 186 187 188
};

DeviceContext::DeviceContext() { impl_ = std::make_unique<Impl>(); }

DeviceContext::DeviceContext(const DeviceContext& other) {
189
  impl_->SetHostAllocator(&other.GetHostAllocator());
W
Wilber 已提交
190
  impl_->SetAllocator(&other.GetAllocator());
191
  impl_->SetZeroAllocator(&other.GetZeroAllocator());
L
Leo Chen 已提交
192 193
  impl_->SetHostGenerator(other.GetHostGenerator());
  impl_->SetGenerator(other.GetGenerator());
W
Wilber 已提交
194 195 196 197 198 199
}

DeviceContext::DeviceContext(DeviceContext&& other) {
  impl_ = std::move(other.impl_);
}

200
DeviceContext& DeviceContext::operator=(DeviceContext&& other) = default;
201

W
Wilber 已提交
202 203
DeviceContext::~DeviceContext() = default;

W
Wilber 已提交
204 205
void DeviceContext::SetAllocator(const Allocator* allocator) {
  impl_->SetAllocator(allocator);
W
Wilber 已提交
206 207
}

W
Wilber 已提交
208 209
const Allocator& DeviceContext::GetAllocator() const {
  return impl_->GetAllocator();
210 211 212 213
}

void DeviceContext::SetHostAllocator(const Allocator* allocator) {
  impl_->SetHostAllocator(allocator);
W
Wilber 已提交
214 215 216 217
}

const Allocator& DeviceContext::GetHostAllocator() const {
  return impl_->GetHostAllocator();
W
Wilber 已提交
218 219
}

220 221
void DeviceContext::SetZeroAllocator(const Allocator* allocator) {
  impl_->SetZeroAllocator(allocator);
W
Wilber 已提交
222 223
}

224 225 226
const Allocator& DeviceContext::GetZeroAllocator() const {
  return impl_->GetZeroAllocator();
}
W
Wilber 已提交
227

228 229 230
void* DeviceContext::Alloc(TensorBase* tensor,
                           DataType dtype,
                           size_t requested_size) const {
231
  return impl_->Alloc(tensor, GetPlace(), dtype, requested_size);
W
Wilber 已提交
232
}
W
Wilber 已提交
233

234 235
template <typename T>
T* DeviceContext::Alloc(TensorBase* tensor, size_t requested_size) const {
236
  return impl_->Alloc<T>(tensor, GetPlace(), requested_size);
237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267
}

void* DeviceContext::HostAlloc(TensorBase* tensor,
                               DataType dtype,
                               size_t requested_size) const {
  return impl_->HostAlloc(tensor, dtype, requested_size);
}

template <typename T>
T* DeviceContext::HostAlloc(TensorBase* tensor, size_t requested_size) const {
  return impl_->HostAlloc<T>(tensor, requested_size);
}

#define DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(dtype)              \
  template dtype* DeviceContext::Alloc(TensorBase* tensor,           \
                                       size_t requested_size) const; \
  template dtype* DeviceContext::HostAlloc(TensorBase* tensor,       \
                                           size_t requested_size) const;

DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(bool)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(int8_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(uint8_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(int16_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(int32_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(int64_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(float)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(double)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::bfloat16)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::float16)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::complex64)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::complex128)
J
Jack Zhou 已提交
268
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::pstring)
269 270 271

#undef DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION

W
Wilber 已提交
272 273 274 275
void DeviceContext::SetGenerator(Generator* gen) { impl_->SetGenerator(gen); }

Generator* DeviceContext::GetGenerator() const { return impl_->GetGenerator(); }

L
Leo Chen 已提交
276 277 278 279 280 281 282 283
void DeviceContext::SetHostGenerator(Generator* gen) {
  impl_->SetHostGenerator(gen);
}

Generator* DeviceContext::GetHostGenerator() const {
  return impl_->GetHostGenerator();
}

284
}  // namespace phi