diff --git a/paddle/fluid/operators/jit/test.cc b/paddle/fluid/operators/jit/test.cc index d994a11f97d9aa9d9e41d3e7442ceedb1285edcb..62d4cdc19ae05870789ca624a454ba9080d82e3b 100644 --- a/paddle/fluid/operators/jit/test.cc +++ b/paddle/fluid/operators/jit/test.cc @@ -57,31 +57,188 @@ std::vector TestSizes() { return s; } -template -void TestXYZNFunc(const typename KernelTuples::func_type tgt, +namespace jit = paddle::operators::jit; + +template +struct TestFuncWithRefer { + void operator()(const typename KernelTuples::func_type tgt, Args... args) {} +}; + +template +struct TestFuncWithRefer, std::vector, std::vector, + std::vector> { + void operator()(const typename jit::XYZNTuples::func_type tgt, const std::vector& x, const std::vector& y, const std::vector& zref) { - EXPECT_TRUE(tgt != nullptr); - EXPECT_EQ(zref.size(), x.size()); - EXPECT_EQ(zref.size(), y.size()); - const T* x_data = x.data(); - const T* y_data = y.data(); - const T* zref_data = zref.data(); - const int d = zref.size(); - - std::vector ztgt(d); - T* ztgt_data = ztgt.data(); - // test normal - tgt(x_data, y_data, ztgt_data, d); - ExpectEQ(ztgt_data, zref_data, d); - // test inplace x - std::copy(x.begin(), x.end(), ztgt.begin()); - tgt(ztgt_data, y_data, ztgt_data, d); - ExpectEQ(ztgt_data, zref_data, d); - // test inplace y - std::copy(y.begin(), y.end(), ztgt.begin()); - tgt(x_data, ztgt_data, ztgt_data, d); - ExpectEQ(ztgt_data, zref_data, d); + EXPECT_TRUE(tgt != nullptr); + EXPECT_EQ(zref.size(), x.size()); + EXPECT_EQ(zref.size(), y.size()); + const T* x_data = x.data(); + const T* y_data = y.data(); + const T* zref_data = zref.data(); + const int d = zref.size(); + + std::vector ztgt(d); + T* ztgt_data = ztgt.data(); + // test normal + tgt(x_data, y_data, ztgt_data, d); + ExpectEQ(ztgt_data, zref_data, d); + // test inplace x + std::copy(x.begin(), x.end(), ztgt.begin()); + tgt(ztgt_data, y_data, ztgt_data, d); + ExpectEQ(ztgt_data, zref_data, d); + // test inplace y + std::copy(y.begin(), y.end(), ztgt.begin()); + tgt(x_data, ztgt_data, ztgt_data, d); + ExpectEQ(ztgt_data, zref_data, d); + } +}; + +template +struct TestFuncWithRefer, T, std::vector, + std::vector> { + void operator()(const typename jit::AXYNTuples::func_type tgt, const T a, + const std::vector& x, const std::vector& yref) { + EXPECT_TRUE(tgt != nullptr); + EXPECT_EQ(yref.size(), x.size()); + const T* x_data = x.data(); + const T* yref_data = yref.data(); + const int d = yref.size(); + std::vector ytgt(d); + T* ytgt_data = ytgt.data(); + // test normal + tgt(&a, x_data, ytgt_data, d); + ExpectEQ(ytgt_data, yref_data, d); + // test inplace x + std::copy(x.begin(), x.end(), ytgt.begin()); + tgt(&a, ytgt_data, ytgt_data, d); + ExpectEQ(ytgt_data, yref_data, d); + } +}; + +template +struct TestFuncWithRefer, std::vector, std::vector> { + void operator()(const typename jit::XYNTuples::func_type tgt, + const std::vector& x, const std::vector& yref) { + EXPECT_TRUE(tgt != nullptr); + EXPECT_EQ(yref.size(), x.size()); + const T* x_data = x.data(); + const T* yref_data = yref.data(); + const int d = yref.size(); + std::vector ytgt(d); + T* ytgt_data = ytgt.data(); + // test normal + tgt(x_data, ytgt_data, d); + ExpectEQ(ytgt_data, yref_data, d); + // test inplace x + std::copy(x.begin(), x.end(), ytgt.begin()); + tgt(ytgt_data, ytgt_data, d); + ExpectEQ(ytgt_data, yref_data, d); + } +}; + +template +struct TestFuncWithRefer, std::vector, std::vector, + std::vector, std::vector, std::vector> { + void operator()(const typename jit::LSTMTuples::func_type tgt, + const std::vector& xsrc, const std::vector& wp, + const std::vector& ct_1, const std::vector& ct_ref, + const std::vector& ht_ref, + const typename jit::LSTMTuples::attr_type& attr) { + EXPECT_TRUE(tgt != nullptr); + EXPECT_EQ(ct_ref.size(), ht_ref.size()); + EXPECT_EQ(ct_1.size(), ht_ref.size()); + EXPECT_EQ(xsrc.size(), 4 * ht_ref.size()); + EXPECT_EQ(wp.size(), 3 * ht_ref.size()); + + // x could be changed after compute, so copy to save src + int d = ht_ref.size(); + std::vector x(xsrc.size()), ct(ct_ref.size()), ht(ht_ref.size()); + std::vector checked(2 * d); + std::copy(xsrc.begin(), xsrc.end(), x.begin()); + + const T* ct_1_data = ct_1.data(); + const T* wp_data = wp.data(); + const T* ct_ref_data = ct_ref.data(); + const T* ht_ref_data = ht_ref.data(); + T* x_data = x.data(); + T* ct_data = ct.data(); + T* ht_data = ht.data(); + T* checked_data = checked.data(); + + paddle::operators::jit::lstm_t step; + step.gates = x_data; + step.ct_1 = ct_1_data; + step.ct = ct_data; + step.ht = ht_data; + if (attr.use_peephole) { + step.wp = wp_data; + step.checked = checked_data; + } + + tgt(&step, &attr); + ExpectEQ(ct_data, ct_ref_data, d); + ExpectEQ(ht_data, ht_ref_data, d); + } +}; + +template +struct TestFuncWithRefer, std::vector, std::vector, + std::vector> { + void operator()(const typename jit::GRUTuples::func_type tgt, + const std::vector& xsrc, const std::vector& ht_1, + const std::vector& ht_ref, + const typename jit::GRUTuples::attr_type& attr) { + EXPECT_TRUE(tgt != nullptr); + EXPECT_EQ(ht_1.size(), ht_ref.size()); + EXPECT_EQ(xsrc.size(), 3 * ht_ref.size()); + + // x could be changed after compute, so copy to save src + int d = ht_ref.size(); + std::vector x(xsrc.size()), ht(ht_ref.size()); + std::copy(xsrc.begin(), xsrc.end(), x.begin()); + const T* ht_1_data = ht_1.data(); + const T* ht_ref_data = ht_ref.data(); + T* x_data = x.data(); + T* ht_data = ht.data(); + paddle::operators::jit::gru_t step; + step.gates = x_data; + step.ht_1 = ht_1_data; + step.ht = ht_data; + tgt(&step, &attr); + ExpectEQ(ht_data, ht_ref_data, d); + } +}; + +template +void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) { + TestFuncWithRefer test; + // test jitcode + auto jitcode = jit::GetJitCode(attr); + if (jitcode) { + VLOG(10) << "Test Jitcode Kernel "; + test(jitcode, args...); + } + // test all impls in more + jit::KernelKey kkey(KT, PlaceType()); + auto& pool = jit::KernelPool().Instance().AllKernels(); + auto iter = pool.find(kkey); + if (iter != pool.end()) { + auto& impls = iter->second; + for (auto& impl : impls) { + auto i = dynamic_cast*>(impl.get()); + if (i && i->UseMe(attr)) { + auto more = i->GetFunc(); + VLOG(10) << "Test More Kernel "; + test(more, args...); + } + } + } + // test result from Get function + VLOG(10) << "Test Get function "; + auto tgt = jit::Get(attr); + test(tgt, args...); } template @@ -113,79 +270,11 @@ void TestXYZNKernel() { ExpectEQ(xinp_data, zref_data, d); ExpectEQ(yinp_data, zref_data, d); - // test jitcode - auto jitcode = jit::GetJitCode, PlaceType>(d); - if (jitcode) { - VLOG(10) << "Test Jitcode Kernel, size: " << d; - TestXYZNFunc>(jitcode, x, y, zref); - } - - // test all impls in more - jit::KernelKey kkey(KT, PlaceType()); - auto& pool = jit::KernelPool().Instance().AllKernels(); - auto iter = pool.find(kkey); - if (iter != pool.end()) { - auto& impls = iter->second; - for (auto& impl : impls) { - auto i = dynamic_cast>*>( - impl.get()); - if (i && i->UseMe(d)) { - auto more = i->GetFunc(); - VLOG(10) << "Test More Kernel, size: " << d; - TestXYZNFunc>(more, x, y, zref); - } - } - } - // Test result from Get function - VLOG(10) << "Test Get function, size: " << d; - auto tgt = jit::Get, PlaceType>(d); - TestXYZNFunc>(tgt, x, y, zref); + TestAllImpls, PlaceType, std::vector, + std::vector, std::vector>(d, x, y, zref); } } -TEST(JITKernel, vmul) { - namespace jit = paddle::operators::jit; - TestXYZNKernel(); - TestXYZNKernel(); -} - -TEST(JITKernel, vadd) { - namespace jit = paddle::operators::jit; - TestXYZNKernel(); - TestXYZNKernel(); -} - -TEST(JITKernel, vaddrelu) { - namespace jit = paddle::operators::jit; - TestXYZNKernel(); - TestXYZNKernel(); -} - -TEST(JITKernel, vsub) { - namespace jit = paddle::operators::jit; - TestXYZNKernel(); - TestXYZNKernel(); -} - -template -void TestAXYNFunc(const typename KernelTuples::func_type tgt, const T a, - const std::vector& x, const std::vector& yref) { - EXPECT_TRUE(tgt != nullptr); - EXPECT_EQ(yref.size(), x.size()); - const T* x_data = x.data(); - const T* yref_data = yref.data(); - const int d = yref.size(); - std::vector ytgt(d); - T* ytgt_data = ytgt.data(); - // test normal - tgt(&a, x_data, ytgt_data, d); - ExpectEQ(ytgt_data, yref_data, d); - // test inplace x - std::copy(x.begin(), x.end(), ytgt.begin()); - tgt(&a, ytgt_data, ytgt_data, d); - ExpectEQ(ytgt_data, yref_data, d); -} - template void TestAXYNKernel() { namespace jit = paddle::operators::jit; @@ -208,67 +297,11 @@ void TestAXYNKernel() { ref(&a, xinp_data, xinp_data, d); ExpectEQ(xinp_data, yref_data, d); - // test jitcode - auto jitcode = jit::GetJitCode, PlaceType>(d); - if (jitcode) { - VLOG(10) << "Test Jitcode Kernel, size: " << d; - TestAXYNFunc>(jitcode, a, x, yref); - } - - // test all impls in more - jit::KernelKey kkey(KT, PlaceType()); - auto& pool = jit::KernelPool().Instance().AllKernels(); - auto iter = pool.find(kkey); - if (iter != pool.end()) { - auto& impls = iter->second; - for (auto& impl : impls) { - auto i = dynamic_cast>*>( - impl.get()); - if (i && i->UseMe(d)) { - auto more = i->GetFunc(); - VLOG(10) << "Test More Kernel, size: " << d; - TestAXYNFunc>(more, a, x, yref); - } - } - } - // Test result from Get function - VLOG(10) << "Test Get function, size: " << d; - auto tgt = jit::Get, PlaceType>(d); - TestAXYNFunc>(tgt, a, x, yref); + TestAllImpls, PlaceType, T, std::vector, + std::vector>(d, a, x, yref); } } -TEST(JITKernel, vscal) { - namespace jit = paddle::operators::jit; - TestAXYNKernel(); - TestAXYNKernel(); -} - -TEST(JITKernel, vaddbias) { - namespace jit = paddle::operators::jit; - TestAXYNKernel(); - TestAXYNKernel(); -} - -template -void TestXYNFunc(const typename KernelTuples::func_type tgt, - const std::vector& x, const std::vector& yref) { - EXPECT_TRUE(tgt != nullptr); - EXPECT_EQ(yref.size(), x.size()); - const T* x_data = x.data(); - const T* yref_data = yref.data(); - const int d = yref.size(); - std::vector ytgt(d); - T* ytgt_data = ytgt.data(); - // test normal - tgt(x_data, ytgt_data, d); - ExpectEQ(ytgt_data, yref_data, d); - // test inplace x - std::copy(x.begin(), x.end(), ytgt.begin()); - tgt(ytgt_data, ytgt_data, d); - ExpectEQ(ytgt_data, yref_data, d); -} - template void TestXYNKernel() { namespace jit = paddle::operators::jit; @@ -290,108 +323,11 @@ void TestXYNKernel() { ref(xinp_data, xinp_data, d); ExpectEQ(xinp_data, yref_data, d); - // test jitcode - auto jitcode = jit::GetJitCode, PlaceType>(d); - if (jitcode) { - VLOG(10) << "Test Jitcode Kernel, size: " << d; - TestXYNFunc>(jitcode, x, yref); - } - - // test all impls in more - jit::KernelKey kkey(KT, PlaceType()); - auto& pool = jit::KernelPool().Instance().AllKernels(); - auto iter = pool.find(kkey); - if (iter != pool.end()) { - auto& impls = iter->second; - for (auto& impl : impls) { - auto i = - dynamic_cast>*>(impl.get()); - if (i && i->UseMe(d)) { - auto more = i->GetFunc(); - VLOG(10) << "Test More Kernel, size: " << d; - TestXYNFunc>(more, x, yref); - } - } - } - // Test result from Get function - VLOG(10) << "Test Get function, size: " << d; - auto tgt = jit::Get, PlaceType>(d); - TestXYNFunc>(tgt, x, yref); + TestAllImpls, PlaceType, std::vector, + std::vector>(d, x, yref); } } -TEST(JITKernel, vrelu) { - namespace jit = paddle::operators::jit; - TestXYNKernel(); - TestXYNKernel(); -} - -TEST(JITKernel, videntity) { - namespace jit = paddle::operators::jit; - TestXYNKernel(); - TestXYNKernel(); -} - -TEST(JITKernel, vexp) { - namespace jit = paddle::operators::jit; - TestXYNKernel(); - TestXYNKernel(); -} - -TEST(JITKernel, vsigmoid) { - namespace jit = paddle::operators::jit; - TestXYNKernel(); - TestXYNKernel(); -} - -TEST(JITKernel, vtanh) { - namespace jit = paddle::operators::jit; - TestXYNKernel(); - TestXYNKernel(); -} - -template -void TestLSTMFunc(const typename KernelTuples::func_type tgt, - const std::vector& xsrc, const std::vector& wp, - const std::vector& ct_1, const std::vector& ct_ref, - const std::vector& ht_ref, - const paddle::operators::jit::lstm_attr_t& attr) { - EXPECT_TRUE(tgt != nullptr); - EXPECT_EQ(ct_ref.size(), ht_ref.size()); - EXPECT_EQ(ct_1.size(), ht_ref.size()); - EXPECT_EQ(xsrc.size(), 4 * ht_ref.size()); - EXPECT_EQ(wp.size(), 3 * ht_ref.size()); - - // x could be changed after compute, so copy to save src - int d = ht_ref.size(); - std::vector x(xsrc.size()), ct(ct_ref.size()), ht(ht_ref.size()); - std::vector checked(2 * d); - std::copy(xsrc.begin(), xsrc.end(), x.begin()); - - const T* ct_1_data = ct_1.data(); - const T* wp_data = wp.data(); - const T* ct_ref_data = ct_ref.data(); - const T* ht_ref_data = ht_ref.data(); - T* x_data = x.data(); - T* ct_data = ct.data(); - T* ht_data = ht.data(); - T* checked_data = checked.data(); - - paddle::operators::jit::lstm_t step; - step.gates = x_data; - step.ct_1 = ct_1_data; - step.ct = ct_data; - step.ht = ht_data; - if (attr.use_peephole) { - step.wp = wp_data; - step.checked = checked_data; - } - - tgt(&step, &attr); - ExpectEQ(ct_data, ct_ref_data, d); - ExpectEQ(ht_data, ht_ref_data, d); -} - template void TestLSTMKernel() { namespace jit = paddle::operators::jit; @@ -435,37 +371,10 @@ void TestLSTMKernel() { } ref(&step, &attr); - // test jitcode - auto jitcode = - jit::GetJitCode, PlaceType>(attr); - if (jitcode) { - VLOG(10) << "Test Jitcode Kernel " << info; - TestLSTMFunc>(jitcode, xsrc, wp, ct_1, - ct_ref, ht_ref, attr); - } - - // test all impls in more - jit::KernelKey kkey(KT, PlaceType()); - auto& pool = jit::KernelPool().Instance().AllKernels(); - auto iter = pool.find(kkey); - if (iter != pool.end()) { - auto& impls = iter->second; - for (auto& impl : impls) { - auto i = - dynamic_cast>*>( - impl.get()); - if (i && i->UseMe(attr)) { - auto more = i->GetFunc(); - VLOG(10) << "Test More Kernel " << info; - TestLSTMFunc>(more, xsrc, wp, ct_1, - ct_ref, ht_ref, attr); - } - } - } - // Test result from Get function - auto tgt = jit::Get, PlaceType>(attr); - TestLSTMFunc>(tgt, xsrc, wp, ct_1, ct_ref, - ht_ref, attr); + TestAllImpls, PlaceType, std::vector, + std::vector, std::vector, std::vector, + std::vector>(attr, xsrc, wp, ct_1, ct_ref, ht_ref, + attr); } } } @@ -473,43 +382,6 @@ void TestLSTMKernel() { } } -TEST(JITKernel, lstmctht) { - namespace jit = paddle::operators::jit; - TestLSTMKernel(); - TestLSTMKernel(); -} - -TEST(JITKernel, lstmc1h1) { - namespace jit = paddle::operators::jit; - TestLSTMKernel(); - TestLSTMKernel(); -} - -template -void TestGRUFunc(const typename KernelTuples::func_type tgt, - const std::vector& xsrc, const std::vector& ht_1, - const std::vector& ht_ref, - const paddle::operators::jit::gru_attr_t& attr) { - EXPECT_TRUE(tgt != nullptr); - EXPECT_EQ(ht_1.size(), ht_ref.size()); - EXPECT_EQ(xsrc.size(), 3 * ht_ref.size()); - - // x could be changed after compute, so copy to save src - int d = ht_ref.size(); - std::vector x(xsrc.size()), ht(ht_ref.size()); - std::copy(xsrc.begin(), xsrc.end(), x.begin()); - const T* ht_1_data = ht_1.data(); - const T* ht_ref_data = ht_ref.data(); - T* x_data = x.data(); - T* ht_data = ht.data(); - paddle::operators::jit::gru_t step; - step.gates = x_data; - step.ht_1 = ht_1_data; - step.ht = ht_data; - tgt(&step, &attr); - ExpectEQ(ht_data, ht_ref_data, d); -} - template void TestGRUKernel() { namespace jit = paddle::operators::jit; @@ -538,37 +410,97 @@ void TestGRUKernel() { step.ht = ht_ref_data; ref(&step, &attr); - // test jitcode - auto jitcode = jit::GetJitCode, PlaceType>(attr); - if (jitcode) { - VLOG(10) << "Test Jitcode Kernel " << info; - TestGRUFunc>(jitcode, xsrc, ht_1, ht_ref, attr); - } - - // test all impls in more - jit::KernelKey kkey(KT, PlaceType()); - auto& pool = jit::KernelPool().Instance().AllKernels(); - auto iter = pool.find(kkey); - if (iter != pool.end()) { - auto& impls = iter->second; - for (auto& impl : impls) { - auto i = dynamic_cast>*>( - impl.get()); - if (i && i->UseMe(attr)) { - auto more = i->GetFunc(); - VLOG(10) << "Test More Kernel " << info; - TestGRUFunc>(more, xsrc, ht_1, ht_ref, attr); - } - } - } - // Test result from Get function - auto tgt = jit::Get, PlaceType>(attr); - TestGRUFunc>(tgt, xsrc, ht_1, ht_ref, attr); + TestAllImpls, PlaceType, std::vector, + std::vector, std::vector>(attr, xsrc, ht_1, ht_ref, + attr); } } } } +// XYZNTuple +TEST(JITKernel, vmul) { + namespace jit = paddle::operators::jit; + TestXYZNKernel(); + TestXYZNKernel(); +} + +TEST(JITKernel, vadd) { + namespace jit = paddle::operators::jit; + TestXYZNKernel(); + TestXYZNKernel(); +} + +TEST(JITKernel, vaddrelu) { + namespace jit = paddle::operators::jit; + TestXYZNKernel(); + TestXYZNKernel(); +} + +TEST(JITKernel, vsub) { + namespace jit = paddle::operators::jit; + TestXYZNKernel(); + TestXYZNKernel(); +} + +// AXYNTuples +TEST(JITKernel, vscal) { + namespace jit = paddle::operators::jit; + TestAXYNKernel(); + TestAXYNKernel(); +} + +TEST(JITKernel, vaddbias) { + namespace jit = paddle::operators::jit; + TestAXYNKernel(); + TestAXYNKernel(); +} + +// XYNTuples +TEST(JITKernel, vrelu) { + namespace jit = paddle::operators::jit; + TestXYNKernel(); + TestXYNKernel(); +} + +TEST(JITKernel, videntity) { + namespace jit = paddle::operators::jit; + TestXYNKernel(); + TestXYNKernel(); +} + +TEST(JITKernel, vexp) { + namespace jit = paddle::operators::jit; + TestXYNKernel(); + TestXYNKernel(); +} + +TEST(JITKernel, vsigmoid) { + namespace jit = paddle::operators::jit; + TestXYNKernel(); + TestXYNKernel(); +} + +TEST(JITKernel, vtanh) { + namespace jit = paddle::operators::jit; + TestXYNKernel(); + TestXYNKernel(); +} + +// LSTM +TEST(JITKernel, lstmctht) { + namespace jit = paddle::operators::jit; + TestLSTMKernel(); + TestLSTMKernel(); +} + +TEST(JITKernel, lstmc1h1) { + namespace jit = paddle::operators::jit; + TestLSTMKernel(); + TestLSTMKernel(); +} + +// GRU TEST(JITKernel, gruh1) { namespace jit = paddle::operators::jit; TestGRUKernel();