From 28eb7d840c9baaf49303cada7fcc71f557abb78a Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Tue, 11 Dec 2018 09:29:15 +0000 Subject: [PATCH] test all impls and all inplace cases --- paddle/fluid/operators/jit/helper.h | 53 ++++++++----- paddle/fluid/operators/jit/test.cc | 117 +++++++++++++++++++++------- 2 files changed, 121 insertions(+), 49 deletions(-) diff --git a/paddle/fluid/operators/jit/helper.h b/paddle/fluid/operators/jit/helper.h index c8da960a1e1..09a6bc3d9d7 100644 --- a/paddle/fluid/operators/jit/helper.h +++ b/paddle/fluid/operators/jit/helper.h @@ -28,33 +28,16 @@ namespace paddle { namespace operators { namespace jit { -// Refer code do not related with attr, and always on CPUPlace -template -inline Func GetRefer() { - auto& ref_pool = ReferKernelPool().Instance().AllKernels(); - KernelKey kkey(KT, platform::CPUPlace()); - auto ref_iter = ref_pool.find(kkey); - PADDLE_ENFORCE(ref_iter != ref_pool.end(), - "Every Kernel should have reference function."); - auto& ref_impls = ref_iter->second; - for (auto& impl : ref_impls) { - auto i = dynamic_cast*>(impl.get()); - if (i) { - return i->GetFunc(); - } - } - return nullptr; -} - template -const Func Get(Attr attr) { + typename PlaceType> +inline const Func GetJitCode(Attr attr) { size_t key = JitCodeKey(attr); auto& codes = JitCodePool().Instance(); if (codes.Has(key)) { return codes.AllKernels().at(key)->template getCode(); } + // creator is not related with attr, so can use KernelKey as key KernelKey kkey(KT, PlaceType()); if (std::is_same::value) { // pool: (KernelKey(type, place), vector) @@ -73,8 +56,38 @@ const Func Get(Attr attr) { } } } + return nullptr; +} + +// Refer code do not related with attr, which is just for cast +// Refer is always on CPUPlace +template +inline Func GetRefer() { + auto& ref_pool = ReferKernelPool().Instance().AllKernels(); + KernelKey kkey(KT, platform::CPUPlace()); + auto ref_iter = ref_pool.find(kkey); + PADDLE_ENFORCE(ref_iter != ref_pool.end(), + "Every Kernel should have reference function."); + auto& ref_impls = ref_iter->second; + for (auto& impl : ref_impls) { + auto i = dynamic_cast*>(impl.get()); + if (i) { + return i->GetFunc(); + } + } + return nullptr; +} + +template +const Func Get(Attr attr) { + auto jitfunc = GetJitCode(attr); + if (jitfunc) { + return jitfunc; + } // pool: (KernelKey(type, place), vector) + KernelKey kkey(KT, PlaceType()); auto& pool = KernelPool().Instance().AllKernels(); auto iter = pool.find(kkey); if (iter != pool.end()) { diff --git a/paddle/fluid/operators/jit/test.cc b/paddle/fluid/operators/jit/test.cc index e531ba1a2c4..e523089101f 100644 --- a/paddle/fluid/operators/jit/test.cc +++ b/paddle/fluid/operators/jit/test.cc @@ -55,46 +55,105 @@ void ExpectEQ(const T* target, const T* refer, int n) { } } +std::vector TestSizes() { + std::vector s; + for (int i = 1; i < 30; ++i) { + s.push_back(i); + } + // test some large size + s.push_back(100); + s.push_back(1000); + return s; +} + +template +void TestTartgetFunc(const Func tgt, const std::vector& x, + const std::vector& y, const std::vector& zref) { + EXPECT_TRUE(tgt != nullptr); + EXPECT_EQ(zref.size(), x.size()); + EXPECT_EQ(zref.size(), y.size()); + const T* x_data = x.data(); + const T* y_data = y.data(); + const T* zref_data = zref.data(); + const int d = zref.size(); + + std::vector ztgt(d); + T* ztgt_data = ztgt.data(); + // test normal + tgt(x_data, y_data, ztgt_data, d); + ExpectEQ(ztgt_data, zref_data, d); + // test inplace x + std::copy(x.begin(), x.end(), ztgt.begin()); + tgt(ztgt_data, y_data, ztgt_data, d); + ExpectEQ(ztgt_data, zref_data, d); + // test inplace y + std::copy(y.begin(), y.end(), ztgt.begin()); + tgt(x_data, ztgt_data, ztgt_data, d); + ExpectEQ(ztgt_data, zref_data, d); +} + TEST(JitKernel, vmul) { using T = float; using PlaceType = paddle::platform::CPUPlace; - namespace jit = paddle::operators::jit; - // TODO(TJ): test more vector size - for (int d = 1; d < 30; ++d) { - auto ref = jit::GetRefer::func_type, + const auto KT = jit::vmul; + for (int d : TestSizes()) { + auto ref = jit::GetRefer::func_type, jit::VMulTuples::attr_type>(); - auto tgt = jit::Get::func_type, - jit::VMulTuples::attr_type, PlaceType>(d); EXPECT_TRUE(ref != nullptr); - EXPECT_TRUE(tgt != nullptr); - std::vector x(d), y(d); - std::vector zref(d), ztgt(d); + std::vector x(d), y(d), zref(d); RandomVec(d, x.data()); RandomVec(d, y.data()); - const float* x_data = x.data(); - const float* y_data = y.data(); - float* ztgt_data = ztgt.data(); - float* zref_data = zref.data(); - tgt(x_data, y_data, ztgt_data, d); + std::vector xinp(d), yinp(d); // inplace test + std::copy(x.begin(), x.end(), xinp.begin()); + std::copy(y.begin(), y.end(), yinp.begin()); + + const T* x_data = x.data(); + const T* y_data = y.data(); + T* zref_data = zref.data(); + T* xinp_data = xinp.data(); + T* yinp_data = yinp.data(); + + // test refer code inplace ref(x_data, y_data, zref_data, d); - ExpectEQ(ztgt_data, zref_data, d); - - // test inplace x - std::copy(x.begin(), x.end(), zref.begin()); - std::copy(x.begin(), x.end(), ztgt.begin()); - tgt(ztgt_data, y_data, ztgt_data, d); - ref(zref_data, y_data, zref_data, d); - ExpectEQ(ztgt_data, zref_data, d); - - // test inplace y - std::copy(y.begin(), y.end(), zref.begin()); - std::copy(y.begin(), y.end(), ztgt.begin()); - tgt(x_data, ztgt_data, ztgt_data, d); - ref(x_data, zref_data, zref_data, d); - ExpectEQ(ztgt_data, zref_data, d); + ref(x_data, yinp_data, yinp_data, d); + ref(xinp_data, y_data, xinp_data, d); + ExpectEQ(xinp_data, zref_data, d); + ExpectEQ(yinp_data, zref_data, d); + + // test jitcode + auto jitcode = jit::GetJitCode::func_type, + jit::VMulTuples::attr_type, PlaceType>(d); + if (jitcode) { + VLOG(10) << "Test jitcode, size: " << d; + TestTartgetFunc::func_type>(jitcode, x, y, zref); + } + + // test all impls in more + jit::KernelKey kkey(KT, PlaceType()); + auto& pool = jit::KernelPool().Instance().AllKernels(); + auto iter = pool.find(kkey); + if (iter != pool.end()) { + auto& impls = iter->second; + for (auto& impl : impls) { + auto i = + dynamic_cast::func_type, + jit::VMulTuples::attr_type>*>( + impl.get()); + if (i && i->UseMe(d)) { + auto more = i->GetFunc(); + VLOG(10) << "Test More Kernel, size: " << d; + TestTartgetFunc::func_type>(more, x, y, zref); + } + } + } + // Test result from Get function + VLOG(10) << "Test Get function, size: " << d; + auto tgt = jit::Get::func_type, + jit::VMulTuples::attr_type, PlaceType>(d); + TestTartgetFunc::func_type>(tgt, x, y, zref); } } -- GitLab