dlpack_tensor.cc 6.0 KB
Newer Older
S
sneaxiy 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
W
wanghuancoder 已提交
14
#include "paddle/fluid/framework/dlpack_tensor.h"
15
#include "paddle/fluid/framework/convert_utils.h"
Y
Yu Yang 已提交
16
#include "paddle/fluid/framework/data_type.h"
W
wanghuancoder 已提交
17

S
sneaxiy 已提交
18 19 20 21 22 23 24
namespace paddle {
namespace framework {

namespace internal {
template <typename T>
static ::DLDataType GetDLDataTypeCode() {
  ::DLDataType dtype;
25
  if (std::is_same<T, platform::complex<float>>::value ||
C
chentianyu03 已提交
26
      std::is_same<T, platform::complex<double>>::value) {
S
Siming Dai 已提交
27 28 29
    dtype.code = kDLComplex;
  } else if (std::is_same<T, platform::bfloat16>::value) {
    dtype.code = kDLBfloat;
30 31
  } else if (std::is_same<T, platform::float16>::value ||
             std::is_floating_point<T>::value) {
S
sneaxiy 已提交
32 33 34 35 36 37
    dtype.code = kDLFloat;
  } else if (std::is_unsigned<T>::value) {
    dtype.code = kDLUInt;
  } else if (std::is_integral<T>::value) {
    dtype.code = kDLInt;
  } else {
38 39 40 41
    PADDLE_THROW(platform::errors::Unavailable(
        "Unsupported data type (%s), only supports float16, float, unsigned "
        "int and int.",
        platform::demangle(typeid(T).name())));
S
sneaxiy 已提交
42 43 44 45 46 47
  }
  dtype.bits = 8 * sizeof(T);
  dtype.lanes = 1;
  return dtype;
}

Y
Yu Yang 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60
static std::unordered_map<int, ::DLDataType> CreateDLDataTypeMap() {
  static std::unordered_map<int, ::DLDataType> result;

#define REG_DL_DATA_TYPE(cpp_type, proto_type) \
  result[static_cast<int>(proto_type)] = GetDLDataTypeCode<cpp_type>()

  _ForEachDataType_(REG_DL_DATA_TYPE);
#undef REG_DL_DATA_TYPE
  return result;
}

static DLDataType GetDLDataTypeFromTypeIndex(proto::VarType::Type type) {
  static auto type_to_dtype_map = CreateDLDataTypeMap();
S
sneaxiy 已提交
61
  static auto type_to_dtype_map_end_it = type_to_dtype_map.end();
Y
Yu Yang 已提交
62
  auto it = type_to_dtype_map.find(static_cast<int>(type));
63 64 65
  PADDLE_ENFORCE_NE(it, type_to_dtype_map_end_it,
                    platform::errors::InvalidArgument(
                        "Unsupported data type (%s).", DataTypeToString(type)));
S
sneaxiy 已提交
66 67 68 69
  return it->second;
#undef REG_DL_DATA_TYPE
}

S
Siming Dai 已提交
70 71 72 73 74 75
struct DLDeviceVisitor : public boost::static_visitor<::DLDevice> {
  inline ::DLDevice operator()(const platform::CPUPlace &place) const {
    ::DLDevice device;
    device.device_type = kDLCPU;
    device.device_id = 0;
    return device;
S
sneaxiy 已提交
76 77
  }

J
jianghaicheng 已提交
78 79 80 81 82
  inline ::DLDevice operator()(const platform::IPUPlace &place) const {
    PADDLE_THROW(
        platform::errors::Unimplemented("platform::IPUPlace is not supported"));
  }

S
Siming Dai 已提交
83
  inline ::DLDevice operator()(const platform::XPUPlace &place) const {
84 85 86 87
    PADDLE_THROW(
        platform::errors::Unimplemented("platform::XPUPlace is not supported"));
  }

S
Siming Dai 已提交
88
  inline ::DLDevice operator()(const platform::NPUPlace &place) const {
89 90 91 92
    PADDLE_THROW(
        platform::errors::Unimplemented("platform::NPUPlace is not supported"));
  }

S
Siming Dai 已提交
93
  inline ::DLDevice operator()(const platform::NPUPinnedPlace &place) const {
94 95 96 97
    PADDLE_THROW(platform::errors::Unimplemented(
        "platform::NPUPinnedPlace is not supported"));
  }

F
fwenguang 已提交
98 99 100 101 102
  inline ::DLDevice operator()(const platform::MLUPlace &place) const {
    PADDLE_THROW(
        platform::errors::Unimplemented("platform::MLUPlace is not supported"));
  }

S
Siming Dai 已提交
103
  inline ::DLDevice operator()(const platform::CUDAPlace &place) const {
104
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
S
Siming Dai 已提交
105 106 107 108
    ::DLDevice device;
    device.device_type = kDLGPU;
    device.device_id = place.device;
    return device;
S
sneaxiy 已提交
109
#else
110 111
    PADDLE_THROW(platform::errors::Unavailable(
        "platform::CUDAPlace is not supported in CPU only version."));
S
sneaxiy 已提交
112 113 114
#endif
  }

S
Siming Dai 已提交
115
  inline ::DLDevice operator()(const platform::CUDAPinnedPlace &place) const {
116
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
S
Siming Dai 已提交
117 118 119 120
    ::DLDevice device;
    device.device_type = kDLCPUPinned;
    device.device_id = 0;
    return device;
S
sneaxiy 已提交
121
#else
122 123
    PADDLE_THROW(platform::errors::Unavailable(
        "platform::CUDAPinnedPlace is not supported in CPU only version."));
S
sneaxiy 已提交
124 125 126 127 128 129 130
#endif
  }
};
}  // namespace internal

DLPackTensor::DLPackTensor(const Tensor &tensor, LaneType lanes) {
  // init data, data buffer
131
  t_.data = const_cast<void *>(tensor.data());
S
sneaxiy 已提交
132

S
Siming Dai 已提交
133
  // init device, DLDevice type with device_type and device_id
S
sneaxiy 已提交
134
  auto place = tensor.place();
135
  t_.device = paddle::platform::VisitPlace(place, internal::DLDeviceVisitor());
S
sneaxiy 已提交
136 137

  // init dtype
138 139
  t_.dtype = internal::GetDLDataTypeFromTypeIndex(
      framework::TransToProtoVarType(tensor.dtype()));
S
sneaxiy 已提交
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
  t_.dtype.lanes = lanes;

  // init ndim, tensor rank
  auto &dims = tensor.dims();
  using DimType = decltype(t_.ndim);  // int
  t_.ndim = static_cast<DimType>(dims.size());

  // init shape, tensor dims
  t_.shape = shape_;
  for (DimType i = 0; i < t_.ndim; ++i) {
    t_.shape[i] = dims[i];
  }

  // init strides, nullptr means the tensor is compact
  t_.strides = nullptr;

  // init byte_offset
  t_.byte_offset = 0;
}

S
Siming Dai 已提交
160 161
::DLManagedTensor *DLPackTensor::ToDLManagedTensor() {
  // init shape
6
633WHU 已提交
162 163 164 165 166 167 168
  auto shape = new int64_t[t_.ndim];
  using DimType = decltype(t_.ndim);  // int
  for (DimType i = 0; i < t_.ndim; ++i) {
    shape[i] = t_.shape[i];
  }
  t_.shape = shape;

S
Siming Dai 已提交
169 170 171 172 173 174 175 176 177
  // init strides
  auto strides = new int64_t[t_.ndim];
  for (DimType i = 0; i < t_.ndim; ++i) {
    strides[i] = 1;
  }
  for (DimType i = t_.ndim - 2; i >= 0; --i) {
    strides[i] = t_.shape[i + 1] * strides[i + 1];
  }
  t_.strides = strides;
6
633WHU 已提交
178 179 180 181 182 183 184 185 186 187 188 189 190 191 192

  auto tensor = new DLManagedTensor;
  tensor->dl_tensor = t_;

  tensor->deleter = [](DLManagedTensor *arg) {
    delete[] arg->dl_tensor.shape;
    delete[] arg->dl_tensor.strides;
    delete arg;
  };

  tensor->manager_ctx = nullptr;

  return tensor;
}

S
sneaxiy 已提交
193 194
}  // namespace framework
}  // namespace paddle