tensor.cc 12.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/api/include/tensor.h"
16 17 18 19 20 21

#include <memory>
#include <utility>
#include <vector>

#include "glog/logging.h"
22 23 24 25 26 27 28 29 30
#include "paddle/phi/api/lib/ext_compat_utils.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/api/lib/utils/storage.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/selected_rows.h"
#include "paddle/phi/core/tensor_base.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/core/tensor_utils.h"
31 32 33 34 35
/**
 * [ Why still include the fluid headers? ]
 *
 * We hope to organize the basic implementation of Tensor and the logic related
 * to Tensor computation into an independent library, which we call
36
 * [Tensor Operation Library, phi], so we extract or rewrite the original
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
 * Kernels.
 *
 * In the future, the training library, inference library and custom operators
 * will link to this Tensor Operation library.
 *
 * However, if we directly split the link relation, we need to make too many
 * changes, which will affect the stability of the framework, so here we still
 * rely on the implementation of the framework, which is a intermediate state.
 *
 * In the future, the necessary components will be moved to the this library,
 * or the corresponding components will be re-implemented.
 */
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/stream/cuda_stream.h"
52 53 54 55
#include "paddle/phi/common/complex.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/enforce.h"
56 57 58 59 60 61 62 63

namespace paddle {
namespace experimental {

/////// Tensor Methods ////////

/* Part 1: Construction and destruction methods */

64
Tensor::Tensor(std::shared_ptr<phi::TensorBase> tensor_impl)
65
    : impl_(std::move(tensor_impl)) {
66 67 68
  PADDLE_ENFORCE_NOT_NULL(
      impl_,
      phi::errors::InvalidArgument("TensorImpl with nullptr is not supported"));
69 70 71
}

Tensor::Tensor(const PlaceType &place)
72 73
    : impl_(std::move(std::make_shared<phi::DenseTensor>(
          std::move(phi::make_intrusive<SharedStorage>(
74
              ConvertExtPlaceToInnerPlace(place))),
75 76 77
          std::move(phi::DenseTensorMeta(phi::DataType::UNDEFINED,
                                         phi::make_ddim({}),
                                         phi::DataLayout::NCHW))))),
78
      place_{place} {}
79 80

Tensor::Tensor(const PlaceType &place, const std::vector<int64_t> &shape)
81 82
    : impl_(std::move(std::make_shared<phi::DenseTensor>(
          std::move(phi::make_intrusive<SharedStorage>(
83
              ConvertExtPlaceToInnerPlace(place))),
84 85 86
          std::move(phi::DenseTensorMeta(phi::DataType::UNDEFINED,
                                         phi::make_ddim(shape),
                                         phi::DataLayout::NCHW))))),
87
      place_{place} {}
88

89
Tensor::Tensor(std::shared_ptr<phi::TensorBase> tensor_impl,
90 91
               const std::string &name)
    : impl_(std::move(tensor_impl)), name_(std::move(name)) {}
92 93 94 95 96 97
/* Part 2: Dimension, DataType and DataLayout methods */

int64_t Tensor::numel() const { return impl_->numel(); }

int64_t Tensor::size() const { return impl_->numel(); }

98
phi::DDim Tensor::dims() const { return impl_->dims(); }
99 100

std::vector<int64_t> Tensor::shape() const {
101
  return phi::vectorize<int64_t>(impl_->dims());
102 103 104
}

void Tensor::reshape(const std::vector<int64_t> &shape) {
105 106 107 108 109 110 111 112
  LOG(WARNING) << "The function of resetting the shape of the uninitialized "
                  "Tensor of the `reshape` method is deprecated since version "
                  "2.3, and will be removed in version 2.4, please use "
                  "`paddle::experimental::full` method to create a new Tensor "
                  "instead. "
                  "reason: `reshape` means changing the tensor shape without "
                  "touching underlying data, this requires the total size of "
                  "the tensor to remain constant.";
C
Chen Weihang 已提交
113
  if (is_dense_tensor()) {
114 115
    std::dynamic_pointer_cast<phi::DenseTensor>(impl_)->Resize(
        phi::make_ddim(shape));
116
  } else {
117
    PADDLE_THROW(phi::errors::Unimplemented(
118 119
        "Only support reshape operation on DenseTensor now."));
  }
120 121
}

122
DataType Tensor::dtype() const { return impl_->dtype(); }
123

124
DataType Tensor::type() const { return impl_->dtype(); }
125 126 127

DataLayout Tensor::layout() const { return impl_->layout(); }

C
Chen Weihang 已提交
128
bool Tensor::is_dense_tensor() const {
129
  return phi::DenseTensor::classof(impl_.get());
C
Chen Weihang 已提交
130
}
131
bool Tensor::is_selected_rows() const {
132
  return phi::SelectedRows::classof(impl_.get());
133
}
134 135 136
/* Part 3: Device and Backend methods */

PlaceType Tensor::place() const {
137 138 139 140 141
  if (!impl_->initialized()) {
    return place_;
  } else {
    return ConvertInnerPlaceToExtPlace(impl_->place());
  }
142 143
}

144 145 146
paddle::platform::Place Tensor::inner_place() const {
  return ConvertExtPlaceToInnerPlace(place());
}
147 148

bool Tensor::is_cpu() const {
149
  return paddle::platform::is_cpu_place(inner_place());
150 151 152
}

bool Tensor::is_cuda() const {
153
  return paddle::platform::is_gpu_place(inner_place());
154 155 156 157 158 159
}

/* Part 4: Data Access methods */

template <typename T>
T *Tensor::mutable_data() {
C
Chen Weihang 已提交
160
  if (is_dense_tensor()) {
161
    return std::dynamic_pointer_cast<phi::DenseTensor>(impl_)->mutable_data<T>(
162
        ConvertExtPlaceToInnerPlace(place()));
163 164 165 166
  }
  return nullptr;
}

167 168 169 170 171 172 173 174
template PADDLE_API float *Tensor::mutable_data<float>();
template PADDLE_API double *Tensor::mutable_data<double>();
template PADDLE_API int64_t *Tensor::mutable_data<int64_t>();
template PADDLE_API int32_t *Tensor::mutable_data<int32_t>();
template PADDLE_API uint8_t *Tensor::mutable_data<uint8_t>();
template PADDLE_API int8_t *Tensor::mutable_data<int8_t>();
template PADDLE_API int16_t *Tensor::mutable_data<int16_t>();
template PADDLE_API bool *Tensor::mutable_data<bool>();
175 176 177 178 179 180
template PADDLE_API phi::dtype::complex<float>
    *Tensor::mutable_data<phi::dtype::complex<float>>();
template PADDLE_API phi::dtype::complex<double>
    *Tensor::mutable_data<phi::dtype::complex<double>>();
template PADDLE_API phi::dtype::float16 *
Tensor::mutable_data<phi::dtype::float16>();
181 182 183 184

template <typename T>
T *Tensor::mutable_data(const PlaceType &place) {
  auto inner_place = ConvertExtPlaceToInnerPlace(place);
185 186 187 188
  if (impl_->initialized()) {
    PADDLE_ENFORCE_EQ(
        platform::is_same_place(inner_place, impl_->place()),
        true,
189 190
        phi::errors::Unimplemented("Modification of tensor place through "
                                   "mutable_data is not supported now"));
191 192
  }
  if (is_dense_tensor()) {
193
    return std::dynamic_pointer_cast<phi::DenseTensor>(impl_)->mutable_data<T>(
194 195 196
        inner_place);
  }
  return nullptr;
197 198
}

199 200
template PADDLE_API float *Tensor::mutable_data<float>(const PlaceType &place);
template PADDLE_API double *Tensor::mutable_data<double>(
201
    const PlaceType &place);
202
template PADDLE_API int64_t *Tensor::mutable_data<int64_t>(
203
    const PlaceType &place);
204
template PADDLE_API int32_t *Tensor::mutable_data<int32_t>(
205
    const PlaceType &place);
206
template PADDLE_API uint8_t *Tensor::mutable_data<uint8_t>(
207
    const PlaceType &place);
208
template PADDLE_API int8_t *Tensor::mutable_data<int8_t>(
209
    const PlaceType &place);
210
template PADDLE_API int16_t *Tensor::mutable_data<int16_t>(
211
    const PlaceType &place);
212
template PADDLE_API bool *Tensor::mutable_data<bool>(const PlaceType &place);
213 214 215 216 217 218
template PADDLE_API phi::dtype::complex<float>
    *Tensor::mutable_data<phi::dtype::complex<float>>(const PlaceType &place);
template PADDLE_API phi::dtype::complex<double>
    *Tensor::mutable_data<phi::dtype::complex<double>>(const PlaceType &place);
template PADDLE_API phi::dtype::float16 *
Tensor::mutable_data<phi::dtype::float16>(const PlaceType &place);
219 220 221

template <typename T>
const T *Tensor::data() const {
C
Chen Weihang 已提交
222
  if (is_dense_tensor()) {
223 224 225
    return std::dynamic_pointer_cast<phi::DenseTensor>(impl_)->data<T>();
  } else if (phi::SelectedRows::classof(impl_.get())) {
    return std::dynamic_pointer_cast<phi::SelectedRows>(impl_)
226 227
        ->value()
        .data<T>();
228 229 230 231
  }
  return nullptr;
}

232 233 234 235 236 237 238 239
template PADDLE_API const float *Tensor::data<float>() const;
template PADDLE_API const double *Tensor::data<double>() const;
template PADDLE_API const int64_t *Tensor::data<int64_t>() const;
template PADDLE_API const int32_t *Tensor::data<int32_t>() const;
template PADDLE_API const uint8_t *Tensor::data<uint8_t>() const;
template PADDLE_API const int8_t *Tensor::data<int8_t>() const;
template PADDLE_API const int16_t *Tensor::data<int16_t>() const;
template PADDLE_API const bool *Tensor::data<bool>() const;
240 241 242 243 244 245 246 247
template PADDLE_API const phi::dtype::complex<float>
    *Tensor::data<phi::dtype::complex<float>>() const;
template PADDLE_API const phi::dtype::complex<double>
    *Tensor::data<phi::dtype::complex<double>>() const;
template PADDLE_API const phi::dtype::float16 *
Tensor::data<phi::dtype::float16>() const;
template PADDLE_API const phi::dtype::bfloat16 *
Tensor::data<phi::dtype::bfloat16>() const;
248 249 250

template <typename T>
T *Tensor::data() {
251
  if (is_dense_tensor()) {
252 253 254
    return std::dynamic_pointer_cast<phi::DenseTensor>(impl_)->data<T>();
  } else if (phi::SelectedRows::classof(impl_.get())) {
    return std::dynamic_pointer_cast<phi::SelectedRows>(impl_)
255 256 257
        ->mutable_value()
        ->data<T>();
  }
258 259 260
  return nullptr;
}

261 262 263 264 265 266 267 268
template PADDLE_API float *Tensor::data<float>();
template PADDLE_API double *Tensor::data<double>();
template PADDLE_API int64_t *Tensor::data<int64_t>();
template PADDLE_API int32_t *Tensor::data<int32_t>();
template PADDLE_API uint8_t *Tensor::data<uint8_t>();
template PADDLE_API int8_t *Tensor::data<int8_t>();
template PADDLE_API int16_t *Tensor::data<int16_t>();
template PADDLE_API bool *Tensor::data<bool>();
269 270 271 272 273
template PADDLE_API phi::dtype::complex<float>
    *Tensor::data<phi::dtype::complex<float>>();
template PADDLE_API phi::dtype::complex<double>
    *Tensor::data<phi::dtype::complex<double>>();
template PADDLE_API phi::dtype::float16 *Tensor::data<phi::dtype::float16>();
274

275
// TODO(chenweihang): replace slice impl by API
276
Tensor Tensor::slice(int64_t begin_idx, int64_t end_idx) const {
C
Chen Weihang 已提交
277
  if (is_dense_tensor()) {
278 279 280
    return Tensor(std::make_shared<phi::DenseTensor>(
        std::move(phi::DenseTensorUtils::Slice(
            *(std::dynamic_pointer_cast<phi::DenseTensor>(impl_).get()),
281 282 283
            begin_idx,
            end_idx))));
  } else {
284
    PADDLE_THROW(phi::errors::Unimplemented(
285
        "Only support slice operation on DenseTensor now."));
286
  }
287 288
}

289
std::shared_ptr<phi::TensorBase> Tensor::impl() const { return impl_; }
290

291
void Tensor::set_impl(const std::shared_ptr<phi::TensorBase> &impl) {
292 293 294 295 296 297 298 299 300
  impl_ = impl;
}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
gpuStream_t Tensor::stream() const {
  return platform::stream::get_current_stream(-1)->raw_stream();
}
#endif

301
/* Part 5: Status utils methods */
302 303 304

bool Tensor::defined() const { return impl_ != nullptr; }

305
bool Tensor::initialized() const { return defined() && impl_->initialized(); }
306 307

bool Tensor::is_initialized() const {
308
  return defined() && impl_->initialized();
309 310 311 312
}

void Tensor::reset() { impl_.reset(); }

313
/* Part 6: Operator overloading */
314 315 316 317

Tensor &Tensor::operator=(const Tensor &x) & {
  impl_ = x.impl_;
  autograd_meta_ = x.autograd_meta_;
318 319
  name_ = x.name_;
  place_ = x.place_;
320 321 322 323 324 325
  return *this;
}

Tensor &Tensor::operator=(Tensor &&x) & {
  impl_ = std::move(x.impl_);
  autograd_meta_ = std::move(x.autograd_meta_);
326 327
  name_ = std::move(x.name_);
  place_ = std::move(x.place_);
328 329 330 331 332 333 334 335 336 337 338 339 340 341
  return *this;
}

AbstractAutogradMeta *Tensor::get_autograd_meta() const {
  return autograd_meta_.get();
}

void Tensor::set_autograd_meta(
    std::shared_ptr<AbstractAutogradMeta> autograd_meta) {
  autograd_meta_ = std::move(autograd_meta);
}

}  // namespace experimental
}  // namespace paddle