tensor_test.cc 7.5 KB
Newer Older
D
dzhwinter 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
//  Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//    http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
F
fengjiayi 已提交
14 15 16

#include "paddle/framework/tensor.h"
#include <gtest/gtest.h>
17
#include <string>
F
fengjiayi 已提交
18

19 20 21
namespace framework = paddle::framework;
namespace platform = paddle::platform;

F
fengjiayi 已提交
22
TEST(Tensor, Dims) {
23
  framework::Tensor tt;
24
  tt.Resize({2, 3, 4});
25
  framework::DDim dims = tt.dims();
F
fengjiayi 已提交
26 27 28 29 30 31 32
  ASSERT_EQ(arity(dims), 3);
  for (int i = 0; i < 3; ++i) {
    EXPECT_EQ(i + 2, dims[i]);
  }
}

TEST(Tensor, DataAssert) {
33
  framework::Tensor src_tensor;
F
fengjiayi 已提交
34

35 36
  bool caught = false;
  try {
F
fengjiayi 已提交
37
    src_tensor.data<double>();
38
  } catch (platform::EnforceNotMet err) {
39
    caught = true;
F
fengjiayi 已提交
40
    std::string msg =
Z
zchen0211 已提交
41
        "holder_ should not be null\nTensor holds no memory. Call "
Y
Yan Chunwei 已提交
42
        "Tensor::mutable_data first.";
43 44 45 46 47 48
    const char* what = err.what();
    for (size_t i = 0; i < msg.length(); ++i) {
      ASSERT_EQ(what[i], msg[i]);
    }
  }
  ASSERT_TRUE(caught);
F
fengjiayi 已提交
49 50 51
}

TEST(Tensor, MutableData) {
52
  {
53
    framework::Tensor src_tensor;
54 55 56
    float* p1 = nullptr;
    float* p2 = nullptr;
    // initialization
57 58
    p1 = src_tensor.mutable_data<float>(framework::make_ddim({1, 2, 3}),
                                        platform::CPUPlace());
59
    EXPECT_NE(p1, nullptr);
F
fengjiayi 已提交
60
    // set src_tensor a new dim with large size
61
    // momery is supposed to be re-allocated
62 63
    p2 = src_tensor.mutable_data<float>(framework::make_ddim({3, 4}),
                                        platform::CPUPlace());
64 65
    EXPECT_NE(p2, nullptr);
    EXPECT_NE(p1, p2);
F
fengjiayi 已提交
66
    // set src_tensor a new dim with same size
67
    // momery block is supposed to be unchanged
68 69
    p1 = src_tensor.mutable_data<float>(framework::make_ddim({2, 2, 3}),
                                        platform::CPUPlace());
70
    EXPECT_EQ(p1, p2);
F
fengjiayi 已提交
71
    // set src_tensor a new dim with smaller size
72
    // momery block is supposed to be unchanged
73 74
    p2 = src_tensor.mutable_data<float>(framework::make_ddim({2, 2}),
                                        platform::CPUPlace());
75 76
    EXPECT_EQ(p1, p2);
  }
L
liaogang 已提交
77

78
#ifdef PADDLE_WITH_CUDA
79
  {
80
    framework::Tensor src_tensor;
81 82 83
    float* p1 = nullptr;
    float* p2 = nullptr;
    // initialization
84 85
    p1 = src_tensor.mutable_data<float>(framework::make_ddim({1, 2, 3}),
                                        platform::CUDAPlace());
86 87 88
    EXPECT_NE(p1, nullptr);
    // set src_tensor a new dim with large size
    // momery is supposed to be re-allocated
89 90
    p2 = src_tensor.mutable_data<float>(framework::make_ddim({3, 4}),
                                        platform::CUDAPlace());
91 92 93 94
    EXPECT_NE(p2, nullptr);
    EXPECT_NE(p1, p2);
    // set src_tensor a new dim with same size
    // momery block is supposed to be unchanged
95 96
    p1 = src_tensor.mutable_data<float>(framework::make_ddim({2, 2, 3}),
                                        platform::CUDAPlace());
97 98 99
    EXPECT_EQ(p1, p2);
    // set src_tensor a new dim with smaller size
    // momery block is supposed to be unchanged
100 101
    p2 = src_tensor.mutable_data<float>(framework::make_ddim({2, 2}),
                                        platform::CUDAPlace());
102 103 104
    EXPECT_EQ(p1, p2);
  }
#endif
F
fengjiayi 已提交
105
}
F
fengjiayi 已提交
106

