dlpack_tensor_test.cc 4.2 KB
Newer Older
S
sneaxiy 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/dlpack_tensor.h"
#include <glog/logging.h>
#include <gtest/gtest.h>

W
wanghuancoder 已提交
19 20 21 22 23 24
namespace paddle {
namespace platform {
struct float16;
}  // namespace platform
}  // namespace paddle

S
sneaxiy 已提交
25 26 27 28 29 30
namespace paddle {
namespace framework {

namespace {  // NOLINT
template <typename T>
constexpr uint8_t GetDLDataTypeCode() {
31
  if (std::is_same<T, platform::complex<float>>::value ||
32
      std::is_same<T, platform::complex<double>>::value) {
S
Siming Dai 已提交
33 34 35 36 37
    return static_cast<uint8_t>(kDLComplex);
  }

  if (std::is_same<T, platform::bfloat16>::value) {
    return static_cast<uint8_t>(kDLBfloat);
38 39
  }

S
sneaxiy 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
  return std::is_same<platform::float16, T>::value ||
                 std::is_floating_point<T>::value
             ? static_cast<uint8_t>(kDLFloat)
             : (std::is_unsigned<T>::value
                    ? static_cast<uint8_t>(kDLUInt)
                    : (std::is_integral<T>::value ? static_cast<uint8_t>(kDLInt)
                                                  : static_cast<uint8_t>(-1)));
}
}  // NOLINT

template <typename T>
void TestMain(const platform::Place &place, uint16_t lanes) {
  DDim dims{4, 5, 6, 7};
  Tensor tensor;
  tensor.Resize(dims);
  void *p = tensor.mutable_data<T>(place);

  DLPackTensor dlpack_tensor(tensor, lanes);
  ::DLTensor &dl_tensor = dlpack_tensor;

  CHECK_EQ(p, dl_tensor.data);
  if (platform::is_cpu_place(place)) {
S
Siming Dai 已提交
62 63
    CHECK_EQ(kDLCPU, dl_tensor.device.device_type);
    CHECK_EQ(0, dl_tensor.device.device_id);
S
sneaxiy 已提交
64
  } else if (platform::is_gpu_place(place)) {
S
Siming Dai 已提交
65
    CHECK_EQ(kDLGPU, dl_tensor.device.device_type);
66
    CHECK_EQ(place.device, dl_tensor.device.device_id);
S
sneaxiy 已提交
67
  } else if (platform::is_cuda_pinned_place(place)) {
S
Siming Dai 已提交
68 69
    CHECK_EQ(kDLCPUPinned, dl_tensor.device.device_type);
    CHECK_EQ(0, dl_tensor.device.device_id);
S
sneaxiy 已提交
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
  } else {
    CHECK_EQ(false, true);
  }

  CHECK_EQ(dims.size(), dl_tensor.ndim);
  for (auto i = 0; i < dims.size(); ++i) {
    CHECK_EQ(dims[i], dl_tensor.shape[i]);
  }

  CHECK_EQ(dl_tensor.strides == nullptr, true);
  CHECK_EQ(static_cast<uint64_t>(0), dl_tensor.byte_offset);

  CHECK_EQ(lanes, dl_tensor.dtype.lanes);
  CHECK_EQ(sizeof(T) * 8, dl_tensor.dtype.bits);

  CHECK_EQ(GetDLDataTypeCode<T>(), dl_tensor.dtype.code);
}

6
633WHU 已提交
88
template <typename T>
S
Siming Dai 已提交
89
void TestToDLManagedTensor(const platform::Place &place, uint16_t lanes) {
6
633WHU 已提交
90 91 92 93 94 95 96
  DDim dims{6, 7};
  Tensor tensor;
  tensor.Resize(dims);
  tensor.mutable_data<T>(place);

  DLPackTensor dlpack_tensor(tensor, lanes);

S
Siming Dai 已提交
97
  ::DLManagedTensor *dl_managed_tensor = dlpack_tensor.ToDLManagedTensor();
6
633WHU 已提交
98 99 100 101 102 103 104

  CHECK_EQ(dl_managed_tensor->manager_ctx == nullptr, true);

  for (auto i = 0; i < dims.size(); ++i) {
    CHECK_EQ(dims[i], dl_managed_tensor->dl_tensor.shape[i]);
  }

S
Siming Dai 已提交
105 106
  CHECK_EQ(dl_managed_tensor->dl_tensor.strides[0] == 7, true);
  CHECK_EQ(dl_managed_tensor->dl_tensor.strides[1] == 1, true);
6
633WHU 已提交
107 108 109 110

  dl_managed_tensor->deleter(dl_managed_tensor);
}

S
sneaxiy 已提交
111 112
template <typename T>
void TestMainLoop() {
113
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
S
sneaxiy 已提交
114 115 116
  std::vector<platform::Place> places{platform::CPUPlace(),
                                      platform::CUDAPlace(0),
                                      platform::CUDAPinnedPlace()};
117
  if (platform::GetGPUDeviceCount() > 1) {
S
sneaxiy 已提交
118 119 120 121 122 123 124 125 126
    places.emplace_back(platform::CUDAPlace(1));
  }
#else
  std::vector<platform::Place> places{platform::CPUPlace()};
#endif
  std::vector<uint16_t> lanes{1, 2};
  for (auto &p : places) {
    for (auto &l : lanes) {
      TestMain<T>(p, l);
S
Siming Dai 已提交
127
      TestToDLManagedTensor<T>(p, l);
S
sneaxiy 已提交
128 129 130
    }
  }
}
Y
Yu Yang 已提交
131 132
TEST(dlpack, test_all) {
#define TestCallback(cpp_type, proto_type) TestMainLoop<cpp_type>()
S
sneaxiy 已提交
133

Y
Yu Yang 已提交
134 135
  _ForEachDataType_(TestCallback);
}
S
sneaxiy 已提交
136 137 138

}  // namespace framework
}  // namespace paddle