/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include <random> #include <string> #include <vector> #include "gflags/gflags.h" #include "glog/logging.h" #include "gtest/gtest.h" #include "paddle/fluid/operators/jit/kernels.h" #include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/place.h" template <typename T> void RandomVec(const int n, T* a, const T lower = static_cast<T>(-20.f), const T upper = static_cast<T>(20.f)) { static unsigned int seed = 100; std::mt19937 rng(seed++); std::uniform_real_distribution<double> uniform_dist(0, 1); for (int i = 0; i < n; ++i) { a[i] = static_cast<T>(uniform_dist(rng) * (upper - lower) + lower); } } template <typename T> void ExpectEQ(const T* target, const T* refer, int n) { if (std::is_floating_point<T>::value) { for (int i = 0; i < n; ++i) { EXPECT_NEAR(target[i], refer[i], 1e-5); } } else { for (int i = 0; i < n; ++i) { EXPECT_EQ(target[i], refer[i]); } } } std::vector<int> TestSizes() { std::vector<int> s; for (int i = 1; i < 32; ++i) { s.push_back(i); } // test some large size s.push_back(100); s.push_back(1000); s.push_back(2000); return s; } namespace jit = paddle::operators::jit; template <typename KernelTuples, typename... Args> struct TestFuncWithRefer { void operator()(const typename KernelTuples::func_type tgt, Args... args) {} }; template <typename T> struct TestFuncWithRefer<jit::XYZNTuples<T>, std::vector<T>, std::vector<T>, std::vector<T>> { void operator()(const typename jit::XYZNTuples<T>::func_type tgt, const std::vector<T>& x, const std::vector<T>& y, const std::vector<T>& zref) { EXPECT_TRUE(tgt != nullptr); EXPECT_EQ(zref.size(), x.size()); EXPECT_EQ(zref.size(), y.size()); const T* x_data = x.data(); const T* y_data = y.data(); const T* zref_data = zref.data(); const int d = zref.size(); std::vector<T> ztgt(d); T* ztgt_data = ztgt.data(); // test normal tgt(x_data, y_data, ztgt_data, d); ExpectEQ<T>(ztgt_data, zref_data, d); // test inplace x std::copy(x.begin(), x.end(), ztgt.begin()); tgt(ztgt_data, y_data, ztgt_data, d); ExpectEQ<T>(ztgt_data, zref_data, d); // test inplace y std::copy(y.begin(), y.end(), ztgt.begin()); tgt(x_data, ztgt_data, ztgt_data, d); ExpectEQ<T>(ztgt_data, zref_data, d); } }; template <typename T> struct TestFuncWithRefer<jit::AXYNTuples<T>, T, std::vector<T>, std::vector<T>> { void operator()(const typename jit::AXYNTuples<T>::func_type tgt, const T a, const std::vector<T>& x, const std::vector<T>& yref) { EXPECT_TRUE(tgt != nullptr); EXPECT_EQ(yref.size(), x.size()); const T* x_data = x.data(); const T* yref_data = yref.data(); const int d = yref.size(); std::vector<T> ytgt(d); T* ytgt_data = ytgt.data(); // test normal tgt(&a, x_data, ytgt_data, d); ExpectEQ<T>(ytgt_data, yref_data, d); // test inplace x std::copy(x.begin(), x.end(), ytgt.begin()); tgt(&a, ytgt_data, ytgt_data, d); ExpectEQ<T>(ytgt_data, yref_data, d); } }; template <typename T> struct TestFuncWithRefer<jit::XYNTuples<T>, std::vector<T>, std::vector<T>> { void operator()(const typename jit::XYNTuples<T>::func_type tgt, const std::vector<T>& x, const std::vector<T>& yref) { EXPECT_TRUE(tgt != nullptr); EXPECT_EQ(yref.size(), x.size()); const T* x_data = x.data(); const T* yref_data = yref.data(); const int d = yref.size(); std::vector<T> ytgt(d); T* ytgt_data = ytgt.data(); // test normal tgt(x_data, ytgt_data, d); ExpectEQ<T>(ytgt_data, yref_data, d); // test inplace x std::copy(x.begin(), x.end(), ytgt.begin()); tgt(ytgt_data, ytgt_data, d); ExpectEQ<T>(ytgt_data, yref_data, d); } }; template <typename T> struct TestFuncWithRefer<jit::LSTMTuples<T>, std::vector<T>, std::vector<T>, std::vector<T>, std::vector<T>, std::vector<T>> { void operator()(const typename jit::LSTMTuples<T>::func_type tgt, const std::vector<T>& xsrc, const std::vector<T>& wp, const std::vector<T>& ct_1, const std::vector<T>& ct_ref, const std::vector<T>& ht_ref, const typename jit::LSTMTuples<T>::attr_type& attr) { EXPECT_TRUE(tgt != nullptr); EXPECT_EQ(ct_ref.size(), ht_ref.size()); EXPECT_EQ(ct_1.size(), ht_ref.size()); EXPECT_EQ(xsrc.size(), 4 * ht_ref.size()); EXPECT_EQ(wp.size(), 3 * ht_ref.size()); // x could be changed after compute, so copy to save src int d = ht_ref.size(); std::vector<T> x(xsrc.size()), ct(ct_ref.size()), ht(ht_ref.size()); std::vector<T> checked(2 * d); std::copy(xsrc.begin(), xsrc.end(), x.begin()); const T* ct_1_data = ct_1.data(); const T* wp_data = wp.data(); const T* ct_ref_data = ct_ref.data(); const T* ht_ref_data = ht_ref.data(); T* x_data = x.data(); T* ct_data = ct.data(); T* ht_data = ht.data(); T* checked_data = checked.data(); paddle::operators::jit::lstm_t step; step.gates = x_data; step.ct_1 = ct_1_data; step.ct = ct_data; step.ht = ht_data; if (attr.use_peephole) { step.wp = wp_data; step.checked = checked_data; } tgt(&step, &attr); ExpectEQ<T>(ct_data, ct_ref_data, d); ExpectEQ<T>(ht_data, ht_ref_data, d); } }; template <typename T> struct TestFuncWithRefer<jit::GRUTuples<T>, std::vector<T>, std::vector<T>, std::vector<T>> { void operator()(const typename jit::GRUTuples<T>::func_type tgt, const std::vector<T>& xsrc, const std::vector<T>& ht_1, const std::vector<T>& ht_ref, const typename jit::GRUTuples<T>::attr_type& attr) { EXPECT_TRUE(tgt != nullptr); EXPECT_EQ(ht_1.size(), ht_ref.size()); EXPECT_EQ(xsrc.size(), 3 * ht_ref.size()); // x could be changed after compute, so copy to save src int d = ht_ref.size(); std::vector<T> x(xsrc.size()), ht(ht_ref.size()); std::copy(xsrc.begin(), xsrc.end(), x.begin()); const T* ht_1_data = ht_1.data(); const T* ht_ref_data = ht_ref.data(); T* x_data = x.data(); T* ht_data = ht.data(); paddle::operators::jit::gru_t step; step.gates = x_data; step.ht_1 = ht_1_data; step.ht = ht_data; tgt(&step, &attr); ExpectEQ<T>(ht_data, ht_ref_data, d); } }; template <typename T> struct TestFuncWithRefer<jit::SeqPoolTuples<T>, std::vector<T>, std::vector<T>> { void operator()(const typename jit::SeqPoolTuples<T>::func_type tgt, const std::vector<T>& x, const std::vector<T>& yref, const typename jit::SeqPoolTuples<T>::attr_type& attr) { EXPECT_TRUE(tgt != nullptr); EXPECT_EQ(x.size() % yref.size(), 0); int w = yref.size(); std::vector<T> y(w); const T* x_data = x.data(); const T* yref_data = yref.data(); T* y_data = y.data(); tgt(x_data, y_data, &attr); ExpectEQ<T>(y_data, yref_data, w); } }; template <typename T> struct TestFuncWithRefer<jit::MatMulTuples<T>, std::vector<T>, std::vector<T>, std::vector<T>, int, int, int> { void operator()(const typename jit::MatMulTuples<T>::func_type tgt, const std::vector<T>& a, const std::vector<T>& b, const std::vector<T>& cref, int m, int n, int k) { EXPECT_TRUE(tgt != nullptr); EXPECT_EQ(a.size(), static_cast<size_t>(m * k)); EXPECT_EQ(b.size(), static_cast<size_t>(k * n)); EXPECT_EQ(cref.size(), static_cast<size_t>(m * n)); std::vector<T> c(cref.size()); const T* a_data = a.data(); const T* b_data = b.data(); const T* cref_data = cref.data(); T* c_data = c.data(); tgt(a_data, b_data, c_data, m, n, k); ExpectEQ<T>(c_data, cref_data, m * n); } }; template <paddle::operators::jit::KernelType KT, typename KernelTuples, typename PlaceType, typename... Args> void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) { TestFuncWithRefer<KernelTuples, Args...> test; // test jitcode auto jitcode = jit::GetJitCode<KT, KernelTuples, PlaceType>(attr); if (jitcode) { VLOG(10) << "Test Jitcode Kernel "; test(jitcode, args...); } // test all impls in more jit::KernelKey kkey(KT, PlaceType()); auto& pool = jit::KernelPool().Instance().AllKernels(); auto iter = pool.find(kkey); if (iter != pool.end()) { auto& impls = iter->second; for (auto& impl : impls) { auto i = dynamic_cast<const jit::KernelMore<KernelTuples>*>(impl.get()); if (i && i->UseMe(attr)) { auto more = i->GetFunc(); VLOG(10) << "Test More Kernel : " << i->ImplType(); test(more, args...); } } } // test result from Get function // VLOG(10) << "Test Get function "; auto tgt = jit::Get<KT, KernelTuples, PlaceType>(attr); test(tgt, args...); } template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> void TestXYZNKernel() { namespace jit = paddle::operators::jit; VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); for (int d : TestSizes()) { auto ref = jit::GetRefer<KT, jit::XYZNTuples<T>>(); EXPECT_TRUE(ref != nullptr); std::vector<T> x(d), y(d), zref(d); RandomVec<T>(d, x.data()); RandomVec<T>(d, y.data()); std::vector<T> xinp(d), yinp(d); // inplace test std::copy(x.begin(), x.end(), xinp.begin()); std::copy(y.begin(), y.end(), yinp.begin()); const T* x_data = x.data(); const T* y_data = y.data(); T* zref_data = zref.data(); T* xinp_data = xinp.data(); T* yinp_data = yinp.data(); // test refer code inplace ref(x_data, y_data, zref_data, d); ref(x_data, yinp_data, yinp_data, d); ref(xinp_data, y_data, xinp_data, d); ExpectEQ<T>(xinp_data, zref_data, d); ExpectEQ<T>(yinp_data, zref_data, d); TestAllImpls<KT, jit::XYZNTuples<T>, PlaceType, std::vector<T>, std::vector<T>, std::vector<T>>(d, x, y, zref); } } template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> void TestAXYNKernel() { namespace jit = paddle::operators::jit; VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); for (int d : TestSizes()) { auto ref = jit::GetRefer<KT, jit::AXYNTuples<T>>(); EXPECT_TRUE(ref != nullptr); const T a = static_cast<T>(3); std::vector<T> x(d), yref(d); std::vector<T> xinp(d); // inplace test RandomVec<T>(d, x.data()); std::copy(x.begin(), x.end(), xinp.begin()); const T* x_data = x.data(); T* yref_data = yref.data(); T* xinp_data = xinp.data(); // test refer code inplace ref(&a, x_data, yref_data, d); ref(&a, xinp_data, xinp_data, d); ExpectEQ<T>(xinp_data, yref_data, d); TestAllImpls<KT, jit::AXYNTuples<T>, PlaceType, T, std::vector<T>, std::vector<T>>(d, a, x, yref); } } template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> void TestXYNKernel() { namespace jit = paddle::operators::jit; VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); for (int d : TestSizes()) { auto ref = jit::GetRefer<KT, jit::XYNTuples<T>>(); EXPECT_TRUE(ref != nullptr); std::vector<T> x(d), yref(d); std::vector<T> xinp(d); // inplace test RandomVec<T>(d, x.data(), -2.f, 2.f); std::copy(x.begin(), x.end(), xinp.begin()); const T* x_data = x.data(); T* yref_data = yref.data(); T* xinp_data = xinp.data(); // test refer code inplace ref(x_data, yref_data, d); ref(xinp_data, xinp_data, d); ExpectEQ<T>(xinp_data, yref_data, d); TestAllImpls<KT, jit::XYNTuples<T>, PlaceType, std::vector<T>, std::vector<T>>(d, x, yref); } } template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> void TestLSTMKernel() { namespace jit = paddle::operators::jit; VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); std::vector<std::string> all_acts = {"sigmoid", "tanh", "relu", "identity"}; for (int d : TestSizes()) { for (bool use_peephole : {true, false}) { for (auto& act_gate : all_acts) { for (auto& act_cand : all_acts) { for (auto& act_cell : all_acts) { const jit::lstm_attr_t attr( d, jit::to_kerneltype(act_gate), jit::to_kerneltype(act_cand), jit::to_kerneltype(act_cell), use_peephole); auto ref = jit::GetRefer<KT, jit::LSTMTuples<T>>(); EXPECT_TRUE(ref != nullptr); std::vector<T> xsrc(4 * d), wp(3 * d), ct_1(d); std::vector<T> ct_ref(d), ht_ref(d), checked(2 * d); RandomVec<T>(4 * d, xsrc.data(), -2.f, 2.f); RandomVec<T>(3 * d, wp.data(), -2.f, 2.f); RandomVec<T>(d, ct_1.data(), -2.f, 2.f); // x could be changed after compute, so copy to save src std::vector<T> x(xsrc.size()); std::copy(xsrc.begin(), xsrc.end(), x.begin()); const T* ct_1_data = ct_1.data(); const T* wp_data = wp.data(); T* x_data = x.data(); T* checked_data = checked.data(); T* ct_ref_data = ct_ref.data(); T* ht_ref_data = ht_ref.data(); jit::lstm_t step; step.gates = x_data; step.ct_1 = ct_1_data; step.ct = ct_ref_data; step.ht = ht_ref_data; if (use_peephole) { step.wp = wp_data; step.checked = checked_data; } ref(&step, &attr); VLOG(10) << attr; TestAllImpls<KT, jit::LSTMTuples<T>, PlaceType, std::vector<T>, std::vector<T>, std::vector<T>, std::vector<T>, std::vector<T>>(attr, xsrc, wp, ct_1, ct_ref, ht_ref, attr); } } } } } } template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> void TestGRUKernel() { namespace jit = paddle::operators::jit; VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); std::vector<std::string> all_acts = {"sigmoid", "tanh", "relu", "identity"}; for (int d : TestSizes()) { for (auto& act_gate : all_acts) { for (auto& act_cand : all_acts) { const jit::gru_attr_t attr(d, jit::to_kerneltype(act_gate), jit::to_kerneltype(act_cand)); auto ref = jit::GetRefer<KT, jit::GRUTuples<T>>(); EXPECT_TRUE(ref != nullptr); std::vector<T> xsrc(3 * d), ht_1(d), ht_ref(d); RandomVec<T>(3 * d, xsrc.data(), -2.f, 2.f); RandomVec<T>(d, ht_1.data(), -2.f, 2.f); // x could be changed after compute, so copy to save src std::vector<T> x(xsrc.size()); std::copy(xsrc.begin(), xsrc.end(), x.begin()); const T* ht_1_data = ht_1.data(); T* x_data = x.data(); T* ht_ref_data = ht_ref.data(); jit::gru_t step; step.gates = x_data; step.ht_1 = ht_1_data; step.ht = ht_ref_data; ref(&step, &attr); VLOG(10) << attr; TestAllImpls<KT, jit::GRUTuples<T>, PlaceType, std::vector<T>, std::vector<T>, std::vector<T>>(attr, xsrc, ht_1, ht_ref, attr); } } } } template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> void TestSeqPoolKernel() { VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); std::vector<jit::SeqPoolType> pool_types = { jit::SeqPoolType::kSum, jit::SeqPoolType::kAvg, jit::SeqPoolType::kSqrt}; for (auto type : pool_types) { for (int w : TestSizes()) { jit::seq_pool_attr_t attr(w, type); for (int h : TestSizes()) { attr.h = h; auto ref = jit::GetRefer<KT, jit::SeqPoolTuples<T>>(); EXPECT_TRUE(ref != nullptr); std::vector<T> x(h * w), yref(w); RandomVec<T>(h * w, x.data(), -2.f, 2.f); const T* x_data = x.data(); T* yref_data = yref.data(); ref(x_data, yref_data, &attr); VLOG(10) << attr; TestAllImpls<KT, jit::SeqPoolTuples<T>, PlaceType, std::vector<T>, std::vector<T>>(attr, x, yref, attr); } } } } template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> void TestMatMulKernel() { VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); for (int m : {1, 2, 3, 4}) { for (int n : {1, 2, 3, 4}) { for (int k : TestSizes()) { auto ref = jit::GetRefer<KT, jit::MatMulTuples<T>>(); EXPECT_TRUE(ref != nullptr); std::vector<T> a(m * k), b(k * n), c(m * n); RandomVec<T>(m * k, a.data(), -0.2f, 0.2f); RandomVec<T>(k * n, b.data(), -0.2f, 0.2f); const T* a_data = a.data(); const T* b_data = b.data(); T* c_data = c.data(); ref(a_data, b_data, c_data, m, n, k); TestAllImpls<KT, jit::MatMulTuples<T>, PlaceType, std::vector<T>, std::vector<T>, std::vector<T>>(k, a, b, c, m, n, k); } } } } template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> void TestNCHW16CMulNCKernel() { VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); const int n = 3, c = 16 * 4, h = 10, w = 10; auto ref = jit::GetRefer<KT, jit::NCHW16CMulNCTuples<T>>(); EXPECT_TRUE(ref != nullptr); int sz = n * c * h * w; std::vector<T> x(sz), y(n * c), zref(sz); std::vector<T> ztgt(sz), zjit(sz); RandomVec<T>(sz, x.data(), -2.f, 2.f); RandomVec<T>(n * c, y.data(), -2.f, 2.f); const T* x_data = x.data(); const T* y_data = y.data(); T* zref_data = zref.data(); T* ztgt_data = ztgt.data(); T* zjit_data = zjit.data(); constexpr int simd_width = ZMM_FLOAT_BLOCK; int C = c / simd_width; auto tgt = jit::Get<KT, jit::NCHW16CMulNCTuples<T>, PlaceType>(0); auto jitcode = jit::GetJitCode<KT, jit::NCHW16CMulNCTuples<T>, PlaceType>(0); EXPECT_TRUE(tgt != nullptr); if (std::is_same<T, float>::value && paddle::platform::MayIUse(paddle::platform::avx512f)) { EXPECT_TRUE(jitcode != nullptr); } for (int ni = 0; ni < n; ni++) { for (int ci = 0; ci < C; ci++) { auto ptr_x = x_data + ni * C * h * w * simd_width + ci * h * w * simd_width; auto ptr_y = y_data + ni * C * simd_width + ci * simd_width; auto ptr_zref = zref_data + ni * C * h * w * simd_width + ci * h * w * simd_width; auto ptr_ztgt = ztgt_data + ni * C * h * w * simd_width + ci * h * w * simd_width; ref(ptr_x, ptr_y, ptr_zref, h, w); tgt(ptr_x, ptr_y, ptr_ztgt, h, w); if (jitcode) { auto ptr_zjit = zjit_data + ni * C * h * w * simd_width + ci * h * w * simd_width; jitcode(ptr_x, ptr_y, ptr_zjit, h, w); } } } ExpectEQ<T>(ztgt_data, zref_data, sz); if (jitcode) { ExpectEQ<T>(zjit_data, zref_data, sz); } } // XYZNTuple TEST(JITKernel, kVMul) { namespace jit = paddle::operators::jit; TestXYZNKernel<jit::kVMul, float, paddle::platform::CPUPlace>(); TestXYZNKernel<jit::kVMul, double, paddle::platform::CPUPlace>(); } TEST(JITKernel, kVAdd) { namespace jit = paddle::operators::jit; TestXYZNKernel<jit::kVAdd, float, paddle::platform::CPUPlace>(); TestXYZNKernel<jit::kVAdd, double, paddle::platform::CPUPlace>(); } TEST(JITKernel, kVAddRelu) { namespace jit = paddle::operators::jit; TestXYZNKernel<jit::kVAddRelu, float, paddle::platform::CPUPlace>(); TestXYZNKernel<jit::kVAddRelu, double, paddle::platform::CPUPlace>(); } TEST(JITKernel, kVSub) { namespace jit = paddle::operators::jit; TestXYZNKernel<jit::kVSub, float, paddle::platform::CPUPlace>(); TestXYZNKernel<jit::kVSub, double, paddle::platform::CPUPlace>(); } // AXYNTuples TEST(JITKernel, kVScal) { namespace jit = paddle::operators::jit; TestAXYNKernel<jit::kVScal, float, paddle::platform::CPUPlace>(); TestAXYNKernel<jit::kVScal, double, paddle::platform::CPUPlace>(); } TEST(JITKernel, kVAddBias) { namespace jit = paddle::operators::jit; TestAXYNKernel<jit::kVAddBias, float, paddle::platform::CPUPlace>(); TestAXYNKernel<jit::kVAddBias, double, paddle::platform::CPUPlace>(); } // XYNTuples TEST(JITKernel, kVRelu) { namespace jit = paddle::operators::jit; TestXYNKernel<jit::kVRelu, float, paddle::platform::CPUPlace>(); TestXYNKernel<jit::kVRelu, double, paddle::platform::CPUPlace>(); } TEST(JITKernel, kVIdentity) { namespace jit = paddle::operators::jit; TestXYNKernel<jit::kVIdentity, float, paddle::platform::CPUPlace>(); TestXYNKernel<jit::kVIdentity, double, paddle::platform::CPUPlace>(); } TEST(JITKernel, kVSquare) { namespace jit = paddle::operators::jit; TestXYNKernel<jit::kVSquare, float, paddle::platform::CPUPlace>(); TestXYNKernel<jit::kVSquare, double, paddle::platform::CPUPlace>(); } TEST(JITKernel, kVExp) { namespace jit = paddle::operators::jit; TestXYNKernel<jit::kVExp, float, paddle::platform::CPUPlace>(); TestXYNKernel<jit::kVExp, double, paddle::platform::CPUPlace>(); } TEST(JITKernel, kVSigmoid) { namespace jit = paddle::operators::jit; TestXYNKernel<jit::kVSigmoid, float, paddle::platform::CPUPlace>(); TestXYNKernel<jit::kVSigmoid, double, paddle::platform::CPUPlace>(); } TEST(JITKernel, kVTanh) { namespace jit = paddle::operators::jit; TestXYNKernel<jit::kVTanh, float, paddle::platform::CPUPlace>(); TestXYNKernel<jit::kVTanh, double, paddle::platform::CPUPlace>(); } // LSTM TEST(JITKernel, kLSTMCtHt) { namespace jit = paddle::operators::jit; TestLSTMKernel<jit::kLSTMCtHt, float, paddle::platform::CPUPlace>(); TestLSTMKernel<jit::kLSTMCtHt, double, paddle::platform::CPUPlace>(); } TEST(JITKernel, kLSTMC1H1) { namespace jit = paddle::operators::jit; TestLSTMKernel<jit::kLSTMC1H1, float, paddle::platform::CPUPlace>(); TestLSTMKernel<jit::kLSTMC1H1, double, paddle::platform::CPUPlace>(); } // GRU TEST(JITKernel, kGRUH1) { namespace jit = paddle::operators::jit; TestGRUKernel<jit::kGRUH1, float, paddle::platform::CPUPlace>(); TestGRUKernel<jit::kGRUH1, double, paddle::platform::CPUPlace>(); } TEST(JITKernel, kGRUHtPart1) { namespace jit = paddle::operators::jit; TestGRUKernel<jit::kGRUHtPart1, float, paddle::platform::CPUPlace>(); TestGRUKernel<jit::kGRUHtPart1, double, paddle::platform::CPUPlace>(); } TEST(JITKernel, kGRUHtPart2) { namespace jit = paddle::operators::jit; TestGRUKernel<jit::kGRUHtPart2, float, paddle::platform::CPUPlace>(); TestGRUKernel<jit::kGRUHtPart2, double, paddle::platform::CPUPlace>(); } TEST(JITKernel, kSeqPool) { namespace jit = paddle::operators::jit; TestSeqPoolKernel<jit::kSeqPool, float, paddle::platform::CPUPlace>(); TestSeqPoolKernel<jit::kSeqPool, double, paddle::platform::CPUPlace>(); } TEST(JITKernel, kMatMul) { namespace jit = paddle::operators::jit; TestMatMulKernel<jit::kMatMul, float, paddle::platform::CPUPlace>(); TestMatMulKernel<jit::kMatMul, double, paddle::platform::CPUPlace>(); } TEST(JITKernel, kNCHW16CMulNC) { namespace jit = paddle::operators::jit; TestNCHW16CMulNCKernel<jit::kNCHW16CMulNC, float, paddle::platform::CPUPlace>(); TestNCHW16CMulNCKernel<jit::kNCHW16CMulNC, double, paddle::platform::CPUPlace>(); } // TODO(yihua/TJ): add crf decoding and layer norm unit tests TEST(JITKernel, pool) { // TODO(TJ): add some test }