tensor.cc 13.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/api/include/tensor.h"
16 17 18 19 20 21

#include <memory>
#include <utility>
#include <vector>

#include "glog/logging.h"
22 23 24 25 26 27
#include "paddle/phi/api/lib/ext_compat_utils.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/api/lib/utils/storage.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/selected_rows.h"
28 29
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
30 31 32
#include "paddle/phi/core/tensor_base.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/core/tensor_utils.h"
33 34 35 36 37
/**
 * [ Why still include the fluid headers? ]
 *
 * We hope to organize the basic implementation of Tensor and the logic related
 * to Tensor computation into an independent library, which we call
38
 * [Tensor Operation Library, phi], so we extract or rewrite the original
39 40 41 42 43 44 45 46 47 48 49 50
 * Kernels.
 *
 * In the future, the training library, inference library and custom operators
 * will link to this Tensor Operation library.
 *
 * However, if we directly split the link relation, we need to make too many
 * changes, which will affect the stability of the framework, so here we still
 * rely on the implementation of the framework, which is a intermediate state.
 *
 * In the future, the necessary components will be moved to the this library,
 * or the corresponding components will be re-implemented.
 */
51

52 53 54
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/stream/cuda_stream.h"
55 56 57 58
#include "paddle/phi/common/complex.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/enforce.h"
59 60 61 62 63 64 65 66

