test_math_function.cc 11.2 KB
Newer Older
1
//  Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2 3 4 5 6 7 8 9 10 11 12 13
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//    http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
14

15 16
#include <set>

Q
qijun 已提交
17
#include "gtest/gtest.h"
18
#include "paddle/phi/backends/context_pool.h"
19 20
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
21

22
namespace phi {
23
namespace tests {
Q
qijun 已提交
24

Y
Yu Yang 已提交
25
template <typename T>
L
Leo Chen 已提交
26 27 28
inline phi::funcs::BlasT<phi::CPUContext, T> GetBlas(
    const phi::CPUContext& context) {
  return phi::funcs::GetBlas<phi::CPUContext, T>(context);
Y
Yu Yang 已提交
29 30
}

G
guosheng 已提交
31
TEST(math_function, gemm_notrans_cblas) {
32 33 34
  phi::DenseTensor input1;
  phi::DenseTensor input2;
  phi::DenseTensor input3;
G
guosheng 已提交
35 36 37 38

  int m = 2;
  int n = 3;
  int k = 3;
39 40 41 42 43
  auto* dev_ctx =
      phi::DeviceContextPool::Instance().GetByPlace(phi::CPUPlace());

  input1.Resize({2, 3});
  float* input1_ptr = dev_ctx->template Alloc<float>(&input1);
G
guosheng 已提交
44 45
  float arr1[6] = {0, 1, 2, 3, 4, 5};
  memcpy(input1_ptr, arr1, 6 * sizeof(float));
46 47
  input2.Resize({3, 4});
  float* input2_ptr = dev_ctx->template Alloc<float>(&input2);
G
guosheng 已提交
48 49
  float arr2[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
  memcpy(input2_ptr, arr2, 12 * sizeof(float));
50 51
  input3.Resize({2, 4});
  float* input3_ptr = dev_ctx->template Alloc<float>(&input3);
G
guosheng 已提交
52 53 54
  float arr3[8] = {0, 1, 2, 3, 4, 5, 6, 7};
  memcpy(input3_ptr, arr3, 8 * sizeof(float));

55 56 57 58 59 60 61 62 63 64 65 66 67
  GetBlas<float>(*dev_ctx).GEMM(false,
                                false,
                                m,
                                n,
                                k,
                                1,
                                input1_ptr,
                                3,
                                input2_ptr + 1,
                                4,
                                1,
                                input3_ptr + 1,
                                4);
G
guosheng 已提交
68 69 70 71 72 73 74 75 76 77

  EXPECT_EQ(input3_ptr[0], 0);
  EXPECT_EQ(input3_ptr[1], 24);
  EXPECT_EQ(input3_ptr[2], 28);
  EXPECT_EQ(input3_ptr[3], 32);
  EXPECT_EQ(input3_ptr[4], 4);
  EXPECT_EQ(input3_ptr[5], 73);
  EXPECT_EQ(input3_ptr[6], 86);
  EXPECT_EQ(input3_ptr[7], 99);
}
T
tensor-tang 已提交
78 79 80
#ifdef PADDLE_WITH_LIBXSMM
template <typename T>
void MklSmmCompare(int m, int n, int k) {
81 82 83 84
  phi::DenseTensor mat_a;
  phi::DenseTensor mat_b;
  phi::DenseTensor mat_c_smm;
  phi::DenseTensor mat_c_mkl;
T
tensor-tang 已提交
85

86 87 88 89 90 91 92 93 94 95 96
  auto* dev_ctx =
      phi::DeviceContextPool::Instance().GetByPlace(phi::CPUPlace());

  mat_a.Resize({m, k});
  T* A = dev_ctx->template Alloc<T>(&mat_a);
  mat_b.Resize({k, n});
  T* B = dev_ctx->template Alloc<T>(&mat_b);
  mat_c_smm.Resize({m, n});
  T* CSMM = dev_ctx->template Alloc<T>(&mat_c_smm);
  mat_c_mkl.Resize({m, n});
  T* CMKL = dev_ctx->template Alloc<T>(&mat_c_mkl);
T
tensor-tang 已提交
97 98 99 100 101 102 103 104
  T alpha = static_cast<T>(1);
  T beta = static_cast<T>(0);
  for (int i = 0; i < mat_a.numel(); ++i) {
    A[i] = static_cast<T>(i);
  }
  for (int i = 0; i < mat_b.numel(); ++i) {
    B[i] = static_cast<T>(i);
  }
T
tensor-tang 已提交
105 106 107 108
  // lda,ldb,ldc follow RowMajor
  int lda = k;
  int ldb = n;
  int ldc = n;
T
tensor-tang 已提交
109

T
tensor-tang 已提交
110
  auto smm = [&, m, n, k, lda, ldb, ldc, alpha, beta]() {
T
tensor-tang 已提交
111 112
    const char transa = 'N';
    const char transb = 'N';
113 114 115 116 117 118 119 120 121 122 123 124 125
    phi::funcs::CBlas<T>::SMM_GEMM(&transa,
                                   &transb,
                                   &n,
                                   &m,
                                   &k,
                                   &alpha,
                                   B,
                                   &ldb,
                                   A,
                                   &lda,
                                   &beta,
                                   CSMM,
                                   &ldc);
T
tensor-tang 已提交
126 127
  };

T
tensor-tang 已提交
128
  auto mkl = [&, m, n, k, lda, ldb, ldc, alpha, beta]() {
129 130 131 132 133 134 135 136 137 138 139 140 141 142
    phi::funcs::CBlas<T>::GEMM(CblasRowMajor,
                               CblasNoTrans,
                               CblasNoTrans,
                               m,
                               n,
                               k,
                               alpha,
                               A,
                               lda,
                               B,
                               ldb,
                               beta,
                               CMKL,
                               ldc);
T
tensor-tang 已提交
143
  };
T
tensor-tang 已提交
144

T
tensor-tang 已提交
145 146 147 148 149 150 151 152 153 154
  smm();
  mkl();
  ASSERT_EQ(mat_c_mkl.numel(), mat_c_smm.numel());
  for (int i = 0; i < mat_c_mkl.numel(); ++i) {
    EXPECT_FLOAT_EQ(CSMM[i], CMKL[i]);
  }
}
TEST(math_function, gemm_mkl_vs_smm) {
  MklSmmCompare<float>(1, 2, 3);
  MklSmmCompare<double>(1, 2, 3);
T
tensor-tang 已提交
155 156
  MklSmmCompare<float>(3, 2, 1);
  MklSmmCompare<double>(3, 2, 1);
T
tensor-tang 已提交
157 158 159 160
  MklSmmCompare<float>(3, 8, 5);
  MklSmmCompare<double>(3, 8, 5);
}
#endif
G
guosheng 已提交
161

T
tensor-tang 已提交
162
TEST(math_function, gemm_trans_cblas) {
163 164 165
  phi::DenseTensor input1;
  phi::DenseTensor input2;
  phi::DenseTensor input3;
G
guosheng 已提交
166 167 168 169

  int m = 2;
  int n = 3;
  int k = 3;
170 171 172 173 174
  auto* dev_ctx =
      phi::DeviceContextPool::Instance().GetByPlace(phi::CPUPlace());

  input1.Resize({2, 3});
  float* input1_ptr = dev_ctx->template Alloc<float>(&input1);
G
guosheng 已提交
175 176
  float arr1[6] = {0, 1, 2, 3, 4, 5};
  memcpy(input1_ptr, arr1, 6 * sizeof(float));
177 178
  input2.Resize({4, 3});
  float* input2_ptr = dev_ctx->template Alloc<float>(&input2);
G
guosheng 已提交
179 180
  float arr2[12] = {0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11};
  memcpy(input2_ptr, arr2, 12 * sizeof(float));
181 182
  input3.Resize({2, 4});
  float* input3_ptr = dev_ctx->template Alloc<float>(&input3);
G
guosheng 已提交
183 184 185
  float arr3[8] = {0, 1, 2, 3, 4, 5, 6, 7};
  memcpy(input3_ptr, arr3, 8 * sizeof(float));

186 187 188 189 190 191 192 193 194 195 196 197 198
  GetBlas<float>(*dev_ctx).GEMM(false,
                                true,
                                m,
                                n,
                                k,
                                1,
                                input1_ptr,
                                3,
                                input2_ptr + 3,
                                3,
                                1,
                                input3_ptr + 1,
                                4);
G
guosheng 已提交
199 200 201 202 203 204 205 206 207 208

  EXPECT_EQ(input3_ptr[0], 0);
  EXPECT_EQ(input3_ptr[1], 24);
  EXPECT_EQ(input3_ptr[2], 28);
  EXPECT_EQ(input3_ptr[3], 32);
  EXPECT_EQ(input3_ptr[4], 4);
  EXPECT_EQ(input3_ptr[5], 73);
  EXPECT_EQ(input3_ptr[6], 86);
  EXPECT_EQ(input3_ptr[7], 99);
}
209 210

TEST(math_function, zero) {
211
  phi::DenseTensor tensor;
212 213 214 215 216
  auto* dev_ctx =
      phi::DeviceContextPool::Instance().GetByPlace(phi::CPUPlace());

  tensor.Resize({2, 2});
  float* t = dev_ctx->template Alloc<float>(&tensor);
L
Leo Chen 已提交
217
  phi::funcs::SetConstant<phi::CPUContext, float> functor;
218
  functor(*dev_ctx, &tensor, 0);
219 220 221 222 223
  EXPECT_EQ(t[0], 0);
  EXPECT_EQ(t[1], 0);
  EXPECT_EQ(t[2], 0);
  EXPECT_EQ(t[3], 0);

224
  functor(*dev_ctx, &tensor, 1);
225 226 227 228 229 230

  EXPECT_EQ(t[0], 1);
  EXPECT_EQ(t[1], 1);
  EXPECT_EQ(t[2], 1);
  EXPECT_EQ(t[3], 1);
}
231 232 233

template <typename T>
void GemvTest(int m, int n, bool trans) {
234 235 236
  phi::DenseTensor mat_a;
  phi::DenseTensor vec_b;
  phi::DenseTensor vec_c;
237 238 239
  int b_num = trans ? m : n;
  int c_num = trans ? n : m;

240 241 242 243 244 245 246 247 248
  auto* dev_ctx =
      phi::DeviceContextPool::Instance().GetByPlace(phi::CPUPlace());

  mat_a.Resize({m, n});
  T* data_a = dev_ctx->template Alloc<T>(&mat_a);
  vec_b.Resize({b_num});
  T* data_b = dev_ctx->template Alloc<T>(&vec_b);
  vec_c.Resize({c_num});
  T* data_c = dev_ctx->template Alloc<T>(&vec_c);
249 250 251 252 253 254 255
  for (int i = 0; i < mat_a.numel(); ++i) {
    data_a[i] = static_cast<T>(i);
  }
  for (int i = 0; i < vec_b.numel(); ++i) {
    data_b[i] = static_cast<T>(i);
  }

256 257 258 259 260 261 262 263
  GetBlas<T>(*dev_ctx).GEMV(trans,
                            static_cast<int>(m),
                            static_cast<int>(n),
                            1.,
                            data_a,
                            data_b,
                            0.,
                            data_c);
264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289

  if (!trans) {
    for (int i = 0; i < m; ++i) {
      T sum = 0.0;
      for (int j = 0; j < n; ++j) {
        sum += data_a[i * n + j] * data_b[j];
      }
      ASSERT_FLOAT_EQ(data_c[i], sum);
    }
  } else {
    for (int i = 0; i < n; ++i) {
      T sum = 0.0;
      for (int j = 0; j < m; ++j) {
        sum += data_a[j * n + i] * data_b[j];
      }
      ASSERT_FLOAT_EQ(data_c[i], sum);
    }
  }
}

TEST(math_function, gemv) {
  GemvTest<float>(3, 13, false);
  GemvTest<double>(4, 5, false);
  GemvTest<float>(12, 7, true);
  GemvTest<double>(7, 9, true);
}
290 291

TEST(math_funciton, set_constant) {
292
  phi::DenseTensor t;
293 294
  auto* dev_ctx =
      phi::DeviceContextPool::Instance().GetByPlace(phi::CPUPlace());
295
  t.Resize({10, 10});
296 297
  dev_ctx->template Alloc<int>(&t);
  phi::funcs::set_constant(*dev_ctx, &t, 10);
298
  for (int64_t i = 0; i < t.numel(); ++i) {
299 300 301 302 303 304
    PADDLE_ENFORCE_EQ(
        10,
        t.data<int>()[i],
        phi::errors::InvalidArgument("Each value of input tensor should be 10, "
                                     "but received %d.",
                                     t.data<int>()[i]));
305 306
  }
}
T
tensor-tang 已提交
307 308 309

template <typename T>
void GemmWarpTest(int m, int n, int k, T alpha, T beta) {
310 311 312 313
  phi::DenseTensor mat_a;
  phi::DenseTensor mat_b;
  phi::DenseTensor mat_c_ref;
  phi::DenseTensor mat_c_mkl;
T
tensor-tang 已提交
314

315 316 317 318 319 320 321 322 323 324 325
  auto* dev_ctx =
      phi::DeviceContextPool::Instance().GetByPlace(phi::CPUPlace());

  mat_a.Resize({m, k});
  T* A = dev_ctx->template Alloc<T>(&mat_a);
  mat_b.Resize({k, n});
  T* B = dev_ctx->template Alloc<T>(&mat_b);
  mat_c_ref.Resize({m, n});
  T* CREF = dev_ctx->template Alloc<T>(&mat_c_ref);
  mat_c_mkl.Resize({m, n});
  T* CMKL = dev_ctx->template Alloc<T>(&mat_c_mkl);
T
tensor-tang 已提交
326 327 328 329 330 331 332 333 334 335 336 337 338 339

  ASSERT_EQ(mat_c_mkl.numel(), mat_c_ref.numel());
  for (int i = 0; i < mat_a.numel(); ++i) {
    A[i] = static_cast<T>(i);
  }
  for (int i = 0; i < mat_b.numel(); ++i) {
    B[i] = static_cast<T>(i + 1);
  }
  for (int i = 0; i < mat_c_ref.numel(); ++i) {
    CREF[i] = static_cast<T>(i + 2);
    CMKL[i] = CREF[i];
  }

  // this would call gemm_warp
340
  GetBlas<T>(*dev_ctx).GEMM(
341
      CblasNoTrans, CblasNoTrans, m, n, k, alpha, A, B, beta, CREF);
T
tensor-tang 已提交
342 343 344 345 346

  // lda,ldb,ldc follow RowMajor
  int lda = k;
  int ldb = n;
  int ldc = n;
347 348 349 350 351 352 353 354 355 356 357 358 359 360
  phi::funcs::CBlas<T>::GEMM(CblasRowMajor,
                             CblasNoTrans,
                             CblasNoTrans,
                             m,
                             n,
                             k,
                             alpha,
                             A,
                             lda,
                             B,
                             ldb,
                             beta,
                             CMKL,
                             ldc);
T
tensor-tang 已提交
361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376

  for (int i = 0; i < mat_c_mkl.numel(); ++i) {
    EXPECT_FLOAT_EQ(CREF[i], CMKL[i]);
  }
}

TEST(math_function, gemm_warp) {
  GemmWarpTest<float>(3, 2, 5, 1.f, 0.f);
  GemmWarpTest<float>(3, 2, 5, 2.f, 1.f);
  GemmWarpTest<float>(8, 5, 6, 1.f, 0.f);
  GemmWarpTest<float>(8, 5, 6, 2.f, 1.f);
  GemmWarpTest<double>(3, 2, 5, 1.0, 0.0);
  GemmWarpTest<double>(3, 2, 5, 2.0, 1.0);
  GemmWarpTest<double>(8, 5, 6, 1.0, 0.0);
  GemmWarpTest<double>(8, 5, 6, 2.0, 1.0);
}
377 378

}  // namespace tests
379
}  // namespace phi