F
fengjiayi 已提交
107
TEST(Tensor, ShareDataWith) {
F
fengjiayi 已提交
108
  {
109 110
    framework::Tensor src_tensor;
    framework::Tensor dst_tensor;
F
fengjiayi 已提交
111 112 113
    // Try to share data form uninitialized tensor
    bool caught = false;
    try {
114
      dst_tensor.ShareDataWith(src_tensor);
Y
Yu Yang 已提交
115
    } catch (paddle::platform::EnforceNotMet err) {
F
fengjiayi 已提交
116
      caught = true;
F
fengjiayi 已提交
117
      std::string msg =
Z
zchen0211 已提交
118
          "holder_ should not be null\nTensor holds no memory. Call "
Y
Yan Chunwei 已提交
119
          "Tensor::mutable_data first.";
F
fengjiayi 已提交
120 121 122
      const char* what = err.what();
      for (size_t i = 0; i < msg.length(); ++i) {
        ASSERT_EQ(what[i], msg[i]);
F
fengjiayi 已提交
123 124 125 126
      }
    }
    ASSERT_TRUE(caught);

127 128
    src_tensor.mutable_data<int>(framework::make_ddim({2, 3, 4}),
                                 platform::CPUPlace());
129
    dst_tensor.ShareDataWith(src_tensor);
F
fengjiayi 已提交
130 131 132
    ASSERT_EQ(src_tensor.data<int>(), dst_tensor.data<int>());
  }

133
#ifdef PADDLE_WITH_CUDA
134
  {
135 136 137 138
    framework::Tensor src_tensor;
    framework::Tensor dst_tensor;
    src_tensor.mutable_data<int>(framework::make_ddim({2, 3, 4}),
                                 platform::CUDAPlace());
139
    dst_tensor.ShareDataWith(src_tensor);
140 141 142
    ASSERT_EQ(src_tensor.data<int>(), dst_tensor.data<int>());
  }
#endif
F
fengjiayi 已提交
143 144 145 146
}

TEST(Tensor, Slice) {
  {
147 148 149 150 151
    framework::Tensor src_tensor;
    src_tensor.mutable_data<int>(framework::make_ddim({5, 3, 4}),
                                 platform::CPUPlace());
    framework::Tensor slice_tensor = src_tensor.Slice(1, 3);
    framework::DDim slice_dims = slice_tensor.dims();
F
fengjiayi 已提交
152 153 154 155 156 157 158 159
    ASSERT_EQ(arity(slice_dims), 3);
    EXPECT_EQ(slice_dims[0], 2);
    EXPECT_EQ(slice_dims[1], 3);
    EXPECT_EQ(slice_dims[2], 4);

    uintptr_t src_data_address =
        reinterpret_cast<uintptr_t>(src_tensor.data<int>());
    uintptr_t src_mutable_data_address = reinterpret_cast<uintptr_t>(
160
        src_tensor.mutable_data<int>(src_tensor.dims(), platform::CPUPlace()));
F
fengjiayi 已提交
161 162
    uintptr_t slice_data_address =
        reinterpret_cast<uintptr_t>(slice_tensor.data<int>());
163 164 165
    uintptr_t slice_mutable_data_address =
        reinterpret_cast<uintptr_t>(slice_tensor.mutable_data<int>(
            slice_tensor.dims(), platform::CPUPlace()));
F
fengjiayi 已提交
166 167 168 169 170
    EXPECT_EQ(src_data_address, src_mutable_data_address);
    EXPECT_EQ(slice_data_address, slice_mutable_data_address);
    EXPECT_EQ(src_data_address + 3 * 4 * 1 * sizeof(int), slice_data_address);
  }

171
#ifdef PADDLE_WITH_CUDA
172
  {
173 174 175 176 177
    framework::Tensor src_tensor;
    src_tensor.mutable_data<double>(framework::make_ddim({6, 9}),
                                    platform::CUDAPlace());
    framework::Tensor slice_tensor = src_tensor.Slice(2, 6);
    framework::DDim slice_dims = slice_tensor.dims();
178 179 180
    ASSERT_EQ(arity(slice_dims), 2);
    EXPECT_EQ(slice_dims[0], 4);
    EXPECT_EQ(slice_dims[1], 9);
F
fengjiayi 已提交
181

182 183
    uintptr_t src_data_address =
        reinterpret_cast<uintptr_t>(src_tensor.data<double>());
184 185 186
    uintptr_t src_mutable_data_address =
        reinterpret_cast<uintptr_t>(src_tensor.mutable_data<double>(
            src_tensor.dims(), platform::CUDAPlace()));
187 188
    uintptr_t slice_data_address =
        reinterpret_cast<uintptr_t>(slice_tensor.data<double>());
189 190 191
    uintptr_t slice_mutable_data_address =
        reinterpret_cast<uintptr_t>(slice_tensor.mutable_data<double>(
            slice_tensor.dims(), platform::CUDAPlace()));
192 193 194 195 196
    EXPECT_EQ(src_data_address, src_mutable_data_address);
    EXPECT_EQ(slice_data_address, slice_mutable_data_address);
    EXPECT_EQ(src_data_address + 9 * 2 * sizeof(double), slice_data_address);
  }
#endif
F
fengjiayi 已提交
197 198
}

F
fengjiayi 已提交
199
TEST(Tensor, ReshapeToMatrix) {
200 201
  framework::Tensor src;
  int* src_ptr = src.mutable_data<int>({2, 3, 4, 9}, platform::CPUPlace());
F
WIP  
fengjiayi 已提交
202 203 204
  for (int i = 0; i < 2 * 3 * 4 * 9; ++i) {
    src_ptr[i] = i;
  }
205
  framework::Tensor res = framework::ReshapeToMatrix(src, 2);
F
WIP  
fengjiayi 已提交
206 207
  ASSERT_EQ(res.dims()[0], 2 * 3);
  ASSERT_EQ(res.dims()[1], 4 * 9);
Z
zchen0211 已提交
208
}
D
dzhwinter 已提交
209 210

TEST(Tensor, Layout) {
211 212 213 214
  framework::Tensor src;
  ASSERT_EQ(src.layout(), framework::DataLayout::kNHWC);
  src.set_layout(framework::DataLayout::kAnyLayout);
  ASSERT_EQ(src.layout(), framework::DataLayout::kAnyLayout);
D
dzhwinter 已提交
215
}