dlpack_tensor.cc 6.2 KB
Newer Older
S
sneaxiy 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
W
wanghuancoder 已提交
14
#include "paddle/fluid/framework/dlpack_tensor.h"
15

16
#include "paddle/fluid/framework/convert_utils.h"
Y
Yu Yang 已提交
17
#include "paddle/fluid/framework/data_type.h"
W
wanghuancoder 已提交
18

S
sneaxiy 已提交
19 20 21 22 23 24 25
namespace paddle {
namespace framework {

namespace internal {
template <typename T>
static ::DLDataType GetDLDataTypeCode() {
  ::DLDataType dtype;
26
  if (std::is_same<T, platform::complex<float>>::value ||
C
chentianyu03 已提交
27
      std::is_same<T, platform::complex<double>>::value) {
S
Siming Dai 已提交
28 29 30
    dtype.code = kDLComplex;
  } else if (std::is_same<T, platform::bfloat16>::value) {
    dtype.code = kDLBfloat;
31 32
  } else if (std::is_same<T, platform::float16>::value ||
             std::is_floating_point<T>::value) {
S
sneaxiy 已提交
33 34 35 36 37 38
    dtype.code = kDLFloat;
  } else if (std::is_unsigned<T>::value) {
    dtype.code = kDLUInt;
  } else if (std::is_integral<T>::value) {
    dtype.code = kDLInt;
  } else {
39 40 41 42
    PADDLE_THROW(platform::errors::Unavailable(
        "Unsupported data type (%s), only supports float16, float, unsigned "
        "int and int.",
        platform::demangle(typeid(T).name())));
S
sneaxiy 已提交
43 44 45 46 47 48
  }
  dtype.bits = 8 * sizeof(T);
  dtype.lanes = 1;
  return dtype;
}

Y
Yu Yang 已提交
49 50 51 52 53 54 55 56 57 58 59 60 61
static std::unordered_map<int, ::DLDataType> CreateDLDataTypeMap() {
  static std::unordered_map<int, ::DLDataType> result;

#define REG_DL_DATA_TYPE(cpp_type, proto_type) \
  result[static_cast<int>(proto_type)] = GetDLDataTypeCode<cpp_type>()

  _ForEachDataType_(REG_DL_DATA_TYPE);
#undef REG_DL_DATA_TYPE
  return result;
}

static DLDataType GetDLDataTypeFromTypeIndex(proto::VarType::Type type) {
  static auto type_to_dtype_map = CreateDLDataTypeMap();
S
sneaxiy 已提交
62
  static auto type_to_dtype_map_end_it = type_to_dtype_map.end();
Y
Yu Yang 已提交
63
  auto it = type_to_dtype_map.find(static_cast<int>(type));
64 65
  PADDLE_ENFORCE_NE(it,
                    type_to_dtype_map_end_it,
66 67
                    platform::errors::InvalidArgument(
                        "Unsupported data type (%s).", DataTypeToString(type)));
S
sneaxiy 已提交
68 69 70 71
  return it->second;
#undef REG_DL_DATA_TYPE
}

S
Siming Dai 已提交
72 73 74 75 76 77
struct DLDeviceVisitor : public boost::static_visitor<::DLDevice> {
  inline ::DLDevice operator()(const platform::CPUPlace &place) const {
    ::DLDevice device;
    device.device_type = kDLCPU;
    device.device_id = 0;
    return device;
S
sneaxiy 已提交
78 79
  }

J
jianghaicheng 已提交
80 81 82 83 84
  inline ::DLDevice operator()(const platform::IPUPlace &place) const {
    PADDLE_THROW(
        platform::errors::Unimplemented("platform::IPUPlace is not supported"));
  }

S
Siming Dai 已提交
85
  inline ::DLDevice operator()(const platform::XPUPlace &place) const {
86 87 88 89
    PADDLE_THROW(
        platform::errors::Unimplemented("platform::XPUPlace is not supported"));
  }

S
Siming Dai 已提交
90
  inline ::DLDevice operator()(const platform::NPUPlace &place) const {
91 92 93 94
    PADDLE_THROW(
        platform::errors::Unimplemented("platform::NPUPlace is not supported"));
  }

S
Siming Dai 已提交
95
  inline ::DLDevice operator()(const platform::NPUPinnedPlace &place) const {
96 97 98 99
    PADDLE_THROW(platform::errors::Unimplemented(
        "platform::NPUPinnedPlace is not supported"));
  }

F
fwenguang 已提交
100 101 102 103 104
  inline ::DLDevice operator()(const platform::MLUPlace &place) const {
    PADDLE_THROW(
        platform::errors::Unimplemented("platform::MLUPlace is not supported"));
  }

105 106 107 108 109
  inline ::DLDevice operator()(const platform::CustomPlace &place) const {
    PADDLE_THROW(platform::errors::Unimplemented(
        "platform::CustomPlace is not supported"));
  }

S
Siming Dai 已提交
110
  inline ::DLDevice operator()(const platform::CUDAPlace &place) const {
111
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
S
Siming Dai 已提交
112 113 114 115
    ::DLDevice device;
    device.device_type = kDLGPU;
    device.device_id = place.device;
    return device;
S
sneaxiy 已提交
116
#else
117 118
    PADDLE_THROW(platform::errors::Unavailable(
        "platform::CUDAPlace is not supported in CPU only version."));
S
sneaxiy 已提交
119 120 121
#endif
  }

S
Siming Dai 已提交
122
  inline ::DLDevice operator()(const platform::CUDAPinnedPlace &place) const {
123
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
S
Siming Dai 已提交
124 125 126 127
    ::DLDevice device;
    device.device_type = kDLCPUPinned;
    device.device_id = 0;
    return device;
S
sneaxiy 已提交
128
#else
129 130
    PADDLE_THROW(platform::errors::Unavailable(
        "platform::CUDAPinnedPlace is not supported in CPU only version."));
S
sneaxiy 已提交
131 132 133 134 135 136 137
#endif
  }
};
}  // namespace internal

DLPackTensor::DLPackTensor(const Tensor &tensor, LaneType lanes) {
  // init data, data buffer
138
  t_.data = const_cast<void *>(tensor.data());
S
sneaxiy 已提交
139

S
Siming Dai 已提交
140
  // init device, DLDevice type with device_type and device_id
S
sneaxiy 已提交
141
  auto place = tensor.place();
142
  t_.device = paddle::platform::VisitPlace(place, internal::DLDeviceVisitor());
S
sneaxiy 已提交
143 144

  // init dtype
145 146
  t_.dtype = internal::GetDLDataTypeFromTypeIndex(
      framework::TransToProtoVarType(tensor.dtype()));
S
sneaxiy 已提交
147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
  t_.dtype.lanes = lanes;

  // init ndim, tensor rank
  auto &dims = tensor.dims();
  using DimType = decltype(t_.ndim);  // int
  t_.ndim = static_cast<DimType>(dims.size());

  // init shape, tensor dims
  t_.shape = shape_;
  for (DimType i = 0; i < t_.ndim; ++i) {
    t_.shape[i] = dims[i];
  }

  // init strides, nullptr means the tensor is compact
  t_.strides = nullptr;

  // init byte_offset
  t_.byte_offset = 0;
}

S
Siming Dai 已提交
167 168
::DLManagedTensor *DLPackTensor::ToDLManagedTensor() {
  // init shape
6
633WHU 已提交
169 170 171 172 173 174 175
  auto shape = new int64_t[t_.ndim];
  using DimType = decltype(t_.ndim);  // int
  for (DimType i = 0; i < t_.ndim; ++i) {
    shape[i] = t_.shape[i];
  }
  t_.shape = shape;

S
Siming Dai 已提交
176 177 178 179 180 181 182 183 184
  // init strides
  auto strides = new int64_t[t_.ndim];
  for (DimType i = 0; i < t_.ndim; ++i) {
    strides[i] = 1;
  }
  for (DimType i = t_.ndim - 2; i >= 0; --i) {
    strides[i] = t_.shape[i + 1] * strides[i + 1];
  }
  t_.strides = strides;
6
633WHU 已提交
185 186 187 188 189 190 191 192 193 194 195 196 197 198 199

  auto tensor = new DLManagedTensor;
  tensor->dl_tensor = t_;

  tensor->deleter = [](DLManagedTensor *arg) {
    delete[] arg->dl_tensor.shape;
    delete[] arg->dl_tensor.strides;
    delete arg;
  };

  tensor->manager_ctx = nullptr;

  return tensor;
}

S
sneaxiy 已提交
200 201
}  // namespace framework
}  // namespace paddle