tensor.cc 14.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/api/include/tensor.h"
16 17 18 19 20 21

#include <memory>
#include <utility>
#include <vector>

#include "glog/logging.h"
22 23 24 25 26 27 28 29 30 31
#include "paddle/phi/api/include/manual_api.h"
#include "paddle/phi/api/lib/ext_compat_utils.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/api/lib/utils/storage.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/selected_rows.h"
#include "paddle/phi/core/tensor_base.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/core/tensor_utils.h"
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
/**
 * [ Why still include the fluid headers? ]
 *
 * We hope to organize the basic implementation of Tensor and the logic related
 * to Tensor computation into an independent library, which we call
 * [Tensor Operation Library, pten], so we extract or rewrite the original
 * Kernels.
 *
 * In the future, the training library, inference library and custom operators
 * will link to this Tensor Operation library.
 *
 * However, if we directly split the link relation, we need to make too many
 * changes, which will affect the stability of the framework, so here we still
 * rely on the implementation of the framework, which is a intermediate state.
 *
 * In the future, the necessary components will be moved to the this library,
 * or the corresponding components will be re-implemented.
 */
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/stream/cuda_stream.h"
53 54 55 56
#include "paddle/phi/common/complex.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/enforce.h"
57 58 59 60 61 62 63 64

namespace paddle {
namespace experimental {

/////// Tensor Methods ////////

/* Part 1: Construction and destruction methods */

65
Tensor::Tensor(std::shared_ptr<phi::TensorBase> tensor_impl)
66
    : impl_(std::move(tensor_impl)) {
67 68 69
  PADDLE_ENFORCE_NOT_NULL(
      impl_,
      phi::errors::InvalidArgument("TensorImpl with nullptr is not supported"));
70 71 72
}

Tensor::Tensor(const PlaceType &place)
73 74
    : impl_(std::move(std::make_shared<phi::DenseTensor>(
          std::move(phi::make_intrusive<SharedStorage>(
75
              ConvertExtPlaceToInnerPlace(place))),
76 77 78
          std::move(phi::DenseTensorMeta(phi::DataType::UNDEFINED,
                                         phi::make_ddim({}),
                                         phi::DataLayout::NCHW))))),
79
      place_{place} {}
80 81

Tensor::Tensor(const PlaceType &place, const std::vector<int64_t> &shape)
82 83
    : impl_(std::move(std::make_shared<phi::DenseTensor>(
          std::move(phi::make_intrusive<SharedStorage>(
84
              ConvertExtPlaceToInnerPlace(place))),
85 86 87
          std::move(phi::DenseTensorMeta(phi::DataType::UNDEFINED,
                                         phi::make_ddim(shape),
                                         phi::DataLayout::NCHW))))),
88
      place_{place} {}
89

90
Tensor::Tensor(std::shared_ptr<phi::TensorBase> tensor_impl,
91 92
               const std::string &name)
    : impl_(std::move(tensor_impl)), name_(std::move(name)) {}
93 94 95 96 97 98
/* Part 2: Dimension, DataType and DataLayout methods */

int64_t Tensor::numel() const { return impl_->numel(); }

int64_t Tensor::size() const { return impl_->numel(); }

99
phi::DDim Tensor::dims() const { return impl_->dims(); }
100 101

std::vector<int64_t> Tensor::shape() const {
102
  return phi::vectorize<int64_t>(impl_->dims());
103 104 105
}

void Tensor::reshape(const std::vector<int64_t> &shape) {
106 107 108 109 110 111 112 113
  LOG(WARNING) << "The function of resetting the shape of the uninitialized "
                  "Tensor of the `reshape` method is deprecated since version "
                  "2.3, and will be removed in version 2.4, please use "
                  "`paddle::experimental::full` method to create a new Tensor "
                  "instead. "
                  "reason: `reshape` means changing the tensor shape without "
                  "touching underlying data, this requires the total size of "
                  "the tensor to remain constant.";
C
Chen Weihang 已提交
114
  if (is_dense_tensor()) {
115 116
    std::dynamic_pointer_cast<phi::DenseTensor>(impl_)->set_meta(
        phi::DenseTensorMeta(dtype(), phi::make_ddim(shape)));
117
  } else {
118
    PADDLE_THROW(phi::errors::Unimplemented(
119 120
        "Only support reshape operation on DenseTensor now."));
  }
121 122
}

123
DataType Tensor::dtype() const { return impl_->dtype(); }
124

125
DataType Tensor::type() const { return impl_->dtype(); }
126 127 128

DataLayout Tensor::layout() const { return impl_->layout(); }

C
Chen Weihang 已提交
129
bool Tensor::is_dense_tensor() const {
130
  return phi::DenseTensor::classof(impl_.get());
C
Chen Weihang 已提交
131
}
132
bool Tensor::is_selected_rows() const {
133
  return phi::SelectedRows::classof(impl_.get());
134
}
135 136 137
/* Part 3: Device and Backend methods */

PlaceType Tensor::place() const {
138 139 140 141 142
  if (!impl_->initialized()) {
    return place_;
  } else {
    return ConvertInnerPlaceToExtPlace(impl_->place());
  }
143 144
}

145 146 147
paddle::platform::Place Tensor::inner_place() const {
  return ConvertExtPlaceToInnerPlace(place());
}
148 149

bool Tensor::is_cpu() const {
150
  return paddle::platform::is_cpu_place(inner_place());
151 152 153
}

bool Tensor::is_cuda() const {
154
  return paddle::platform::is_gpu_place(inner_place());
155 156 157 158 159 160
}

/* Part 4: Data Access methods */

template <typename T>
T *Tensor::mutable_data() {
C
Chen Weihang 已提交
161
  if (is_dense_tensor()) {
162
    return std::dynamic_pointer_cast<phi::DenseTensor>(impl_)->mutable_data<T>(
163
        ConvertExtPlaceToInnerPlace(place()));
164 165 166 167
  }
  return nullptr;
}

168 169 170 171 172 173 174 175
template PADDLE_API float *Tensor::mutable_data<float>();
template PADDLE_API double *Tensor::mutable_data<double>();
template PADDLE_API int64_t *Tensor::mutable_data<int64_t>();
template PADDLE_API int32_t *Tensor::mutable_data<int32_t>();
template PADDLE_API uint8_t *Tensor::mutable_data<uint8_t>();
template PADDLE_API int8_t *Tensor::mutable_data<int8_t>();
template PADDLE_API int16_t *Tensor::mutable_data<int16_t>();
template PADDLE_API bool *Tensor::mutable_data<bool>();
176 177 178 179 180 181
template PADDLE_API phi::dtype::complex<float>
    *Tensor::mutable_data<phi::dtype::complex<float>>();
template PADDLE_API phi::dtype::complex<double>
    *Tensor::mutable_data<phi::dtype::complex<double>>();
template PADDLE_API phi::dtype::float16 *
Tensor::mutable_data<phi::dtype::float16>();
182 183 184 185

template <typename T>
T *Tensor::mutable_data(const PlaceType &place) {
  auto inner_place = ConvertExtPlaceToInnerPlace(place);
186 187 188 189
  if (impl_->initialized()) {
    PADDLE_ENFORCE_EQ(
        platform::is_same_place(inner_place, impl_->place()),
        true,
190 191
        phi::errors::Unimplemented("Modification of tensor place through "
                                   "mutable_data is not supported now"));
192 193
  }
  if (is_dense_tensor()) {
194
    return std::dynamic_pointer_cast<phi::DenseTensor>(impl_)->mutable_data<T>(
195 196 197
        inner_place);
  }
  return nullptr;
198 199
}

200 201
template PADDLE_API float *Tensor::mutable_data<float>(const PlaceType &place);
template PADDLE_API double *Tensor::mutable_data<double>(
202
    const PlaceType &place);
203
template PADDLE_API int64_t *Tensor::mutable_data<int64_t>(
204
    const PlaceType &place);
205
template PADDLE_API int32_t *Tensor::mutable_data<int32_t>(
206
    const PlaceType &place);
207
template PADDLE_API uint8_t *Tensor::mutable_data<uint8_t>(
208
    const PlaceType &place);
209
template PADDLE_API int8_t *Tensor::mutable_data<int8_t>(
210
    const PlaceType &place);
211
template PADDLE_API int16_t *Tensor::mutable_data<int16_t>(
212
    const PlaceType &place);
213
template PADDLE_API bool *Tensor::mutable_data<bool>(const PlaceType &place);
214 215 216 217 218 219
template PADDLE_API phi::dtype::complex<float>
    *Tensor::mutable_data<phi::dtype::complex<float>>(const PlaceType &place);
template PADDLE_API phi::dtype::complex<double>
    *Tensor::mutable_data<phi::dtype::complex<double>>(const PlaceType &place);
template PADDLE_API phi::dtype::float16 *
Tensor::mutable_data<phi::dtype::float16>(const PlaceType &place);
220 221 222

template <typename T>
const T *Tensor::data() const {
C
Chen Weihang 已提交
223
  if (is_dense_tensor()) {
224 225 226
    return std::dynamic_pointer_cast<phi::DenseTensor>(impl_)->data<T>();
  } else if (phi::SelectedRows::classof(impl_.get())) {
    return std::dynamic_pointer_cast<phi::SelectedRows>(impl_)
227 228
        ->value()
        .data<T>();
229 230 231 232
  }
  return nullptr;
}

233 234 235 236 237 238 239 240
template PADDLE_API const float *Tensor::data<float>() const;
template PADDLE_API const double *Tensor::data<double>() const;
template PADDLE_API const int64_t *Tensor::data<int64_t>() const;
template PADDLE_API const int32_t *Tensor::data<int32_t>() const;
template PADDLE_API const uint8_t *Tensor::data<uint8_t>() const;
template PADDLE_API const int8_t *Tensor::data<int8_t>() const;
template PADDLE_API const int16_t *Tensor::data<int16_t>() const;
template PADDLE_API const bool *Tensor::data<bool>() const;
241 242 243 244 245 246 247 248
template PADDLE_API const phi::dtype::complex<float>
    *Tensor::data<phi::dtype::complex<float>>() const;
template PADDLE_API const phi::dtype::complex<double>
    *Tensor::data<phi::dtype::complex<double>>() const;
template PADDLE_API const phi::dtype::float16 *
Tensor::data<phi::dtype::float16>() const;
template PADDLE_API const phi::dtype::bfloat16 *
Tensor::data<phi::dtype::bfloat16>() const;
249 250 251

template <typename T>
T *Tensor::data() {
252
  if (is_dense_tensor()) {
253 254 255
    return std::dynamic_pointer_cast<phi::DenseTensor>(impl_)->data<T>();
  } else if (phi::SelectedRows::classof(impl_.get())) {
    return std::dynamic_pointer_cast<phi::SelectedRows>(impl_)
256 257 258
        ->mutable_value()
        ->data<T>();
  }
259 260 261
  return nullptr;
}

262 263 264 265 266 267 268 269
template PADDLE_API float *Tensor::data<float>();
template PADDLE_API double *Tensor::data<double>();
template PADDLE_API int64_t *Tensor::data<int64_t>();
template PADDLE_API int32_t *Tensor::data<int32_t>();
template PADDLE_API uint8_t *Tensor::data<uint8_t>();
template PADDLE_API int8_t *Tensor::data<int8_t>();
template PADDLE_API int16_t *Tensor::data<int16_t>();
template PADDLE_API bool *Tensor::data<bool>();
270 271 272 273 274
template PADDLE_API phi::dtype::complex<float>
    *Tensor::data<phi::dtype::complex<float>>();
template PADDLE_API phi::dtype::complex<double>
    *Tensor::data<phi::dtype::complex<double>>();
template PADDLE_API phi::dtype::float16 *Tensor::data<phi::dtype::float16>();
275

276
// TODO(chenweihang): replace slice impl by API
277
Tensor Tensor::slice(int64_t begin_idx, int64_t end_idx) const {
C
Chen Weihang 已提交
278
  if (is_dense_tensor()) {
279 280 281
    return Tensor(std::make_shared<phi::DenseTensor>(
        std::move(phi::DenseTensorUtils::Slice(
            *(std::dynamic_pointer_cast<phi::DenseTensor>(impl_).get()),
282 283 284
            begin_idx,
            end_idx))));
  } else {
285
    PADDLE_THROW(phi::errors::Unimplemented(
286
        "Only support slice operation on DenseTensor now."));
287
  }
288 289
}

290
std::shared_ptr<phi::TensorBase> Tensor::impl() const { return impl_; }
291

292
void Tensor::set_impl(const std::shared_ptr<phi::TensorBase> &impl) {
293 294 295 296 297 298 299 300 301 302 303 304 305
  impl_ = impl;
}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
gpuStream_t Tensor::stream() const {
  return platform::stream::get_current_stream(-1)->raw_stream();
}
#endif

/* Part 5: Data Transform methods */

template <typename T>
Tensor Tensor::copy_to(const PlaceType &target_place) const {
306
  LOG(WARNING) << "The Tensor's `copy_to` method is deprecated since version "
307
                  "2.3, and will be removed in version 2.4, please use "
308
                  "`copy_to` method without template argument instead. "
309 310
                  "reason: copying a Tensor to another device does not need "
                  "to specify the data type template argument.";
311
  return copy_to(ConvertExtPlaceToBackend(target_place), /*blocking=*/false);
312 313
}

314
template PADDLE_API Tensor
315
Tensor::copy_to<float>(const PlaceType &target_place) const;
316
template PADDLE_API Tensor
317
Tensor::copy_to<double>(const PlaceType &target_place) const;
318
template PADDLE_API Tensor
319
Tensor::copy_to<int64_t>(const PlaceType &target_place) const;
320
template PADDLE_API Tensor
321
Tensor::copy_to<int32_t>(const PlaceType &target_place) const;
322
template PADDLE_API Tensor
323
Tensor::copy_to<uint8_t>(const PlaceType &target_place) const;
324
template PADDLE_API Tensor
325
Tensor::copy_to<int8_t>(const PlaceType &target_place) const;
326
template PADDLE_API Tensor
327
Tensor::copy_to<int16_t>(const PlaceType &target_place) const;
328
template PADDLE_API Tensor
329
Tensor::copy_to<bool>(const PlaceType &target_place) const;
330
template PADDLE_API Tensor Tensor::copy_to<phi::dtype::complex<float>>(
331
    const PlaceType &target_place) const;
332
template PADDLE_API Tensor Tensor::copy_to<phi::dtype::complex<double>>(
333
    const PlaceType &target_place) const;
334
template PADDLE_API Tensor
335
Tensor::copy_to<phi::dtype::float16>(const PlaceType &target_place) const;
336

337 338
Tensor Tensor::copy_to(Backend backend, bool blocking) const {
  return experimental::copy_to(*this, backend, blocking);
339 340
}

341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362
void Tensor::copy_(const Tensor &src, bool blocking) {
  if (!src.is_initialized()) {
    return;
  }
  VLOG(3) << "Deep copy Tensor from " << src.name() << " to " << name();
  if (defined()) {
    PADDLE_ENFORCE_EQ(dtype(),
                      src.dtype(),
                      platform::errors::PreconditionNotMet(
                          "Tensor %s has different data type with Tensor %s, "
                          "Tensor Copy cannot be performed!",
                          name(),
                          src.name()));
    PADDLE_ENFORCE_EQ(impl()->type_info().id(),
                      src.impl()->type_info().id(),
                      platform::errors::PreconditionNotMet(
                          "Tensor %s has different type with Tensor %s, Tensor "
                          "Copy cannot be performed!",
                          name(),
                          src.name()));
  }
  auto copy_tensor =
363
      src.copy_to(phi::TransToPtenBackend(src.inner_place()), blocking);
364 365
  set_impl(copy_tensor.impl());
}
366 367 368 369 370

/* Part 6: Status utils methods */

bool Tensor::defined() const { return impl_ != nullptr; }

371
bool Tensor::initialized() const { return defined() && impl_->initialized(); }
372 373

bool Tensor::is_initialized() const {
374
  return defined() && impl_->initialized();
375 376 377 378 379 380 381 382 383
}

void Tensor::reset() { impl_.reset(); }

/* Part 7: Operator overloading */

Tensor &Tensor::operator=(const Tensor &x) & {
  impl_ = x.impl_;
  autograd_meta_ = x.autograd_meta_;
384 385
  name_ = x.name_;
  place_ = x.place_;
386 387 388 389 390 391
  return *this;
}

Tensor &Tensor::operator=(Tensor &&x) & {
  impl_ = std::move(x.impl_);
  autograd_meta_ = std::move(x.autograd_meta_);
392 393
  name_ = std::move(x.name_);
  place_ = std::move(x.place_);
394 395 396 397 398 399 400 401 402 403 404 405 406 407
  return *this;
}

AbstractAutogradMeta *Tensor::get_autograd_meta() const {
  return autograd_meta_.get();
}

void Tensor::set_autograd_meta(
    std::shared_ptr<AbstractAutogradMeta> autograd_meta) {
  autograd_meta_ = std::move(autograd_meta);
}

}  // namespace experimental
}  // namespace paddle