test.cc 5.5 KB
Newer Older
T
tensor-tang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License. */

#include <cstring>  // for memcpy
#include <random>
#include <string>
#include <vector>
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "gtest/gtest.h"
T
tensor-tang 已提交
22
#include "paddle/fluid/operators/jit/kernels.h"
T
tensor-tang 已提交
23
#include "paddle/fluid/platform/place.h"
T
tensor-tang 已提交
24

T
tensor-tang 已提交
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
template <typename T>
void RandomVec(const int n, T* a, const T lower = static_cast<T>(-20.f),
               const T upper = static_cast<T>(20.f)) {
  static unsigned int seed = 100;
  std::mt19937 rng(seed++);
  std::uniform_real_distribution<double> uniform_dist(0, 1);
  for (int i = 0; i < n; ++i) {
    a[i] = static_cast<T>(uniform_dist(rng) * (upper - lower) + lower);
  }
}

template <typename T>
void ExpectEQ(const T* target, const T* refer, int n) {
  if (std::is_floating_point<T>::value) {
    for (int i = 0; i < n; ++i) {
      EXPECT_NEAR(target[i], refer[i], 1e-3);
    }
  } else {
    for (int i = 0; i < n; ++i) {
      EXPECT_EQ(target[i], refer[i]);
    }
  }
}

T
tensor-tang 已提交
49 50
std::vector<int> TestSizes() {
  std::vector<int> s;
T
tensor-tang 已提交
51
  for (int i = 1; i < 32; ++i) {
T
tensor-tang 已提交
52 53
    s.push_back(i);
  }
T
tensor-tang 已提交
54 55 56 57
  // test some large size
  s.push_back(100);
  s.push_back(1000);
  s.push_back(2000);
T
tensor-tang 已提交
58 59 60
  return s;
}

61 62 63 64
template <typename T, typename KernelTuples>
void TestTartgetFunc(const typename KernelTuples::func_type tgt,
                     const std::vector<T>& x, const std::vector<T>& y,
                     const std::vector<T>& zref) {
T
tensor-tang 已提交
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
  EXPECT_TRUE(tgt != nullptr);
  EXPECT_EQ(zref.size(), x.size());
  EXPECT_EQ(zref.size(), y.size());
  const T* x_data = x.data();
  const T* y_data = y.data();
  const T* zref_data = zref.data();
  const int d = zref.size();

  std::vector<T> ztgt(d);
  T* ztgt_data = ztgt.data();
  // test normal
  tgt(x_data, y_data, ztgt_data, d);
  ExpectEQ<T>(ztgt_data, zref_data, d);
  // test inplace x
  std::copy(x.begin(), x.end(), ztgt.begin());
  tgt(ztgt_data, y_data, ztgt_data, d);
  ExpectEQ<T>(ztgt_data, zref_data, d);
  // test inplace y
  std::copy(y.begin(), y.end(), ztgt.begin());
  tgt(x_data, ztgt_data, ztgt_data, d);
  ExpectEQ<T>(ztgt_data, zref_data, d);
}

88 89
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void TestXYZNKernel() {
T
tensor-tang 已提交
90
  namespace jit = paddle::operators::jit;
T
tensor-tang 已提交
91
  for (int d : TestSizes()) {
92 93 94
    VLOG(10) << "===== Test JITKernel " << jit::to_string(KT)
             << ", size: " << d;
    auto ref = jit::GetRefer<KT, jit::XYZNTuples<T>>();
T
tensor-tang 已提交
95 96
    EXPECT_TRUE(ref != nullptr);

T
tensor-tang 已提交
97
    std::vector<T> x(d), y(d), zref(d);
T
tensor-tang 已提交
98 99 100
    RandomVec<T>(d, x.data());
    RandomVec<T>(d, y.data());

T
tensor-tang 已提交
101 102 103 104 105 106 107 108 109 110 111
    std::vector<T> xinp(d), yinp(d);  // inplace test
    std::copy(x.begin(), x.end(), xinp.begin());
    std::copy(y.begin(), y.end(), yinp.begin());

    const T* x_data = x.data();
    const T* y_data = y.data();
    T* zref_data = zref.data();
    T* xinp_data = xinp.data();
    T* yinp_data = yinp.data();

    // test refer code inplace
T
tensor-tang 已提交
112
    ref(x_data, y_data, zref_data, d);
T
tensor-tang 已提交
113 114 115 116 117 118
    ref(x_data, yinp_data, yinp_data, d);
    ref(xinp_data, y_data, xinp_data, d);
    ExpectEQ<T>(xinp_data, zref_data, d);
    ExpectEQ<T>(yinp_data, zref_data, d);

    // test jitcode
119
    auto jitcode = jit::GetJitCode<KT, jit::XYZNTuples<T>, PlaceType>(d);
T
tensor-tang 已提交
120
    if (jitcode) {
121 122
      VLOG(10) << "Test Jitcode Kernel, size: " << d;
      TestTartgetFunc<T, jit::XYZNTuples<T>>(jitcode, x, y, zref);
T
tensor-tang 已提交
123 124 125 126 127 128 129 130 131
    }

    // test all impls in more
    jit::KernelKey kkey(KT, PlaceType());
    auto& pool = jit::KernelPool().Instance().AllKernels();
    auto iter = pool.find(kkey);
    if (iter != pool.end()) {
      auto& impls = iter->second;
      for (auto& impl : impls) {
132
        auto i = dynamic_cast<const jit::KernelImpl<jit::XYZNTuples<T>>*>(
T
tensor-tang 已提交
133
            impl.get());
T
tensor-tang 已提交
134 135 136
        if (i && i->UseMe(d)) {
          auto more = i->GetFunc();
          VLOG(10) << "Test More Kernel, size: " << d;
137
          TestTartgetFunc<T, jit::XYZNTuples<T>>(more, x, y, zref);
T
tensor-tang 已提交
138 139 140 141 142
        }
      }
    }
    // Test result from Get function
    VLOG(10) << "Test Get function, size: " << d;
143 144
    auto tgt = jit::Get<KT, jit::XYZNTuples<T>, PlaceType>(d);
    TestTartgetFunc<T, jit::XYZNTuples<T>>(tgt, x, y, zref);
T
tensor-tang 已提交
145 146
  }
}
T
tensor-tang 已提交
147

148 149 150
TEST(JITKernel, vmul) {
  namespace jit = paddle::operators::jit;
  TestXYZNKernel<jit::vmul, float, paddle::platform::CPUPlace>();
T
tensor-tang 已提交
151
  TestXYZNKernel<jit::vmul, double, paddle::platform::CPUPlace>();
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
}

TEST(JITKernel, vadd) {
  namespace jit = paddle::operators::jit;
  TestXYZNKernel<jit::vadd, float, paddle::platform::CPUPlace>();
  TestXYZNKernel<jit::vadd, double, paddle::platform::CPUPlace>();
}

TEST(JITKernel, vaddrelu) {
  namespace jit = paddle::operators::jit;
  TestXYZNKernel<jit::vaddrelu, float, paddle::platform::CPUPlace>();
  TestXYZNKernel<jit::vaddrelu, double, paddle::platform::CPUPlace>();
}

TEST(JITKernel, vsub) {
  namespace jit = paddle::operators::jit;
  TestXYZNKernel<jit::vsub, float, paddle::platform::CPUPlace>();
  TestXYZNKernel<jit::vsub, double, paddle::platform::CPUPlace>();
}

TEST(JITKernel, pool) {}