namespace paddle {
namespace experimental {

/////// Tensor Methods ////////

/* Part 1: Construction and destruction methods */

67
Tensor::Tensor(std::shared_ptr<phi::TensorBase> tensor_impl)
68
    : impl_(std::move(tensor_impl)) {
69 70 71
  PADDLE_ENFORCE_NOT_NULL(
      impl_,
      phi::errors::InvalidArgument("TensorImpl with nullptr is not supported"));
72 73 74
}

Tensor::Tensor(const PlaceType &place)
75 76
    : impl_(std::move(std::make_shared<phi::DenseTensor>(
          std::move(phi::make_intrusive<SharedStorage>(
77
              ConvertExtPlaceToInnerPlace(place))),
78 79 80
          std::move(phi::DenseTensorMeta(phi::DataType::UNDEFINED,
                                         phi::make_ddim({}),
                                         phi::DataLayout::NCHW))))),
81
      place_{place} {}
82 83

Tensor::Tensor(const PlaceType &place, const std::vector<int64_t> &shape)
84 85
    : impl_(std::move(std::make_shared<phi::DenseTensor>(
          std::move(phi::make_intrusive<SharedStorage>(
86
              ConvertExtPlaceToInnerPlace(place))),
87 88 89
          std::move(phi::DenseTensorMeta(phi::DataType::UNDEFINED,
                                         phi::make_ddim(shape),
                                         phi::DataLayout::NCHW))))),
90
      place_{place} {}
91

92
Tensor::Tensor(std::shared_ptr<phi::TensorBase> tensor_impl,
93 94
               const std::string &name)
    : impl_(std::move(tensor_impl)), name_(std::move(name)) {}
95 96 97 98 99 100
/* Part 2: Dimension, DataType and DataLayout methods */

int64_t Tensor::numel() const { return impl_->numel(); }

int64_t Tensor::size() const { return impl_->numel(); }

101
phi::DDim Tensor::dims() const { return impl_->dims(); }
102 103

std::vector<int64_t> Tensor::shape() const {
104
  return phi::vectorize<int64_t>(impl_->dims());
105 106 107
}

void Tensor::reshape(const std::vector<int64_t> &shape) {
108 109 110 111 112 113 114 115
  LOG(WARNING) << "The function of resetting the shape of the uninitialized "
                  "Tensor of the `reshape` method is deprecated since version "
                  "2.3, and will be removed in version 2.4, please use "
                  "`paddle::experimental::full` method to create a new Tensor "
                  "instead. "
                  "reason: `reshape` means changing the tensor shape without "
                  "touching underlying data, this requires the total size of "
                  "the tensor to remain constant.";
C
Chen Weihang 已提交
116
  if (is_dense_tensor()) {
117 118
    std::dynamic_pointer_cast<phi::DenseTensor>(impl_)->Resize(
        phi::make_ddim(shape));
119
  } else {
120
    PADDLE_THROW(phi::errors::Unimplemented(
121 122
        "Only support reshape operation on DenseTensor now."));
  }
123 124
}

125
DataType Tensor::dtype() const { return impl_->dtype(); }
126

127
DataType Tensor::type() const { return impl_->dtype(); }
128 129 130

DataLayout Tensor::layout() const { return impl_->layout(); }

C
Chen Weihang 已提交
131
bool Tensor::is_dense_tensor() const {
132
  return phi::DenseTensor::classof(impl_.get());
C
Chen Weihang 已提交
133
}
134
bool Tensor::is_selected_rows() const {
135
  return phi::SelectedRows::classof(impl_.get());
136
}
137 138 139 140 141 142
bool Tensor::is_sparse_coo_tensor() const {
  return phi::SparseCooTensor::classof(impl_.get());
}
bool Tensor::is_sparse_csr_tensor() const {
  return phi::SparseCsrTensor::classof(impl_.get());
}
143 144 145
/* Part 3: Device and Backend methods */

PlaceType Tensor::place() const {
146 147 148 149 150
  if (!impl_->initialized()) {
    return place_;
  } else {
    return ConvertInnerPlaceToExtPlace(impl_->place());
  }
151 152
}

153
paddle::platform::Place Tensor::inner_place() const {
154 155 156 157 158 159
  PADDLE_ENFORCE_NOT_NULL(
      impl_,
      phi::errors::PermissionDenied(
          "Null pointer error, the impl_ of Tensor should not be "
          "Null when calling Tensor::inner_place()."));
  return impl_->place();
160
}
161 162

bool Tensor::is_cpu() const {
163
  return paddle::platform::is_cpu_place(inner_place());
164 165
}

166
bool Tensor::is_gpu() const {
167
  return paddle::platform::is_gpu_place(inner_place());
168 169
}

170 171 172 173
bool Tensor::is_gpu_pinned() const {
  return paddle::platform::is_cuda_pinned_place(inner_place());
}

174 175 176 177
/* Part 4: Data Access methods */

template <typename T>
T *Tensor::mutable_data() {
C
Chen Weihang 已提交
178
  if (is_dense_tensor()) {
179
    return std::dynamic_pointer_cast<phi::DenseTensor>(impl_)->mutable_data<T>(
180
        ConvertExtPlaceToInnerPlace(place()));
181 182 183 184
  }
  return nullptr;
}

185 186 187 188 189 190 191 192
template PADDLE_API float *Tensor::mutable_data<float>();
template PADDLE_API double *Tensor::mutable_data<double>();
template PADDLE_API int64_t *Tensor::mutable_data<int64_t>();
template PADDLE_API int32_t *Tensor::mutable_data<int32_t>();
template PADDLE_API uint8_t *Tensor::mutable_data<uint8_t>();
template PADDLE_API int8_t *Tensor::mutable_data<int8_t>();
template PADDLE_API int16_t *Tensor::mutable_data<int16_t>();
template PADDLE_API bool *Tensor::mutable_data<bool>();
193 194 195 196 197 198
template PADDLE_API phi::dtype::complex<float>
    *Tensor::mutable_data<phi::dtype::complex<float>>();
template PADDLE_API phi::dtype::complex<double>
    *Tensor::mutable_data<phi::dtype::complex<double>>();
template PADDLE_API phi::dtype::float16 *
Tensor::mutable_data<phi::dtype::float16>();
199 200 201 202

template <typename T>
T *Tensor::mutable_data(const PlaceType &place) {
  auto inner_place = ConvertExtPlaceToInnerPlace(place);
203 204 205 206
  if (impl_->initialized()) {
    PADDLE_ENFORCE_EQ(
        platform::is_same_place(inner_place, impl_->place()),
        true,
207 208
        phi::errors::Unimplemented("Modification of tensor place through "
                                   "mutable_data is not supported now"));
209 210
  }
  if (is_dense_tensor()) {
211
    return std::dynamic_pointer_cast<phi::DenseTensor>(impl_)->mutable_data<T>(
212 213 214
        inner_place);
  }
  return nullptr;
215 216
}

217 218
template PADDLE_API float *Tensor::mutable_data<float>(const PlaceType &place);
template PADDLE_API double *Tensor::mutable_data<double>(
219
    const PlaceType &place);
220
template PADDLE_API int64_t *Tensor::mutable_data<int64_t>(
221
    const PlaceType &place);
222
template PADDLE_API int32_t *Tensor::mutable_data<int32_t>(
223
    const PlaceType &place);
224
template PADDLE_API uint8_t *Tensor::mutable_data<uint8_t>(
225
    const PlaceType &place);
226
template PADDLE_API int8_t *Tensor::mutable_data<int8_t>(
227
    const PlaceType &place);
228
template PADDLE_API int16_t *Tensor::mutable_data<int16_t>(
229
    const PlaceType &place);
230
template PADDLE_API bool *Tensor::mutable_data<bool>(const PlaceType &place);
231 232 233 234 235 236
template PADDLE_API phi::dtype::complex<float>
    *Tensor::mutable_data<phi::dtype::complex<float>>(const PlaceType &place);
template PADDLE_API phi::dtype::complex<double>
    *Tensor::mutable_data<phi::dtype::complex<double>>(const PlaceType &place);
template PADDLE_API phi::dtype::float16 *
Tensor::mutable_data<phi::dtype::float16>(const PlaceType &place);
237 238 239

template <typename T>
const T *Tensor::data() const {
C
Chen Weihang 已提交
240
  if (is_dense_tensor()) {
241 242 243
    return std::dynamic_pointer_cast<phi::DenseTensor>(impl_)->data<T>();
  } else if (phi::SelectedRows::classof(impl_.get())) {
    return std::dynamic_pointer_cast<phi::SelectedRows>(impl_)
244 245
        ->value()
        .data<T>();
246 247 248 249
  }
  return nullptr;
}

250 251 252 253 254 255 256 257
template PADDLE_API const float *Tensor::data<float>() const;
template PADDLE_API const double *Tensor::data<double>() const;
template PADDLE_API const int64_t *Tensor::data<int64_t>() const;
template PADDLE_API const int32_t *Tensor::data<int32_t>() const;
template PADDLE_API const uint8_t *Tensor::data<uint8_t>() const;
template PADDLE_API const int8_t *Tensor::data<int8_t>() const;
template PADDLE_API const int16_t *Tensor::data<int16_t>() const;
template PADDLE_API const bool *Tensor::data<bool>() const;
258 259 260 261 262 263 264 265
template PADDLE_API const phi::dtype::complex<float>
    *Tensor::data<phi::dtype::complex<float>>() const;
template PADDLE_API const phi::dtype::complex<double>
    *Tensor::data<phi::dtype::complex<double>>() const;
template PADDLE_API const phi::dtype::float16 *
Tensor::data<phi::dtype::float16>() const;
template PADDLE_API const phi::dtype::bfloat16 *
Tensor::data<phi::dtype::bfloat16>() const;
266 267 268

template <typename T>
T *Tensor::data() {
269
  if (is_dense_tensor()) {
270 271 272
    return std::dynamic_pointer_cast<phi::DenseTensor>(impl_)->data<T>();
  } else if (phi::SelectedRows::classof(impl_.get())) {
    return std::dynamic_pointer_cast<phi::SelectedRows>(impl_)
273 274 275
        ->mutable_value()
        ->data<T>();
  }
276 277 278
  return nullptr;
}

279 280 281 282 283 284 285 286
template PADDLE_API float *Tensor::data<float>();
template PADDLE_API double *Tensor::data<double>();
template PADDLE_API int64_t *Tensor::data<int64_t>();
template PADDLE_API int32_t *Tensor::data<int32_t>();
template PADDLE_API uint8_t *Tensor::data<uint8_t>();
template PADDLE_API int8_t *Tensor::data<int8_t>();
template PADDLE_API int16_t *Tensor::data<int16_t>();
template PADDLE_API bool *Tensor::data<bool>();
287 288 289 290 291
template PADDLE_API phi::dtype::complex<float>
    *Tensor::data<phi::dtype::complex<float>>();
template PADDLE_API phi::dtype::complex<double>
    *Tensor::data<phi::dtype::complex<double>>();
template PADDLE_API phi::dtype::float16 *Tensor::data<phi::dtype::float16>();
292

293
// TODO(chenweihang): replace slice impl by API
294
Tensor Tensor::slice(int64_t begin_idx, int64_t end_idx) const {
C
Chen Weihang 已提交
295
  if (is_dense_tensor()) {
296 297 298
    return Tensor(std::make_shared<phi::DenseTensor>(
        std::move(phi::DenseTensorUtils::Slice(
            *(std::dynamic_pointer_cast<phi::DenseTensor>(impl_).get()),
299 300 301
            begin_idx,
            end_idx))));
  } else {
302
    PADDLE_THROW(phi::errors::Unimplemented(
303
        "Only support slice operation on DenseTensor now."));
304
  }
305 306
}

307
const std::shared_ptr<phi::TensorBase> &Tensor::impl() const { return impl_; }
308

309
void Tensor::set_impl(const std::shared_ptr<phi::TensorBase> &impl) {
310 311 312
  impl_ = impl;
}

313 314 315 316
void Tensor::set_impl(std::shared_ptr<phi::TensorBase> &&impl) {
  impl_ = std::move(impl);
}

317 318 319 320 321 322
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
gpuStream_t Tensor::stream() const {
  return platform::stream::get_current_stream(-1)->raw_stream();
}
#endif

323
/* Part 5: Status utils methods */
324 325 326

bool Tensor::defined() const { return impl_ != nullptr; }

327
bool Tensor::initialized() const { return defined() && impl_->initialized(); }
328 329

bool Tensor::is_initialized() const {
330
  return defined() && impl_->initialized();
331 332 333 334
}

void Tensor::reset() { impl_.reset(); }

335
/* Part 6: Operator overloading */
336 337 338 339

Tensor &Tensor::operator=(const Tensor &x) & {
  impl_ = x.impl_;
  autograd_meta_ = x.autograd_meta_;
340 341
  name_ = x.name_;
  place_ = x.place_;
342 343 344 345 346 347
  return *this;
}

Tensor &Tensor::operator=(Tensor &&x) & {
  impl_ = std::move(x.impl_);
  autograd_meta_ = std::move(x.autograd_meta_);
348 349
  name_ = std::move(x.name_);
  place_ = std::move(x.place_);
350 351 352 353 354 355 356 357 358 359 360 361
  return *this;
}

AbstractAutogradMeta *Tensor::get_autograd_meta() const {
  return autograd_meta_.get();
}

void Tensor::set_autograd_meta(
    std::shared_ptr<AbstractAutogradMeta> autograd_meta) {
  autograd_meta_ = std::move(autograd_meta);
}

362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386
void Tensor::bump_inplace_version() {
  if (is_dense_tensor()) {
    auto &inplace_version_counter =
        std::dynamic_pointer_cast<phi::DenseTensor>(impl_)
            ->InplaceVersionCounter();
    inplace_version_counter.Bump();
  } else {
    PADDLE_THROW(phi::errors::Unimplemented(
        "bump_inplace_version is only supported on DenseTensor now."));
  }
}

uint32_t Tensor::current_inplace_version() {
  if (is_dense_tensor()) {
    auto &inplace_version_counter =
        std::dynamic_pointer_cast<phi::DenseTensor>(impl_)
            ->InplaceVersionCounter();
    return inplace_version_counter.CurrentVersion();
  } else {
    PADDLE_THROW(phi::errors::Unimplemented(
        "current_inplace_version is only supported on DenseTensor now."));
  }
  return 0;
}

387 388
}  // namespace experimental
}  // namespace paddle