ext_tensor.cc 16.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/fluid/extension/include/ext_tensor.h"
16

17
#include <utility>
18

19 20 21
#include "paddle/fluid/framework/custom_tensor_utils.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/memory/memcpy.h"
22 23
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
24
#include "paddle/fluid/platform/enforce.h"
25
#include "paddle/fluid/platform/float16.h"
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
#include "paddle/fluid/platform/transform.h"

namespace paddle {

template <typename InType, typename OutType>
struct CastDataTypeFunctor {
  HOSTDEVICE inline OutType operator()(InType in) const {
    return static_cast<OutType>(in);
  }
};

template <typename InType>
struct CastDataType {
  CastDataType(const framework::Tensor &in, framework::Tensor *out,
               const platform::DeviceContext *ctx)
      : in_(in), out_(out), ctx_(ctx) {}
  const framework::Tensor in_;
  framework::Tensor *out_;
  const platform::DeviceContext *ctx_;

  template <typename OutType>
  void apply() {
    auto *in_begin = in_.data<InType>();
    auto *in_end = in_begin + in_.numel();
    auto *out_begin = out_->mutable_data<OutType>(in_.place());

    if (platform::is_cpu_place(in_.place())) {
      platform::Transform<platform::CPUDeviceContext> trans;
      auto *context = static_cast<const platform::CPUDeviceContext *>(ctx_);
      trans(*context, in_begin, in_end, out_begin,
            CastDataTypeFunctor<InType, OutType>());
#ifdef __NVCC__
    } else if (platform::is_gpu_place(in_.place())) {
      platform::Transform<platform::CUDADeviceContext> trans;
      auto *context = static_cast<const platform::CUDADeviceContext *>(ctx_);
      trans(*context, in_begin, in_end, out_begin,
            CastDataTypeFunctor<InType, OutType>());
      context->Wait();
#endif
    } else {
      PADDLE_THROW(platform::errors::Unimplemented(
          "Place type is not supported when casting data type."));
    }
  }
};
template <typename T>
void GpuCopy(T *src, T *dst, PlaceType src_plc, PlaceType dst_plc,
             int64_t ele_size) {
#ifdef PADDLE_WITH_CUDA
  platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
  int device_num = paddle::platform::GetCurrentDeviceId();
  platform::CUDAPlace gpu_place(device_num);
  auto *dev_ctx =
      static_cast<const platform::CUDADeviceContext *>(pool.Get(gpu_place));
  if ((src_plc == PlaceType::kGPU) && (dst_plc == PlaceType::kCPU)) {
    memory::Copy(platform::CPUPlace(), static_cast<void *>(dst), gpu_place, src,
                 ele_size, dev_ctx->stream());
  } else if ((src_plc == PlaceType::kGPU) && (dst_plc == PlaceType::kGPU)) {
    memory::Copy(gpu_place, static_cast<void *>(dst), gpu_place, src, ele_size,
                 dev_ctx->stream());
  } else if ((src_plc == PlaceType::kCPU) && (dst_plc == PlaceType::kGPU)) {
    memory::Copy(gpu_place, static_cast<void *>(dst), platform::CPUPlace(), src,
                 ele_size, dev_ctx->stream());
  } else {
    PADDLE_THROW(platform::errors::Unavailable(
        "Only GPU related Copy can reach this func."));
  }
  cudaStreamSynchronize(dev_ctx->stream());
#endif
}

#define GET_CASTED_TENSOR                               \
  if (!tensor_) {                                       \
    tensor_ = std::make_shared<framework::LoDTensor>(); \
  }                                                     \
  auto *tensor = static_cast<framework::LoDTensor *>(tensor_.get());

C
Chen Weihang 已提交
103
void Tensor::reshape(const std::vector<int64_t> &shape) {
104
  GET_CASTED_TENSOR
105 106
  auto new_dim = framework::make_ddim(shape);
  tensor->Resize(new_dim);
107 108 109
}

Tensor::Tensor(const PlaceType &place)
110 111 112
    : tensor_(std::make_shared<framework::LoDTensor>()),
      place_(place),
      stream_(StreamWrapper()) {}
113 114 115 116 117 118 119 120 121

Tensor::Tensor(const PlaceType &place, const std::vector<int64_t> &shape)
    : tensor_(std::make_shared<framework::LoDTensor>()),
      place_(place),
      stream_(StreamWrapper()) {
  GET_CASTED_TENSOR
  tensor->Resize(framework::make_ddim(shape));
}

122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
template <typename T>
T *Tensor::mutable_data(const PlaceType &place) {
  place_ = place;
  return mutable_data<T>();
}

template <typename T>
T *Tensor::mutable_data() {
  GET_CASTED_TENSOR
  PADDLE_ENFORCE_GT(
      tensor->numel(), 0,
      platform::errors::PreconditionNotMet(
          "You should call Tensor::Reshape(const std::vector<int> "
          "&shape)"
          "function before retrieving mutable_data from input tensor."));
  switch (static_cast<int>(place_)) {
    case static_cast<int>(PlaceType::kCPU): {
      return tensor->mutable_data<T>(platform::CPUPlace());
    }
#ifdef PADDLE_WITH_CUDA
    case static_cast<int>(PlaceType::kGPU): {
      int device_num = platform::GetCurrentDeviceId();
      return tensor->mutable_data<T>(platform::CUDAPlace(device_num));
    }
#endif
    default:
      PADDLE_THROW(platform::errors::Unavailable(
          "Custom operator unsupported place id(%d)",
          static_cast<int>(place_)));
  }
}

template <typename T>
T *Tensor::data() const {
  GET_CASTED_TENSOR;
  auto *res = tensor->data<T>();
  return res;
}

DataType Tensor::type() const {
  GET_CASTED_TENSOR;
  auto type = tensor->type();
  if (type == framework::proto::VarType::FP32) {
    return DataType::FLOAT32;
  } else if (type == framework::proto::VarType::INT64) {
    return DataType::INT64;
  } else if (type == framework::proto::VarType::INT32) {
    return DataType::INT32;
  } else if (type == framework::proto::VarType::INT16) {
    return DataType::INT16;
  } else if (type == framework::proto::VarType::INT8) {
    return DataType::INT8;
  } else if (type == framework::proto::VarType::UINT8) {
    return DataType::UINT8;
  } else if (type == framework::proto::VarType::FP64) {
    return DataType::FLOAT64;
  } else if (type == framework::proto::VarType::BOOL) {
    return DataType::BOOL;
180 181 182 183
  } else if (type == framework::proto::VarType::COMPLEX64) {
    return DataType::COMPLEX64;
  } else if (type == framework::proto::VarType::COMPLEX128) {
    return DataType::COMPLEX128;
184 185
  } else if (type == framework::proto::VarType::FP16) {
    return DataType::FLOAT16;
186
  }
187
  // TODO(JiabinYang) Support more dtype here
188 189 190 191
  return DataType::FLOAT32;
}

template <typename T>
192
Tensor Tensor::copy_to(const PlaceType &target_place) const {
193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224
  GET_CASTED_TENSOR;
  PADDLE_ENFORCE_GE(tensor->numel(), 0,
                    platform::errors::PreconditionNotMet(
                        "You should call Tensor::Reshape(const "
                        "std::vector<int> &shape)"
                        "function before copying data from cpu."));
  size_t ele_size = tensor->numel() * sizeof(T);
  auto *p_src_data = tensor->data<T>();
  auto src_place = place();
  Tensor target = Tensor(target_place);
  target.reshape(shape());
  auto *p_target_data = target.template mutable_data<T>();

  if ((src_place == PlaceType::kCPU) && (target_place == PlaceType::kCPU)) {
    std::memcpy(static_cast<void *>(p_target_data), p_src_data, ele_size);
  } else if ((src_place == PlaceType::kGPU) &&
             (target_place == PlaceType::kCPU)) {
    GpuCopy<T>(p_src_data, p_target_data, src_place, target_place, ele_size);
  } else if ((src_place == PlaceType::kCPU) &&
             (target_place == PlaceType::kGPU)) {
    GpuCopy<T>(p_src_data, p_target_data, src_place, target_place, ele_size);
  } else if ((src_place == PlaceType::kGPU) &&
             (target_place == PlaceType::kGPU)) {
    GpuCopy<T>(p_src_data, p_target_data, src_place, target_place, ele_size);
  } else {
    PADDLE_THROW(platform::errors::Unavailable(
        "Not supported place transform of place: %d to place: %d",
        static_cast<int>(src_place), static_cast<int>(target_place)));
  }
  return target;
}

225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
template PD_DLL_DECL Tensor
Tensor::copy_to<float>(const PlaceType &target_place) const;
template PD_DLL_DECL Tensor
Tensor::copy_to<double>(const PlaceType &target_place) const;
template PD_DLL_DECL Tensor
Tensor::copy_to<int64_t>(const PlaceType &target_place) const;
template PD_DLL_DECL Tensor
Tensor::copy_to<int32_t>(const PlaceType &target_place) const;
template PD_DLL_DECL Tensor
Tensor::copy_to<uint8_t>(const PlaceType &target_place) const;
template PD_DLL_DECL Tensor
Tensor::copy_to<int8_t>(const PlaceType &target_place) const;
template PD_DLL_DECL Tensor
Tensor::copy_to<int16_t>(const PlaceType &target_place) const;
template PD_DLL_DECL Tensor
Tensor::copy_to<bool>(const PlaceType &target_place) const;
241 242 243 244
template PD_DLL_DECL Tensor Tensor::copy_to<paddle::platform::complex64>(
    const PlaceType &target_place) const;
template PD_DLL_DECL Tensor Tensor::copy_to<paddle::platform::complex128>(
    const PlaceType &target_place) const;
245 246
template PD_DLL_DECL Tensor
Tensor::copy_to<paddle::platform::float16>(const PlaceType &target_place) const;
247

248 249 250 251 252 253 254 255
template PD_DLL_DECL float *Tensor::data<float>() const;
template PD_DLL_DECL double *Tensor::data<double>() const;
template PD_DLL_DECL int64_t *Tensor::data<int64_t>() const;
template PD_DLL_DECL int32_t *Tensor::data<int32_t>() const;
template PD_DLL_DECL uint8_t *Tensor::data<uint8_t>() const;
template PD_DLL_DECL int8_t *Tensor::data<int8_t>() const;
template PD_DLL_DECL int16_t *Tensor::data<int16_t>() const;
template PD_DLL_DECL bool *Tensor::data<bool>() const;
256 257 258 259
template PD_DLL_DECL paddle::platform::complex64 *
Tensor::data<paddle::platform::complex64>() const;
template PD_DLL_DECL paddle::platform::complex128 *
Tensor::data<paddle::platform::complex128>() const;
260 261
template PD_DLL_DECL paddle::platform::float16 *
Tensor::data<paddle::platform::float16>() const;
262

263 264 265 266 267 268 269 270
template PD_DLL_DECL float *Tensor::mutable_data<float>();
template PD_DLL_DECL double *Tensor::mutable_data<double>();
template PD_DLL_DECL int64_t *Tensor::mutable_data<int64_t>();
template PD_DLL_DECL int32_t *Tensor::mutable_data<int32_t>();
template PD_DLL_DECL uint8_t *Tensor::mutable_data<uint8_t>();
template PD_DLL_DECL int8_t *Tensor::mutable_data<int8_t>();
template PD_DLL_DECL int16_t *Tensor::mutable_data<int16_t>();
template PD_DLL_DECL bool *Tensor::mutable_data<bool>();
271 272 273 274
template PD_DLL_DECL paddle::platform::complex64 *
Tensor::mutable_data<paddle::platform::complex64>();
template PD_DLL_DECL paddle::platform::complex128 *
Tensor::mutable_data<paddle::platform::complex128>();
275 276
template PD_DLL_DECL paddle::platform::float16 *
Tensor::mutable_data<paddle::platform::float16>();
277

278 279 280 281 282 283 284 285 286 287 288 289 290 291
template PD_DLL_DECL float *Tensor::mutable_data<float>(const PlaceType &place);
template PD_DLL_DECL double *Tensor::mutable_data<double>(
    const PlaceType &place);
template PD_DLL_DECL int64_t *Tensor::mutable_data<int64_t>(
    const PlaceType &place);
template PD_DLL_DECL int32_t *Tensor::mutable_data<int32_t>(
    const PlaceType &place);
template PD_DLL_DECL uint8_t *Tensor::mutable_data<uint8_t>(
    const PlaceType &place);
template PD_DLL_DECL int8_t *Tensor::mutable_data<int8_t>(
    const PlaceType &place);
template PD_DLL_DECL int16_t *Tensor::mutable_data<int16_t>(
    const PlaceType &place);
template PD_DLL_DECL bool *Tensor::mutable_data<bool>(const PlaceType &place);
292 293 294 295
template PD_DLL_DECL paddle::platform::complex64 *
Tensor::mutable_data<paddle::platform::complex64>(const PlaceType &place);
template PD_DLL_DECL paddle::platform::complex128 *
Tensor::mutable_data<paddle::platform::complex128>(const PlaceType &place);
296 297
template PD_DLL_DECL paddle::platform::float16 *
Tensor::mutable_data<paddle::platform::float16>(const PlaceType &place);
298

C
Chen Weihang 已提交
299
std::vector<int64_t> Tensor::shape() const {
300
  GET_CASTED_TENSOR
C
Chen Weihang 已提交
301
  return framework::vectorize<int64_t>(tensor->dims());
302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318
}

const PlaceType &Tensor::place() const {
  GET_CASTED_TENSOR;
  if (platform::is_cpu_place(tensor->place())) {
    place_ = PlaceType::kCPU;
  } else if (platform::is_gpu_place(tensor->place())) {
    place_ = PlaceType::kGPU;
  } else {
    PADDLE_THROW(platform::errors::Unimplemented(
        "Current Tensor hold unsupported Place Type, Please Init it"
        "using Tensor::mutable_data<T>(PaddlePlace) which T is"
        "either Place::kCPU or Place::kGPU"));
  }
  return place_;
}

319
Tensor Tensor::cast(const DataType &target_type) const {
320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357
  GET_CASTED_TENSOR;
  Tensor rlt = Tensor(place());
  rlt.reshape(this->shape());
  auto rlt_tensor_ = static_cast<framework::LoDTensor *>(rlt.tensor_.get());
  platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
  auto ctx = pool.Get(tensor->place());
  auto src_type = tensor->type();
  auto dst_type =
      framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(target_type);
  switch (src_type) {
    case framework::proto::VarType::FP32:
      framework::VisitDataType(dst_type,
                               CastDataType<float>(*tensor, rlt_tensor_, ctx));
      break;
    case framework::proto::VarType::FP64:
      framework::VisitDataType(dst_type,
                               CastDataType<double>(*tensor, rlt_tensor_, ctx));
      break;
    case framework::proto::VarType::INT32:
      framework::VisitDataType(dst_type,
                               CastDataType<int>(*tensor, rlt_tensor_, ctx));
      break;
    case framework::proto::VarType::INT64:
      framework::VisitDataType(
          dst_type, CastDataType<int64_t>(*tensor, rlt_tensor_, ctx));
      break;
    case framework::proto::VarType::BOOL:
      framework::VisitDataType(dst_type,
                               CastDataType<bool>(*tensor, rlt_tensor_, ctx));
      break;
    case framework::proto::VarType::INT16:
      framework::VisitDataType(
          dst_type, CastDataType<int16_t>(*tensor, rlt_tensor_, ctx));
      break;
    case framework::proto::VarType::UINT8:
      framework::VisitDataType(
          dst_type, CastDataType<uint8_t>(*tensor, rlt_tensor_, ctx));
      break;
358 359 360 361 362 363 364 365 366 367
    case framework::proto::VarType::COMPLEX64:
      framework::VisitDataType(
          dst_type,
          CastDataType<paddle::platform::complex64>(*tensor, rlt_tensor_, ctx));
      break;
    case framework::proto::VarType::COMPLEX128:
      framework::VisitDataType(dst_type,
                               CastDataType<paddle::platform::complex128>(
                                   *tensor, rlt_tensor_, ctx));
      break;
368 369 370 371 372
    case framework::proto::VarType::FP16:
      framework::VisitDataType(
          dst_type,
          CastDataType<paddle::platform::float16>(*tensor, rlt_tensor_, ctx));
      break;
373
    // TODO(JiabinYang) Support more dtype here
374 375 376 377 378 379 380 381 382 383 384 385 386
    default:
      PADDLE_THROW(platform::errors::Unimplemented(
          "Data type (%s) is not supported when casting data type.",
          framework::DataTypeToString(src_type)));
  }
  return rlt;
}

int64_t Tensor::size() const {
  GET_CASTED_TENSOR;
  return tensor->numel();
}

387 388 389 390 391 392 393 394 395
bool Tensor::is_initialized() const {
  GET_CASTED_TENSOR;
  if (tensor->IsInitialized()) {
    return true;
  } else {
    return false;
  }
}

396 397 398 399 400 401 402 403 404 405 406 407
#ifdef PADDLE_WITH_CUDA
cudaStream_t Tensor::stream() const {
  if (!stream_.IsStreamSet()) {
    PADDLE_THROW(platform::errors::PreconditionNotMet(
        "Stream is not Set, only input tensor will have "
        "stream which is set by framework "));
  } else {
    return reinterpret_cast<cudaStream_t>(stream_.GetStream());
  }
}
#endif

408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425
namespace framework {

void CustomTensorUtils::ShareDataTo(const paddle::Tensor &src, void *dst) {
  static_cast<framework::LoDTensor *>(dst)->ShareDataWith(
      *static_cast<framework::LoDTensor *>(src.tensor_.get()));
}

void CustomTensorUtils::ShareDataFrom(const void *src,
                                      const paddle::Tensor &dst) {
  if (!dst.tensor_) {
    dst.tensor_ = std::make_shared<framework::LoDTensor>();
  }
  auto *tensor = static_cast<framework::LoDTensor *>(dst.tensor_.get());
  tensor->ShareDataWith(*static_cast<const framework::LoDTensor *>(src));
}

}  // namespace framework
}  // namespace paddle