dlpack_tensor.cc 5.9 KB
Newer Older
S
sneaxiy 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
W
wanghuancoder 已提交
14
#include "paddle/fluid/framework/dlpack_tensor.h"
Y
Yu Yang 已提交
15
#include "paddle/fluid/framework/data_type.h"
W
wanghuancoder 已提交
16

S
sneaxiy 已提交
17 18 19 20 21 22 23
namespace paddle {
namespace framework {

namespace internal {
template <typename T>
static ::DLDataType GetDLDataTypeCode() {
  ::DLDataType dtype;
24
  if (std::is_same<T, platform::complex<float>>::value ||
C
chentianyu03 已提交
25
      std::is_same<T, platform::complex<double>>::value) {
S
Siming Dai 已提交
26 27 28
    dtype.code = kDLComplex;
  } else if (std::is_same<T, platform::bfloat16>::value) {
    dtype.code = kDLBfloat;
29 30
  } else if (std::is_same<T, platform::float16>::value ||
             std::is_floating_point<T>::value) {
S
sneaxiy 已提交
31 32 33 34 35 36
    dtype.code = kDLFloat;
  } else if (std::is_unsigned<T>::value) {
    dtype.code = kDLUInt;
  } else if (std::is_integral<T>::value) {
    dtype.code = kDLInt;
  } else {
37 38 39 40
    PADDLE_THROW(platform::errors::Unavailable(
        "Unsupported data type (%s), only supports float16, float, unsigned "
        "int and int.",
        platform::demangle(typeid(T).name())));
S
sneaxiy 已提交
41 42 43 44 45 46
  }
  dtype.bits = 8 * sizeof(T);
  dtype.lanes = 1;
  return dtype;
}

Y
Yu Yang 已提交
47 48 49 50 51 52 53 54 55 56 57 58 59
static std::unordered_map<int, ::DLDataType> CreateDLDataTypeMap() {
  static std::unordered_map<int, ::DLDataType> result;

#define REG_DL_DATA_TYPE(cpp_type, proto_type) \
  result[static_cast<int>(proto_type)] = GetDLDataTypeCode<cpp_type>()

  _ForEachDataType_(REG_DL_DATA_TYPE);
#undef REG_DL_DATA_TYPE
  return result;
}

static DLDataType GetDLDataTypeFromTypeIndex(proto::VarType::Type type) {
  static auto type_to_dtype_map = CreateDLDataTypeMap();
S
sneaxiy 已提交
60
  static auto type_to_dtype_map_end_it = type_to_dtype_map.end();
Y
Yu Yang 已提交
61
  auto it = type_to_dtype_map.find(static_cast<int>(type));
62 63 64
  PADDLE_ENFORCE_NE(it, type_to_dtype_map_end_it,
                    platform::errors::InvalidArgument(
                        "Unsupported data type (%s).", DataTypeToString(type)));
S
sneaxiy 已提交
65 66 67 68
  return it->second;
#undef REG_DL_DATA_TYPE
}

S
Siming Dai 已提交
69 70 71 72 73 74
struct DLDeviceVisitor : public boost::static_visitor<::DLDevice> {
  inline ::DLDevice operator()(const platform::CPUPlace &place) const {
    ::DLDevice device;
    device.device_type = kDLCPU;
    device.device_id = 0;
    return device;
S
sneaxiy 已提交
75 76
  }

J
jianghaicheng 已提交
77 78 79 80 81
  inline ::DLDevice operator()(const platform::IPUPlace &place) const {
    PADDLE_THROW(
        platform::errors::Unimplemented("platform::IPUPlace is not supported"));
  }

S
Siming Dai 已提交
82
  inline ::DLDevice operator()(const platform::XPUPlace &place) const {
83 84 85 86
    PADDLE_THROW(
        platform::errors::Unimplemented("platform::XPUPlace is not supported"));
  }

S
Siming Dai 已提交
87
  inline ::DLDevice operator()(const platform::NPUPlace &place) const {
88 89 90 91
    PADDLE_THROW(
        platform::errors::Unimplemented("platform::NPUPlace is not supported"));
  }

S
Siming Dai 已提交
92
  inline ::DLDevice operator()(const platform::NPUPinnedPlace &place) const {
93 94 95 96
    PADDLE_THROW(platform::errors::Unimplemented(
        "platform::NPUPinnedPlace is not supported"));
  }

F
fwenguang 已提交
97 98 99 100 101
  inline ::DLDevice operator()(const platform::MLUPlace &place) const {
    PADDLE_THROW(
        platform::errors::Unimplemented("platform::MLUPlace is not supported"));
  }

S
Siming Dai 已提交
102
  inline ::DLDevice operator()(const platform::CUDAPlace &place) const {
103
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
S
Siming Dai 已提交
104 105 106 107
    ::DLDevice device;
    device.device_type = kDLGPU;
    device.device_id = place.device;
    return device;
S
sneaxiy 已提交
108
#else
109 110
    PADDLE_THROW(platform::errors::Unavailable(
        "platform::CUDAPlace is not supported in CPU only version."));
S
sneaxiy 已提交
111 112 113
#endif
  }

S
Siming Dai 已提交
114
  inline ::DLDevice operator()(const platform::CUDAPinnedPlace &place) const {
115
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
S
Siming Dai 已提交
116 117 118 119
    ::DLDevice device;
    device.device_type = kDLCPUPinned;
    device.device_id = 0;
    return device;
S
sneaxiy 已提交
120
#else
121 122
    PADDLE_THROW(platform::errors::Unavailable(
        "platform::CUDAPinnedPlace is not supported in CPU only version."));
S
sneaxiy 已提交
123 124 125 126 127 128 129
#endif
  }
};
}  // namespace internal

DLPackTensor::DLPackTensor(const Tensor &tensor, LaneType lanes) {
  // init data, data buffer
130
  t_.data = const_cast<void *>(tensor.data());
S
sneaxiy 已提交
131

S
Siming Dai 已提交
132
  // init device, DLDevice type with device_type and device_id
S
sneaxiy 已提交
133
  auto place = tensor.place();
134
  t_.device = paddle::platform::VisitPlace(place, internal::DLDeviceVisitor());
S
sneaxiy 已提交
135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157

  // init dtype
  t_.dtype = internal::GetDLDataTypeFromTypeIndex(tensor.type());
  t_.dtype.lanes = lanes;

  // init ndim, tensor rank
  auto &dims = tensor.dims();
  using DimType = decltype(t_.ndim);  // int
  t_.ndim = static_cast<DimType>(dims.size());

  // init shape, tensor dims
  t_.shape = shape_;
  for (DimType i = 0; i < t_.ndim; ++i) {
    t_.shape[i] = dims[i];
  }

  // init strides, nullptr means the tensor is compact
  t_.strides = nullptr;

  // init byte_offset
  t_.byte_offset = 0;
}

S
Siming Dai 已提交
158 159
::DLManagedTensor *DLPackTensor::ToDLManagedTensor() {
  // init shape
6
633WHU 已提交
160 161 162 163 164 165 166
  auto shape = new int64_t[t_.ndim];
  using DimType = decltype(t_.ndim);  // int
  for (DimType i = 0; i < t_.ndim; ++i) {
    shape[i] = t_.shape[i];
  }
  t_.shape = shape;

S
Siming Dai 已提交
167 168 169 170 171 172 173 174 175
  // init strides
  auto strides = new int64_t[t_.ndim];
  for (DimType i = 0; i < t_.ndim; ++i) {
    strides[i] = 1;
  }
  for (DimType i = t_.ndim - 2; i >= 0; --i) {
    strides[i] = t_.shape[i + 1] * strides[i + 1];
  }
  t_.strides = strides;
6
633WHU 已提交
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190

  auto tensor = new DLManagedTensor;
  tensor->dl_tensor = t_;

  tensor->deleter = [](DLManagedTensor *arg) {
    delete[] arg->dl_tensor.shape;
    delete[] arg->dl_tensor.strides;
    delete arg;
  };

  tensor->manager_ctx = nullptr;

  return tensor;
}

S
sneaxiy 已提交
191 192
}  // namespace framework
}  // namespace paddle