dlpack_tensor.cc 6.0 KB
Newer Older
S
sneaxiy 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
W
wanghuancoder 已提交
14
#include "paddle/fluid/framework/dlpack_tensor.h"
Y
Yu Yang 已提交
15
#include "paddle/fluid/framework/data_type.h"
W
wanghuancoder 已提交
16 17 18 19 20 21 22 23

namespace paddle {
namespace platform {
struct bfloat16;
struct float16;
}  // namespace platform
}  // namespace paddle

S
sneaxiy 已提交
24 25 26 27 28 29 30
namespace paddle {
namespace framework {

namespace internal {
template <typename T>
static ::DLDataType GetDLDataTypeCode() {
  ::DLDataType dtype;
31
  if (std::is_same<T, platform::complex<float>>::value ||
C
chentianyu03 已提交
32
      std::is_same<T, platform::complex<double>>::value) {
S
Siming Dai 已提交
33 34 35
    dtype.code = kDLComplex;
  } else if (std::is_same<T, platform::bfloat16>::value) {
    dtype.code = kDLBfloat;
36 37
  } else if (std::is_same<T, platform::float16>::value ||
             std::is_floating_point<T>::value) {
S
sneaxiy 已提交
38 39 40 41 42 43
    dtype.code = kDLFloat;
  } else if (std::is_unsigned<T>::value) {
    dtype.code = kDLUInt;
  } else if (std::is_integral<T>::value) {
    dtype.code = kDLInt;
  } else {
44 45 46 47
    PADDLE_THROW(platform::errors::Unavailable(
        "Unsupported data type (%s), only supports float16, float, unsigned "
        "int and int.",
        platform::demangle(typeid(T).name())));
S
sneaxiy 已提交
48 49 50 51 52 53
  }
  dtype.bits = 8 * sizeof(T);
  dtype.lanes = 1;
  return dtype;
}

Y
Yu Yang 已提交
54 55 56 57 58 59 60 61 62 63 64 65 66
static std::unordered_map<int, ::DLDataType> CreateDLDataTypeMap() {
  static std::unordered_map<int, ::DLDataType> result;

#define REG_DL_DATA_TYPE(cpp_type, proto_type) \
  result[static_cast<int>(proto_type)] = GetDLDataTypeCode<cpp_type>()

  _ForEachDataType_(REG_DL_DATA_TYPE);
#undef REG_DL_DATA_TYPE
  return result;
}

static DLDataType GetDLDataTypeFromTypeIndex(proto::VarType::Type type) {
  static auto type_to_dtype_map = CreateDLDataTypeMap();
S
sneaxiy 已提交
67
  static auto type_to_dtype_map_end_it = type_to_dtype_map.end();
Y
Yu Yang 已提交
68
  auto it = type_to_dtype_map.find(static_cast<int>(type));
69 70 71
  PADDLE_ENFORCE_NE(it, type_to_dtype_map_end_it,
                    platform::errors::InvalidArgument(
                        "Unsupported data type (%s).", DataTypeToString(type)));
S
sneaxiy 已提交
72 73 74 75
  return it->second;
#undef REG_DL_DATA_TYPE
}

S
Siming Dai 已提交
76 77 78 79 80 81
struct DLDeviceVisitor : public boost::static_visitor<::DLDevice> {
  inline ::DLDevice operator()(const platform::CPUPlace &place) const {
    ::DLDevice device;
    device.device_type = kDLCPU;
    device.device_id = 0;
    return device;
S
sneaxiy 已提交
82 83
  }

J
jianghaicheng 已提交
84 85 86 87 88
  inline ::DLDevice operator()(const platform::IPUPlace &place) const {
    PADDLE_THROW(
        platform::errors::Unimplemented("platform::IPUPlace is not supported"));
  }

S
Siming Dai 已提交
89
  inline ::DLDevice operator()(const platform::XPUPlace &place) const {
90 91 92 93
    PADDLE_THROW(
        platform::errors::Unimplemented("platform::XPUPlace is not supported"));
  }

S
Siming Dai 已提交
94
  inline ::DLDevice operator()(const platform::NPUPlace &place) const {
95 96 97 98
    PADDLE_THROW(
        platform::errors::Unimplemented("platform::NPUPlace is not supported"));
  }

S
Siming Dai 已提交
99
  inline ::DLDevice operator()(const platform::NPUPinnedPlace &place) const {
100 101 102 103
    PADDLE_THROW(platform::errors::Unimplemented(
        "platform::NPUPinnedPlace is not supported"));
  }

F
fwenguang 已提交
104 105 106 107 108
  inline ::DLDevice operator()(const platform::MLUPlace &place) const {
    PADDLE_THROW(
        platform::errors::Unimplemented("platform::MLUPlace is not supported"));
  }

S
Siming Dai 已提交
109
  inline ::DLDevice operator()(const platform::CUDAPlace &place) const {
110
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
S
Siming Dai 已提交
111 112 113 114
    ::DLDevice device;
    device.device_type = kDLGPU;
    device.device_id = place.device;
    return device;
S
sneaxiy 已提交
115
#else
116 117
    PADDLE_THROW(platform::errors::Unavailable(
        "platform::CUDAPlace is not supported in CPU only version."));
S
sneaxiy 已提交
118 119 120
#endif
  }

S
Siming Dai 已提交
121
  inline ::DLDevice operator()(const platform::CUDAPinnedPlace &place) const {
122
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
S
Siming Dai 已提交
123 124 125 126
    ::DLDevice device;
    device.device_type = kDLCPUPinned;
    device.device_id = 0;
    return device;
S
sneaxiy 已提交
127
#else
128 129
    PADDLE_THROW(platform::errors::Unavailable(
        "platform::CUDAPinnedPlace is not supported in CPU only version."));
S
sneaxiy 已提交
130 131 132 133 134 135 136
#endif
  }
};
}  // namespace internal

DLPackTensor::DLPackTensor(const Tensor &tensor, LaneType lanes) {
  // init data, data buffer
137
  t_.data = const_cast<void *>(tensor.data());
S
sneaxiy 已提交
138

S
Siming Dai 已提交
139
  // init device, DLDevice type with device_type and device_id
S
sneaxiy 已提交
140
  auto place = tensor.place();
S
Siming Dai 已提交
141
  t_.device = boost::apply_visitor(internal::DLDeviceVisitor(), place);
S
sneaxiy 已提交
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164

  // init dtype
  t_.dtype = internal::GetDLDataTypeFromTypeIndex(tensor.type());
  t_.dtype.lanes = lanes;

  // init ndim, tensor rank
  auto &dims = tensor.dims();
  using DimType = decltype(t_.ndim);  // int
  t_.ndim = static_cast<DimType>(dims.size());

  // init shape, tensor dims
  t_.shape = shape_;
  for (DimType i = 0; i < t_.ndim; ++i) {
    t_.shape[i] = dims[i];
  }

  // init strides, nullptr means the tensor is compact
  t_.strides = nullptr;

  // init byte_offset
  t_.byte_offset = 0;
}

S
Siming Dai 已提交
165 166
::DLManagedTensor *DLPackTensor::ToDLManagedTensor() {
  // init shape
6
633WHU 已提交
167 168 169 170 171 172 173
  auto shape = new int64_t[t_.ndim];
  using DimType = decltype(t_.ndim);  // int
  for (DimType i = 0; i < t_.ndim; ++i) {
    shape[i] = t_.shape[i];
  }
  t_.shape = shape;

S
Siming Dai 已提交
174 175 176 177 178 179 180 181 182
  // init strides
  auto strides = new int64_t[t_.ndim];
  for (DimType i = 0; i < t_.ndim; ++i) {
    strides[i] = 1;
  }
  for (DimType i = t_.ndim - 2; i >= 0; --i) {
    strides[i] = t_.shape[i + 1] * strides[i + 1];
  }
  t_.strides = strides;
6
633WHU 已提交
183 184 185 186 187 188 189 190 191 192 193 194 195 196 197

  auto tensor = new DLManagedTensor;
  tensor->dl_tensor = t_;

  tensor->deleter = [](DLManagedTensor *arg) {
    delete[] arg->dl_tensor.shape;
    delete[] arg->dl_tensor.strides;
    delete arg;
  };

  tensor->manager_ctx = nullptr;

  return tensor;
}

S
sneaxiy 已提交
198 199
}  // namespace framework
}  // namespace paddle