dlpack_tensor.cc 7.8 KB
Newer Older
S
sneaxiy 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
W
wanghuancoder 已提交
14
#include "paddle/fluid/framework/dlpack_tensor.h"
15

16
#include "paddle/fluid/framework/convert_utils.h"
Y
Yu Yang 已提交
17
#include "paddle/fluid/framework/data_type.h"
W
wanghuancoder 已提交
18

S
sneaxiy 已提交
19 20 21 22 23 24 25
namespace paddle {
namespace framework {

namespace internal {
template <typename T>
static ::DLDataType GetDLDataTypeCode() {
  ::DLDataType dtype;
26
  if (std::is_same<T, platform::complex<float>>::value ||
C
chentianyu03 已提交
27
      std::is_same<T, platform::complex<double>>::value) {
S
Siming Dai 已提交
28 29 30
    dtype.code = kDLComplex;
  } else if (std::is_same<T, platform::bfloat16>::value) {
    dtype.code = kDLBfloat;
31 32
  } else if (std::is_same<T, platform::float16>::value ||
             std::is_floating_point<T>::value) {
S
sneaxiy 已提交
33 34 35 36 37 38
    dtype.code = kDLFloat;
  } else if (std::is_unsigned<T>::value) {
    dtype.code = kDLUInt;
  } else if (std::is_integral<T>::value) {
    dtype.code = kDLInt;
  } else {
39 40 41 42
    PADDLE_THROW(platform::errors::Unavailable(
        "Unsupported data type (%s), only supports float16, float, unsigned "
        "int and int.",
        platform::demangle(typeid(T).name())));
S
sneaxiy 已提交
43 44 45 46 47 48
  }
  dtype.bits = 8 * sizeof(T);
  dtype.lanes = 1;
  return dtype;
}

Y
Yu Yang 已提交
49 50 51 52 53 54 55 56 57 58 59 60 61
static std::unordered_map<int, ::DLDataType> CreateDLDataTypeMap() {
  static std::unordered_map<int, ::DLDataType> result;

#define REG_DL_DATA_TYPE(cpp_type, proto_type) \
  result[static_cast<int>(proto_type)] = GetDLDataTypeCode<cpp_type>()

  _ForEachDataType_(REG_DL_DATA_TYPE);
#undef REG_DL_DATA_TYPE
  return result;
}

static DLDataType GetDLDataTypeFromTypeIndex(proto::VarType::Type type) {
  static auto type_to_dtype_map = CreateDLDataTypeMap();
S
sneaxiy 已提交
62
  static auto type_to_dtype_map_end_it = type_to_dtype_map.end();
Y
Yu Yang 已提交
63
  auto it = type_to_dtype_map.find(static_cast<int>(type));
64 65
  PADDLE_ENFORCE_NE(it,
                    type_to_dtype_map_end_it,
66 67
                    platform::errors::InvalidArgument(
                        "Unsupported data type (%s).", DataTypeToString(type)));
S
sneaxiy 已提交
68 69 70 71
  return it->second;
#undef REG_DL_DATA_TYPE
}

72 73
struct DLDeviceVisitor
    : public std::unary_function<const platform::Place &, ::DLDevice> {
S
Siming Dai 已提交
74 75 76 77 78
  inline ::DLDevice operator()(const platform::CPUPlace &place) const {
    ::DLDevice device;
    device.device_type = kDLCPU;
    device.device_id = 0;
    return device;
S
sneaxiy 已提交
79 80
  }

J
jianghaicheng 已提交
81 82 83 84 85
  inline ::DLDevice operator()(const platform::IPUPlace &place) const {
    PADDLE_THROW(
        platform::errors::Unimplemented("platform::IPUPlace is not supported"));
  }

S
Siming Dai 已提交
86
  inline ::DLDevice operator()(const platform::XPUPlace &place) const {
87 88 89 90
    PADDLE_THROW(
        platform::errors::Unimplemented("platform::XPUPlace is not supported"));
  }

S
Siming Dai 已提交
91
  inline ::DLDevice operator()(const platform::NPUPlace &place) const {
92 93 94 95
    PADDLE_THROW(
        platform::errors::Unimplemented("platform::NPUPlace is not supported"));
  }

S
Siming Dai 已提交
96
  inline ::DLDevice operator()(const platform::NPUPinnedPlace &place) const {
97 98 99 100
    PADDLE_THROW(platform::errors::Unimplemented(
        "platform::NPUPinnedPlace is not supported"));
  }

F
fwenguang 已提交
101 102 103 104 105
  inline ::DLDevice operator()(const platform::MLUPlace &place) const {
    PADDLE_THROW(
        platform::errors::Unimplemented("platform::MLUPlace is not supported"));
  }

106 107 108 109 110
  inline ::DLDevice operator()(const platform::CustomPlace &place) const {
    PADDLE_THROW(platform::errors::Unimplemented(
        "platform::CustomPlace is not supported"));
  }

S
Siming Dai 已提交
111
  inline ::DLDevice operator()(const platform::CUDAPlace &place) const {
112
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
S
Siming Dai 已提交
113 114 115 116
    ::DLDevice device;
    device.device_type = kDLGPU;
    device.device_id = place.device;
    return device;
S
sneaxiy 已提交
117
#else
118 119
    PADDLE_THROW(platform::errors::Unavailable(
        "platform::CUDAPlace is not supported in CPU only version."));
S
sneaxiy 已提交
120 121 122
#endif
  }

S
Siming Dai 已提交
123
  inline ::DLDevice operator()(const platform::CUDAPinnedPlace &place) const {
124
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
S
Siming Dai 已提交
125 126 127 128
    ::DLDevice device;
    device.device_type = kDLCPUPinned;
    device.device_id = 0;
    return device;
S
sneaxiy 已提交
129
#else
130 131
    PADDLE_THROW(platform::errors::Unavailable(
        "platform::CUDAPinnedPlace is not supported in CPU only version."));
S
sneaxiy 已提交
132 133 134 135 136
#endif
  }
};
}  // namespace internal

137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
struct PaddleDLMTensor {
  phi::DenseTensor handle;
  DLManagedTensor tensor;
};

void deleter(DLManagedTensor *arg) {
  delete[] arg->dl_tensor.shape;
  delete[] arg->dl_tensor.strides;
  delete static_cast<PaddleDLMTensor *>(arg->manager_ctx);
}

DLManagedTensor *toDLPack(const phi::DenseTensor &src) {
  PaddleDLMTensor *pdDLMTensor(new PaddleDLMTensor);
  pdDLMTensor->handle = const_cast<phi::DenseTensor &>(src);
  pdDLMTensor->tensor.manager_ctx = pdDLMTensor;
  pdDLMTensor->tensor.deleter = &deleter;
  pdDLMTensor->tensor.dl_tensor.data = const_cast<void *>(src.data());

  // init ndim
  using DimType = decltype(pdDLMTensor->tensor.dl_tensor.ndim);  // int
  pdDLMTensor->tensor.dl_tensor.ndim = static_cast<DimType>(src.dims().size());
  DimType ndim = pdDLMTensor->tensor.dl_tensor.ndim;

  // init shape
  auto shape = new int64_t[ndim];
  for (DimType i = 0; i < ndim; ++i) {
    shape[i] = src.dims()[i];
  }
  pdDLMTensor->tensor.dl_tensor.shape = shape;

  // init stride
  auto strides = new int64_t[ndim];
  for (DimType i = 0; i < ndim; ++i) {
    strides[i] = 1;
  }
  for (DimType i = ndim - 2; i >= 0; --i) {
    strides[i] = shape[i + 1] * strides[i + 1];
  }
  pdDLMTensor->tensor.dl_tensor.strides = strides;

  // init device, DLDevice type with device_type and device_id
  auto place = src.place();
  pdDLMTensor->tensor.dl_tensor.device =
      paddle::platform::VisitPlace(place, internal::DLDeviceVisitor());

  pdDLMTensor->tensor.dl_tensor.dtype = internal::GetDLDataTypeFromTypeIndex(
      framework::TransToProtoVarType(src.dtype()));

  pdDLMTensor->tensor.dl_tensor.byte_offset = 0;
  return &(pdDLMTensor->tensor);
}

DLPackTensor::DLPackTensor(const phi::DenseTensor &tensor, LaneType lanes) {
S
sneaxiy 已提交
190
  // init data, data buffer
191
  t_.data = const_cast<void *>(tensor.data());
S
sneaxiy 已提交
192

S
Siming Dai 已提交
193
  // init device, DLDevice type with device_type and device_id
S
sneaxiy 已提交
194
  auto place = tensor.place();
195
  t_.device = paddle::platform::VisitPlace(place, internal::DLDeviceVisitor());
S
sneaxiy 已提交
196 197

  // init dtype
198 199
  t_.dtype = internal::GetDLDataTypeFromTypeIndex(
      framework::TransToProtoVarType(tensor.dtype()));
S
sneaxiy 已提交
200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219
  t_.dtype.lanes = lanes;

  // init ndim, tensor rank
  auto &dims = tensor.dims();
  using DimType = decltype(t_.ndim);  // int
  t_.ndim = static_cast<DimType>(dims.size());

  // init shape, tensor dims
  t_.shape = shape_;
  for (DimType i = 0; i < t_.ndim; ++i) {
    t_.shape[i] = dims[i];
  }

  // init strides, nullptr means the tensor is compact
  t_.strides = nullptr;

  // init byte_offset
  t_.byte_offset = 0;
}

S
Siming Dai 已提交
220 221
::DLManagedTensor *DLPackTensor::ToDLManagedTensor() {
  // init shape
6
633WHU 已提交
222 223 224 225 226 227 228
  auto shape = new int64_t[t_.ndim];
  using DimType = decltype(t_.ndim);  // int
  for (DimType i = 0; i < t_.ndim; ++i) {
    shape[i] = t_.shape[i];
  }
  t_.shape = shape;

S
Siming Dai 已提交
229 230 231 232 233 234 235 236 237
  // init strides
  auto strides = new int64_t[t_.ndim];
  for (DimType i = 0; i < t_.ndim; ++i) {
    strides[i] = 1;
  }
  for (DimType i = t_.ndim - 2; i >= 0; --i) {
    strides[i] = t_.shape[i + 1] * strides[i + 1];
  }
  t_.strides = strides;
6
633WHU 已提交
238 239 240 241 242 243 244 245 246 247 248 249 250 251 252

  auto tensor = new DLManagedTensor;
  tensor->dl_tensor = t_;

  tensor->deleter = [](DLManagedTensor *arg) {
    delete[] arg->dl_tensor.shape;
    delete[] arg->dl_tensor.strides;
    delete arg;
  };

  tensor->manager_ctx = nullptr;

  return tensor;
}

S
sneaxiy 已提交
253 254
}  // namespace framework
}  // namespace paddle