From 45bdd84dac51b6f3fb4315b144f958e6e15c8389 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Sun, 10 Mar 2019 16:26:59 +0000 Subject: [PATCH] enhance the jitkernel helper and add unit tests test=develop --- paddle/fluid/operators/jit/benchmark.cc | 28 +-- paddle/fluid/operators/jit/gen/act.cc | 14 +- paddle/fluid/operators/jit/gen/blas.cc | 4 +- paddle/fluid/operators/jit/gen/embseqpool.cc | 2 +- paddle/fluid/operators/jit/gen/gru.cc | 2 +- paddle/fluid/operators/jit/gen/hopv.cc | 2 +- paddle/fluid/operators/jit/gen/jitcode.h | 2 +- paddle/fluid/operators/jit/gen/lstm.cc | 2 +- paddle/fluid/operators/jit/gen/matmul.cc | 2 +- paddle/fluid/operators/jit/gen/seqpool.cc | 2 +- paddle/fluid/operators/jit/gen/sgd.cc | 2 +- paddle/fluid/operators/jit/gen/vbroadcast.cc | 2 +- paddle/fluid/operators/jit/gen_base.cc | 2 +- paddle/fluid/operators/jit/gen_base.h | 7 +- paddle/fluid/operators/jit/helper.h | 116 +++++++--- paddle/fluid/operators/jit/kernel_base.h | 7 +- paddle/fluid/operators/jit/kernel_key.cc | 3 + .../jit/more/intrinsic/crf_decoding.cc | 2 +- .../jit/more/intrinsic/crf_decoding.h | 3 +- .../jit/more/intrinsic/layer_norm.cc | 2 +- .../operators/jit/more/intrinsic/layer_norm.h | 3 +- paddle/fluid/operators/jit/more/mix/mix.cc | 16 +- paddle/fluid/operators/jit/more/mix/mix.h | 12 +- paddle/fluid/operators/jit/more/mkl/mkl.cc | 47 ++-- paddle/fluid/operators/jit/more/mkl/mkl.h | 14 +- paddle/fluid/operators/jit/registry.h | 4 +- paddle/fluid/operators/jit/test.cc | 208 ++++++++++++++---- 27 files changed, 328 insertions(+), 182 deletions(-) diff --git a/paddle/fluid/operators/jit/benchmark.cc b/paddle/fluid/operators/jit/benchmark.cc index 773cf38eb99..fbb04a166ef 100644 --- a/paddle/fluid/operators/jit/benchmark.cc +++ b/paddle/fluid/operators/jit/benchmark.cc @@ -111,33 +111,11 @@ template void BenchAllImpls(const typename KernelTuple::attr_type& attr, Args... args) { BenchFunc benchmark; std::vector> infos; - // test refer - auto refer = jit::GetRefer(); - if (!refer) { - LOG(FATAL) << "Refer can not be empty!"; + auto funcs = jit::GetAllCandidateFuncsWithTypes(attr); + for (auto f : funcs) { + infos.push_back(std::make_pair(f.first, benchmark(f.second, args...))); } - infos.push_back(std::make_pair("Refer", benchmark(refer, args...))); - // test jitcode - auto jitcode = jit::GetJitCode(attr); - if (jitcode) { - infos.push_back(std::make_pair("JitCode", benchmark(jitcode, args...))); - } - // test all impls in more - jit::KernelKey kkey(KernelTuple::kernel_type, PlaceType()); - auto& pool = jit::KernelPool().Instance().AllKernels(); - auto iter = pool.find(kkey); - if (iter != pool.end()) { - auto& impls = iter->second; - for (auto& impl : impls) { - auto i = dynamic_cast*>(impl.get()); - if (i && i->UseMe(attr)) { - auto more = i->GetFunc(); - infos.push_back( - std::make_pair(i->ImplType(), benchmark(more, args...))); - } - } - } // Test result from Get function auto tgt = jit::KernelFuncs::Cache().At(attr); if (!tgt) { diff --git a/paddle/fluid/operators/jit/gen/act.cc b/paddle/fluid/operators/jit/gen/act.cc index e7a73758790..5cac219f95f 100644 --- a/paddle/fluid/operators/jit/gen/act.cc +++ b/paddle/fluid/operators/jit/gen/act.cc @@ -81,7 +81,7 @@ void VActJitCode::genCode() { #define DECLARE_ACT_CREATOR(name) \ class name##Creator : public JitCodeCreator { \ public: \ - bool UseMe(const int& attr) const override; \ + bool CanBeUsed(const int& attr) const override; \ size_t CodeSize(const int& d) const override; \ std::unique_ptr CreateJitCode(const int& attr) const override { \ return make_unique(attr, CodeSize(attr)); \ @@ -96,27 +96,27 @@ DECLARE_ACT_CREATOR(VSigmoid); DECLARE_ACT_CREATOR(VTanh); // TODO(TJ): tuning use me -bool VReluCreator::UseMe(const int& d) const { +bool VReluCreator::CanBeUsed(const int& d) const { return platform::MayIUse(platform::avx); } -bool VSquareCreator::UseMe(const int& d) const { +bool VSquareCreator::CanBeUsed(const int& d) const { return platform::MayIUse(platform::avx); } -bool VIdentityCreator::UseMe(const int& d) const { +bool VIdentityCreator::CanBeUsed(const int& d) const { return platform::MayIUse(platform::avx); } -bool VExpCreator::UseMe(const int& d) const { +bool VExpCreator::CanBeUsed(const int& d) const { return platform::MayIUse(platform::avx) && d < 32; } -bool VSigmoidCreator::UseMe(const int& d) const { +bool VSigmoidCreator::CanBeUsed(const int& d) const { return platform::MayIUse(platform::avx); } -bool VTanhCreator::UseMe(const int& d) const { +bool VTanhCreator::CanBeUsed(const int& d) const { return platform::MayIUse(platform::avx); } diff --git a/paddle/fluid/operators/jit/gen/blas.cc b/paddle/fluid/operators/jit/gen/blas.cc index 5da24c359ed..e764a7983d3 100644 --- a/paddle/fluid/operators/jit/gen/blas.cc +++ b/paddle/fluid/operators/jit/gen/blas.cc @@ -142,7 +142,7 @@ void NCHW16CMulNCJitCode::genCode() { class NCHW16CMulNCCreator : public JitCodeCreator { public: - bool UseMe(const int& attr) const override { + bool CanBeUsed(const int& attr) const override { return platform::MayIUse(platform::avx512f); } size_t CodeSize(const int& d) const override { return 256 * 1024; } @@ -154,7 +154,7 @@ class NCHW16CMulNCCreator : public JitCodeCreator { #define DECLARE_BLAS_CREATOR(name) \ class name##Creator : public JitCodeCreator { \ public: \ - bool UseMe(const int& attr) const override { \ + bool CanBeUsed(const int& attr) const override { \ return platform::MayIUse(platform::avx) && attr <= 1024; \ } \ size_t CodeSize(const int& d) const override { \ diff --git a/paddle/fluid/operators/jit/gen/embseqpool.cc b/paddle/fluid/operators/jit/gen/embseqpool.cc index 23837a3fb98..6e8ecc07e74 100644 --- a/paddle/fluid/operators/jit/gen/embseqpool.cc +++ b/paddle/fluid/operators/jit/gen/embseqpool.cc @@ -121,7 +121,7 @@ void EmbSeqPoolJitCode::genCode() { class EmbSeqPoolCreator : public JitCodeCreator { public: - bool UseMe(const emb_seq_pool_attr_t& attr) const override { + bool CanBeUsed(const emb_seq_pool_attr_t& attr) const override { return platform::MayIUse(platform::avx) && attr.table_width % YMM_FLOAT_BLOCK == 0; } diff --git a/paddle/fluid/operators/jit/gen/gru.cc b/paddle/fluid/operators/jit/gen/gru.cc index 13f7a14111a..4bc9247f6f0 100644 --- a/paddle/fluid/operators/jit/gen/gru.cc +++ b/paddle/fluid/operators/jit/gen/gru.cc @@ -86,7 +86,7 @@ void GRUJitCode::genCode() { class name##Creator : public JitCodeCreator { \ public: \ /* TODO(TJ): enable more */ \ - bool UseMe(const gru_attr_t& attr) const override { \ + bool CanBeUsed(const gru_attr_t& attr) const override { \ return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \ } \ size_t CodeSize(const gru_attr_t& attr) const override { \ diff --git a/paddle/fluid/operators/jit/gen/hopv.cc b/paddle/fluid/operators/jit/gen/hopv.cc index e7884017198..3383f17df8f 100644 --- a/paddle/fluid/operators/jit/gen/hopv.cc +++ b/paddle/fluid/operators/jit/gen/hopv.cc @@ -76,7 +76,7 @@ void HOPVJitCode::genCode() { #define DECLARE_HOP_CREATOR(name) \ class name##Creator : public JitCodeCreator { \ public: \ - bool UseMe(const int& attr) const override { \ + bool CanBeUsed(const int& attr) const override { \ return platform::MayIUse(platform::avx); \ } \ size_t CodeSize(const int& d) const override { \ diff --git a/paddle/fluid/operators/jit/gen/jitcode.h b/paddle/fluid/operators/jit/gen/jitcode.h index 39847d1b65f..228db7cc721 100644 --- a/paddle/fluid/operators/jit/gen/jitcode.h +++ b/paddle/fluid/operators/jit/gen/jitcode.h @@ -73,7 +73,7 @@ class JitCode : public GenBase, public Xbyak::CodeGenerator { virtual void genCode() = 0; size_t getSize() const override { return CodeGenerator::getSize(); } - const unsigned char* getCodeInternal() override { + const unsigned char* getCodeInternal() const override { const Xbyak::uint8* code = CodeGenerator::getCode(); return code; } diff --git a/paddle/fluid/operators/jit/gen/lstm.cc b/paddle/fluid/operators/jit/gen/lstm.cc index 08bafb5a818..5e7789aede1 100644 --- a/paddle/fluid/operators/jit/gen/lstm.cc +++ b/paddle/fluid/operators/jit/gen/lstm.cc @@ -114,7 +114,7 @@ void LSTMJitCode::genCode() { class name##Creator : public JitCodeCreator { \ public: \ /* TODO(TJ): enable more */ \ - bool UseMe(const lstm_attr_t& attr) const override { \ + bool CanBeUsed(const lstm_attr_t& attr) const override { \ return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \ } \ size_t CodeSize(const lstm_attr_t& attr) const override { \ diff --git a/paddle/fluid/operators/jit/gen/matmul.cc b/paddle/fluid/operators/jit/gen/matmul.cc index ae3858eab20..ca50f26ce57 100644 --- a/paddle/fluid/operators/jit/gen/matmul.cc +++ b/paddle/fluid/operators/jit/gen/matmul.cc @@ -98,7 +98,7 @@ void MatMulJitCode::genCode() { class MatMulCreator : public JitCodeCreator { public: - bool UseMe(const matmul_attr_t& attr) const override { + bool CanBeUsed(const matmul_attr_t& attr) const override { return attr.m == 1 && platform::MayIUse(platform::avx512f) && attr.n % ZMM_FLOAT_BLOCK == 0 && attr.k < 512; } diff --git a/paddle/fluid/operators/jit/gen/seqpool.cc b/paddle/fluid/operators/jit/gen/seqpool.cc index 530d24ee1fb..ceca104cc98 100644 --- a/paddle/fluid/operators/jit/gen/seqpool.cc +++ b/paddle/fluid/operators/jit/gen/seqpool.cc @@ -57,7 +57,7 @@ void SeqPoolJitCode::genCode() { class SeqPoolCreator : public JitCodeCreator { public: - bool UseMe(const seq_pool_attr_t& attr) const override { + bool CanBeUsed(const seq_pool_attr_t& attr) const override { return platform::MayIUse(platform::avx); } size_t CodeSize(const seq_pool_attr_t& attr) const override { diff --git a/paddle/fluid/operators/jit/gen/sgd.cc b/paddle/fluid/operators/jit/gen/sgd.cc index a745a27f954..a40da9b9932 100644 --- a/paddle/fluid/operators/jit/gen/sgd.cc +++ b/paddle/fluid/operators/jit/gen/sgd.cc @@ -104,7 +104,7 @@ void SgdJitCode::genCode() { class SgdCreator : public JitCodeCreator { public: - bool UseMe(const sgd_attr_t& attr) const override { + bool CanBeUsed(const sgd_attr_t& attr) const override { return platform::MayIUse(platform::avx) && attr.grad_width % YMM_FLOAT_BLOCK == 0; } diff --git a/paddle/fluid/operators/jit/gen/vbroadcast.cc b/paddle/fluid/operators/jit/gen/vbroadcast.cc index 3f9fbdbd821..66a8d75fd4d 100644 --- a/paddle/fluid/operators/jit/gen/vbroadcast.cc +++ b/paddle/fluid/operators/jit/gen/vbroadcast.cc @@ -69,7 +69,7 @@ void VBroadcastJitCode::genCode() { class VBroadcastCreator : public JitCodeCreator { public: - bool UseMe(const int64_t& w) const override { + bool CanBeUsed(const int64_t& w) const override { return platform::MayIUse(platform::avx) && w % YMM_FLOAT_BLOCK == 0; } size_t CodeSize(const int64_t& w) const override { diff --git a/paddle/fluid/operators/jit/gen_base.cc b/paddle/fluid/operators/jit/gen_base.cc index f3603875ad7..4c49eff49e3 100644 --- a/paddle/fluid/operators/jit/gen_base.cc +++ b/paddle/fluid/operators/jit/gen_base.cc @@ -31,7 +31,7 @@ namespace paddle { namespace operators { namespace jit { -// refer do not need useme, it would be the last one. +// refer do not need CanBeUsed, it would be the last one. void GenBase::dumpCode(const unsigned char* code) const { if (code) { static int counter = 0; diff --git a/paddle/fluid/operators/jit/gen_base.h b/paddle/fluid/operators/jit/gen_base.h index a7c7a35a7ea..033c603c07c 100644 --- a/paddle/fluid/operators/jit/gen_base.h +++ b/paddle/fluid/operators/jit/gen_base.h @@ -31,9 +31,10 @@ class GenBase : public Kernel { virtual ~GenBase() = default; virtual std::string name() const = 0; virtual size_t getSize() const = 0; - virtual const unsigned char* getCodeInternal() = 0; + virtual const unsigned char* getCodeInternal() const = 0; + const char* ImplType() const override { return "JitCode"; } template - Func getCode() { + Func getCode() const { const unsigned char* code = this->getCodeInternal(); if (FLAGS_dump_jitcode) { this->dumpCode(code); @@ -65,7 +66,7 @@ class JitCodeCreator : public GenCreator { virtual ~JitCodeCreator() = default; // condition when this jit code can be used. - virtual bool UseMe(const Attr& attr) const = 0; + virtual bool CanBeUsed(const Attr& attr) const = 0; // estimate this code size virtual size_t CodeSize(const Attr& attr) const = 0; diff --git a/paddle/fluid/operators/jit/helper.h b/paddle/fluid/operators/jit/helper.h index 85f4072dd30..d98eada81c0 100644 --- a/paddle/fluid/operators/jit/helper.h +++ b/paddle/fluid/operators/jit/helper.h @@ -14,9 +14,6 @@ #pragma once -extern "C" { -#include -} #include #include #include @@ -36,31 +33,30 @@ template inline typename std::enable_if< std::is_same::value && std::is_same::value, - typename KernelTuple::func_type>::type + const Kernel*>::type GetJitCode(const typename KernelTuple::attr_type& attr) { - using Func = typename KernelTuple::func_type; using Attr = typename KernelTuple::attr_type; size_t key = JitCodeKey(attr); - auto& codes = JitCodePool().Instance(); + auto& codes = JitCodePool::Instance(); if (codes.Has(key)) { - return codes.AllKernels().at(key)->template getCode(); + return codes.AllKernels().at(key).get(); } // creator is not related with attr, so can use KernelKey as key KernelKey kkey(KernelTuple::kernel_type, PlaceType()); // pool: (KernelKey(type, place), vector) - auto& creator_map = JitCodeCreatorPool().Instance().AllCreators(); + auto& creator_map = JitCodeCreatorPool::Instance().AllCreators(); auto iter = creator_map.find(kkey); if (iter != creator_map.end()) { auto& creators = iter->second; for (auto& cur : creators) { auto i = dynamic_cast*>(cur.get()); - if (i && i->UseMe(attr)) { + if (i && i->CanBeUsed(attr)) { auto p = i->CreateJitCode(attr); if (p) { - auto f = p->template getCode(); + auto res = p.get(); codes.Insert(key, std::move(p)); - return f; + return res; } } } @@ -72,7 +68,7 @@ template inline typename std::enable_if< !std::is_same::value || !std::is_same::value, - typename KernelTuple::func_type>::type + const Kernel*>::type GetJitCode(const typename KernelTuple::attr_type& attr) { return nullptr; } @@ -80,8 +76,8 @@ GetJitCode(const typename KernelTuple::attr_type& attr) { // Refer code do not related with attr, which is just for cast // Refer is always on CPUPlace template -inline typename KernelTuple::func_type GetRefer() { - auto& ref_pool = ReferKernelPool().Instance().AllKernels(); +inline const Kernel* GetReferKernel() { + auto& ref_pool = ReferKernelPool::Instance().AllKernels(); KernelKey kkey(KernelTuple::kernel_type, platform::CPUPlace()); auto ref_iter = ref_pool.find(kkey); PADDLE_ENFORCE(ref_iter != ref_pool.end(), @@ -90,36 +86,93 @@ inline typename KernelTuple::func_type GetRefer() { for (auto& impl : ref_impls) { auto i = dynamic_cast*>(impl.get()); if (i) { - return i->GetFunc(); + return i; } } return nullptr; } -template -typename KernelTuple::func_type Get( +template +inline typename KernelTuple::func_type GetReferFunc() { + auto ker = GetReferKernel(); + auto p = dynamic_cast*>(ker); + PADDLE_ENFORCE(p, "The Refer kernel should exsit"); + return p->GetFunc(); +} + +// Return all Kernels that can be used +template +std::vector GetAllCandidateKernels( const typename KernelTuple::attr_type& attr) { - auto jitfunc = GetJitCode(attr); - if (jitfunc) { - return jitfunc; + // the search order shoudl be jitcode > more > refer + std::vector res; + auto jitker = GetJitCode(attr); + if (jitker) { + res.emplace_back(jitker); } - // pool: (KernelKey(type, place), vector) + // more kernelpool: (KernelKey(type, place), vector) KernelKey kkey(KernelTuple::kernel_type, PlaceType()); - auto& pool = KernelPool().Instance().AllKernels(); + auto& pool = KernelPool::Instance().AllKernels(); auto iter = pool.find(kkey); if (iter != pool.end()) { auto& impls = iter->second; for (auto& impl : impls) { auto i = dynamic_cast*>(impl.get()); - if (i && i->UseMe(attr)) { - return i->GetFunc(); + if (i && i->CanBeUsed(attr)) { + res.emplace_back(i); } } } // The last implementation should be reference function on CPUPlace. - return GetRefer(); + auto ref = GetReferKernel(); + PADDLE_ENFORCE(ref != nullptr, "Refer Kernel can not be empty."); + res.emplace_back(ref); + return res; +} + +template +std::vector> +GetAllCandidateFuncsWithTypes(const typename KernelTuple::attr_type& attr) { + using Func = typename KernelTuple::func_type; + auto kers = GetAllCandidateKernels(attr); + std::vector> res; + for (auto k : kers) { + std::string name = k->ImplType(); + if (name == "JitCode") { + auto i = dynamic_cast(k); + PADDLE_ENFORCE(i, "jitcode kernel cast can not fail."); + res.emplace_back(std::make_pair(name, i->template getCode())); + } else { + auto i = dynamic_cast*>(k); + PADDLE_ENFORCE(i, "kernel cast can not fail."); + res.emplace_back(std::make_pair(name, i->GetFunc())); + } + } + return res; +} + +template +std::vector GetAllCandidateFuncs( + const typename KernelTuple::attr_type& attr) { + auto funcs = GetAllCandidateFuncsWithTypes(attr); + std::vector res; + for (auto& i : funcs) { + res.emplace_back(i.second); + } + return res; +} + +template +typename KernelTuple::func_type GetDefaultBestFunc( + const typename KernelTuple::attr_type& attr) { + auto funcs = GetAllCandidateFuncs(attr); + PADDLE_ENFORCE_GE(funcs.size(), 1UL); + // Here could do some runtime benchmark of this attr and return the best one. + // But yet just get the first one as the default best one, + // which is searched in order and tuned by offline. + return funcs[0]; } template @@ -134,17 +187,13 @@ class KernelFuncs { // the exposed interface to use typename KernelTuple::func_type At( const typename KernelTuple::attr_type& attr) { - // XXH64: 13.8 GB/s - // TODO(TJ): change me, maybe not all attr change need one key, should be - // attrkey - int64_t key = XXH64(&attr, sizeof(typename KernelTuple::attr_type), 0); + // Maybe here is not good enough, not all kernels should have jitcode + int64_t key = JitCodeKey(attr); if (Has(key)) { return funcs_.at(key); } - // If do not have this attr in cache, - // then could run some runtime benchmark of this attr and save the best one. - // Here just get the offline benchmarked best one. - auto func = Get(attr); + // If do not have this attr in cache then get the default best + auto func = GetDefaultBestFunc(attr); Insert(key, func); return func; } @@ -156,7 +205,6 @@ class KernelFuncs { protected: bool Has(int64_t key) const { return funcs_.find(key) != funcs_.end(); } - void Insert(int64_t key, typename KernelTuple::func_type func) { funcs_.emplace(key, func); } diff --git a/paddle/fluid/operators/jit/kernel_base.h b/paddle/fluid/operators/jit/kernel_base.h index e8dbcced4f1..bd34d7dfc72 100644 --- a/paddle/fluid/operators/jit/kernel_base.h +++ b/paddle/fluid/operators/jit/kernel_base.h @@ -302,6 +302,7 @@ class Kernel { public: Kernel() = default; virtual ~Kernel() = default; + virtual const char* ImplType() const = 0; DISABLE_COPY_AND_ASSIGN(Kernel); }; @@ -312,8 +313,8 @@ class KernelMore : public Kernel { using Func = typename KernelTuple::func_type; using Attr = typename KernelTuple::attr_type; virtual Func GetFunc() const { return func; } - virtual bool UseMe(const Attr& attr) const = 0; - virtual const char* ImplType() const = 0; + // specify this kernel can be used, means it should not fail if use it. + virtual bool CanBeUsed(const Attr& attr) const = 0; protected: Func func{nullptr}; @@ -323,7 +324,7 @@ template class ReferKernel : public KernelMore { public: // Refer code can always be used - bool UseMe(const typename KernelTuple::attr_type& attr) const override { + bool CanBeUsed(const typename KernelTuple::attr_type& attr) const override { return true; } const char* ImplType() const override { return "Refer"; } diff --git a/paddle/fluid/operators/jit/kernel_key.cc b/paddle/fluid/operators/jit/kernel_key.cc index 1c2fddcae79..6987c893de4 100644 --- a/paddle/fluid/operators/jit/kernel_key.cc +++ b/paddle/fluid/operators/jit/kernel_key.cc @@ -13,6 +13,7 @@ * limitations under the License. */ #include "paddle/fluid/operators/jit/kernel_key.h" +#include #include "paddle/fluid/platform/enforce.h" namespace paddle { @@ -49,6 +50,8 @@ static inline int act_type_convert(KernelType type) { template <> size_t JitCodeKey(const lstm_attr_t& attr) { + // XXH64: 13.8 GB/s + size_t key = attr.d; int gate_key = act_type_convert(attr.act_gate) << 1; int cand_key = act_type_convert(attr.act_cand) << (1 + act_type_shift); diff --git a/paddle/fluid/operators/jit/more/intrinsic/crf_decoding.cc b/paddle/fluid/operators/jit/more/intrinsic/crf_decoding.cc index 16c91f8246d..1254d00189a 100644 --- a/paddle/fluid/operators/jit/more/intrinsic/crf_decoding.cc +++ b/paddle/fluid/operators/jit/more/intrinsic/crf_decoding.cc @@ -161,7 +161,7 @@ void CRFDecoding(const int seq_len, const float* x, const float* w, } } -bool CRFDecodingKernel::UseMe(const int& d) const { +bool CRFDecodingKernel::CanBeUsed(const int& d) const { #ifdef __AVX512F__ constexpr int block = ZMM_FLOAT_BLOCK; #else diff --git a/paddle/fluid/operators/jit/more/intrinsic/crf_decoding.h b/paddle/fluid/operators/jit/more/intrinsic/crf_decoding.h index f4187bd3ba2..49b1a1fea4b 100644 --- a/paddle/fluid/operators/jit/more/intrinsic/crf_decoding.h +++ b/paddle/fluid/operators/jit/more/intrinsic/crf_decoding.h @@ -29,7 +29,8 @@ void CRFDecoding(const int seq_len, const float* x, const float* w, class CRFDecodingKernel : public KernelMore> { public: CRFDecodingKernel() { this->func = CRFDecoding; } - bool UseMe(const typename CRFDecodingTuple::attr_type&) const override; + bool CanBeUsed( + const typename CRFDecodingTuple::attr_type&) const override; const char* ImplType() const override { return "Intrinsic"; } }; diff --git a/paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc b/paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc index e9b6e401c68..a4e3246f104 100644 --- a/paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc +++ b/paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc @@ -153,7 +153,7 @@ void LayerNorm(float* x, float* out, float* mean, float* var, } } -bool LayerNormKernel::UseMe(const int& d) const { +bool LayerNormKernel::CanBeUsed(const int& d) const { return platform::MayIUse(platform::avx) && d >= YMM_FLOAT_BLOCK; } diff --git a/paddle/fluid/operators/jit/more/intrinsic/layer_norm.h b/paddle/fluid/operators/jit/more/intrinsic/layer_norm.h index dfa4c2f072f..7b9f676050d 100644 --- a/paddle/fluid/operators/jit/more/intrinsic/layer_norm.h +++ b/paddle/fluid/operators/jit/more/intrinsic/layer_norm.h @@ -30,7 +30,8 @@ void LayerNorm(float* x, float* out, float* mean, float* var, class LayerNormKernel : public KernelMore> { public: LayerNormKernel() { this->func = LayerNorm; } - bool UseMe(const typename LayerNormTuple::attr_type&) const override; + bool CanBeUsed( + const typename LayerNormTuple::attr_type&) const override; const char* ImplType() const override { return "Intrinsic"; } }; diff --git a/paddle/fluid/operators/jit/more/mix/mix.cc b/paddle/fluid/operators/jit/more/mix/mix.cc index 9ee1032e95e..6e709a16d23 100644 --- a/paddle/fluid/operators/jit/more/mix/mix.cc +++ b/paddle/fluid/operators/jit/more/mix/mix.cc @@ -204,21 +204,21 @@ void GRUHtPart2(gru_t* step, const gru_attr_t* attr) { } // TODO(TJ): tuning me -bool VSigmoidKernel::UseMe(const int& d) const { return true; } +bool VSigmoidKernel::CanBeUsed(const int& d) const { return true; } -bool VTanhKernel::UseMe(const int& d) const { return true; } +bool VTanhKernel::CanBeUsed(const int& d) const { return true; } -bool SoftmaxKernel::UseMe(const int& d) const { return true; } +bool SoftmaxKernel::CanBeUsed(const int& d) const { return true; } -bool LSTMCtHtKernel::UseMe(const lstm_attr_t& attr) const { return true; } +bool LSTMCtHtKernel::CanBeUsed(const lstm_attr_t& attr) const { return true; } -bool LSTMC1H1Kernel::UseMe(const lstm_attr_t& attr) const { return true; } +bool LSTMC1H1Kernel::CanBeUsed(const lstm_attr_t& attr) const { return true; } -bool GRUH1Kernel::UseMe(const gru_attr_t& attr) const { return true; } +bool GRUH1Kernel::CanBeUsed(const gru_attr_t& attr) const { return true; } -bool GRUHtPart1Kernel::UseMe(const gru_attr_t& attr) const { return true; } +bool GRUHtPart1Kernel::CanBeUsed(const gru_attr_t& attr) const { return true; } -bool GRUHtPart2Kernel::UseMe(const gru_attr_t& attr) const { return true; } +bool GRUHtPart2Kernel::CanBeUsed(const gru_attr_t& attr) const { return true; } } // namespace mix } // namespace more diff --git a/paddle/fluid/operators/jit/more/mix/mix.h b/paddle/fluid/operators/jit/more/mix/mix.h index 17eb96462f9..994d485909c 100644 --- a/paddle/fluid/operators/jit/more/mix/mix.h +++ b/paddle/fluid/operators/jit/more/mix/mix.h @@ -34,12 +34,12 @@ void GRUH1(gru_t* step, const gru_attr_t* attr); void GRUHtPart1(gru_t* step, const gru_attr_t* attr); void GRUHtPart2(gru_t* step, const gru_attr_t* attr); -#define DECLARE_MORE_KERNEL(name) \ - class name##Kernel : public KernelMore> { \ - public: \ - name##Kernel() { this->func = name; } \ - bool UseMe(const typename name##Tuple::attr_type&) const override; \ - const char* ImplType() const override { return "Mixed"; } \ +#define DECLARE_MORE_KERNEL(name) \ + class name##Kernel : public KernelMore> { \ + public: \ + name##Kernel() { this->func = name; } \ + bool CanBeUsed(const typename name##Tuple::attr_type&) const override; \ + const char* ImplType() const override { return "Mixed"; } \ } // XYN diff --git a/paddle/fluid/operators/jit/more/mkl/mkl.cc b/paddle/fluid/operators/jit/more/mkl/mkl.cc index 084ea571cea..4f600b38144 100644 --- a/paddle/fluid/operators/jit/more/mkl/mkl.cc +++ b/paddle/fluid/operators/jit/more/mkl/mkl.cc @@ -130,105 +130,106 @@ void ASum(const double* x, double* res, int n) { // TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512 template <> -bool VMulKernel::UseMe(const int& d) const { +bool VMulKernel::CanBeUsed(const int& d) const { return platform::MayIUse(platform::avx512f) && d > 512; } template <> -bool VAddKernel::UseMe(const int& d) const { +bool VAddKernel::CanBeUsed(const int& d) const { return platform::MayIUse(platform::avx) && d > 512; } template <> -bool VScalKernel::UseMe(const int& d) const { +bool VScalKernel::CanBeUsed(const int& d) const { return platform::MayIUse(platform::avx512f) && d > 512; } template <> -bool VExpKernel::UseMe(const int& d) const { +bool VExpKernel::CanBeUsed(const int& d) const { return d > 7; } template <> -bool VSquareKernel::UseMe(const int& d) const { +bool VSquareKernel::CanBeUsed(const int& d) const { return d > 7; } template <> -bool VCopyKernel::UseMe(const int& d) const { +bool VCopyKernel::CanBeUsed(const int& d) const { return d > 15; } template <> -bool VBroadcastKernel::UseMe(const int64_t& d) const { +bool VBroadcastKernel::CanBeUsed(const int64_t& d) const { return d > 127; } template <> -bool VBroadcastKernel::UseMe(const int64_t& attr) const { +bool VBroadcastKernel::CanBeUsed(const int64_t& attr) const { return true; } template <> -bool VSigmoidKernel::UseMe(const int& d) const { +bool VSigmoidKernel::CanBeUsed(const int& d) const { return d > 7; } template <> -bool VTanhKernel::UseMe(const int& d) const { +bool VTanhKernel::CanBeUsed(const int& d) const { return d > 7; } template <> -bool SeqPoolKernel::UseMe(const seq_pool_attr_t& attr) const { +bool SeqPoolKernel::CanBeUsed(const seq_pool_attr_t& attr) const { return true; } template <> -bool SeqPoolKernel::UseMe(const seq_pool_attr_t& attr) const { +bool SeqPoolKernel::CanBeUsed(const seq_pool_attr_t& attr) const { return true; } template <> -bool EmbSeqPoolKernel::UseMe(const emb_seq_pool_attr_t& attr) const { +bool EmbSeqPoolKernel::CanBeUsed(const emb_seq_pool_attr_t& attr) const { return true; } template <> -bool EmbSeqPoolKernel::UseMe(const emb_seq_pool_attr_t& attr) const { +bool EmbSeqPoolKernel::CanBeUsed( + const emb_seq_pool_attr_t& attr) const { return true; } template <> -bool SgdKernel::UseMe(const sgd_attr_t& attr) const { +bool SgdKernel::CanBeUsed(const sgd_attr_t& attr) const { return true; } template <> -bool SgdKernel::UseMe(const sgd_attr_t& attr) const { +bool SgdKernel::CanBeUsed(const sgd_attr_t& attr) const { return true; } template <> -bool MatMulKernel::UseMe(const matmul_attr_t& attr) const { +bool MatMulKernel::CanBeUsed(const matmul_attr_t& attr) const { return platform::MayIUse(platform::avx); } template <> -bool MatMulKernel::UseMe(const matmul_attr_t& attr) const { +bool MatMulKernel::CanBeUsed(const matmul_attr_t& attr) const { return true; } template <> -bool SoftmaxKernel::UseMe(const int& d) const { +bool SoftmaxKernel::CanBeUsed(const int& d) const { // tuned on avx2 return platform::MayIUse(platform::avx) && d < 60; } -#define AWALYS_USE_ME_WITH_DOUBLE(func) \ - template <> \ - bool func##Kernel::UseMe(const int& d) const { \ - return true; \ +#define AWALYS_USE_ME_WITH_DOUBLE(func) \ + template <> \ + bool func##Kernel::CanBeUsed(const int& d) const { \ + return true; \ } AWALYS_USE_ME_WITH_DOUBLE(VMul); diff --git a/paddle/fluid/operators/jit/more/mkl/mkl.h b/paddle/fluid/operators/jit/more/mkl/mkl.h index 8c1d8b57e0c..f51dca654cd 100644 --- a/paddle/fluid/operators/jit/more/mkl/mkl.h +++ b/paddle/fluid/operators/jit/more/mkl/mkl.h @@ -175,13 +175,13 @@ void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows, } } -#define DECLARE_MKL_KERNEL(name) \ - template \ - class name##Kernel : public KernelMore> { \ - public: \ - name##Kernel() { this->func = name; } \ - bool UseMe(const typename name##Tuple::attr_type&) const override; \ - const char* ImplType() const override { return "MKL"; } \ +#define DECLARE_MKL_KERNEL(name) \ + template \ + class name##Kernel : public KernelMore> { \ + public: \ + name##Kernel() { this->func = name; } \ + bool CanBeUsed(const typename name##Tuple::attr_type&) const override; \ + const char* ImplType() const override { return "MKL"; } \ } // ABCMNK diff --git a/paddle/fluid/operators/jit/registry.h b/paddle/fluid/operators/jit/registry.h index cb32c487208..c8da92c0c53 100644 --- a/paddle/fluid/operators/jit/registry.h +++ b/paddle/fluid/operators/jit/registry.h @@ -49,8 +49,8 @@ struct JitKernelRegistrarFunctor { void operator()(KernelType kt) const { KernelKey kkey(kt, PlaceType()); - Pool().Instance().Insert(kkey, - std::move(make_unique())); + Pool::Instance().Insert(kkey, + std::move(make_unique())); constexpr auto size = std::tuple_size>::value; JitKernelRegistrarFunctor diff --git a/paddle/fluid/operators/jit/test.cc b/paddle/fluid/operators/jit/test.cc index a574bf2079f..898133a03b5 100644 --- a/paddle/fluid/operators/jit/test.cc +++ b/paddle/fluid/operators/jit/test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include +#include #include #include #include @@ -68,31 +69,11 @@ template void TestAllImpls(const typename KernelTuple::attr_type& attr, const Tester& verifier, const Args&... args) { - // test jitcode - auto jitcode = jit::GetJitCode(attr); - if (jitcode) { - VLOG(10) << "Test Jitcode Kernel "; - verifier(jitcode, args...); + auto funcs = jit::GetAllCandidateFuncsWithTypes(attr); + for (auto f : funcs) { + VLOG(10) << "Test Kernel " << f.first; + verifier(f.second, args...); } - // test all impls in more - jit::KernelKey kkey(KernelTuple::kernel_type, PlaceType()); - auto& pool = jit::KernelPool().Instance().AllKernels(); - auto iter = pool.find(kkey); - if (iter != pool.end()) { - auto& impls = iter->second; - for (auto& impl : impls) { - auto i = dynamic_cast*>(impl.get()); - if (i && i->UseMe(attr)) { - auto more = i->GetFunc(); - VLOG(10) << "Test More Kernel : " << i->ImplType(); - verifier(more, args...); - } - } - } - // test result from Get function - VLOG(10) << "Test final get function "; - auto tgt = jit::KernelFuncs::Cache().At(attr); - verifier(tgt, args...); } template @@ -100,7 +81,7 @@ void TestKernelXYZN() { using T = typename KernelTuple::data_type; VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type); for (int d : TestSizes()) { - auto ref = jit::GetRefer(); + auto ref = jit::GetReferFunc(); EXPECT_TRUE(ref != nullptr); std::vector x(d), y(d), zref(d); @@ -159,7 +140,7 @@ void TestKernelAXYN() { using T = typename KernelTuple::data_type; VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type); for (int d : TestSizes()) { - auto ref = jit::GetRefer(); + auto ref = jit::GetReferFunc(); EXPECT_TRUE(ref != nullptr); const T a = static_cast(3); @@ -202,7 +183,7 @@ void TestKernelXYN() { using T = typename KernelTuple::data_type; VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type); for (int d : TestSizes()) { - auto ref = jit::GetRefer(); + auto ref = jit::GetReferFunc(); EXPECT_TRUE(ref != nullptr); std::vector x(d), yref(d); @@ -245,7 +226,7 @@ void TestKernelXRN() { auto last_acc = FLAGS_acc; FLAGS_acc = 1e-4; for (int d : TestSizes()) { - auto ref = jit::GetRefer(); + auto ref = jit::GetReferFunc(); EXPECT_TRUE(ref != nullptr); std::vector x(d); RandomVec(d, x.data()); @@ -279,7 +260,7 @@ void TestKernelLSTM() { const jit::lstm_attr_t attr( d, jit::to_kerneltype(act_gate), jit::to_kerneltype(act_cand), jit::to_kerneltype(act_cell), use_peephole); - auto ref = jit::GetRefer(); + auto ref = jit::GetReferFunc(); EXPECT_TRUE(ref != nullptr); std::vector xsrc(4 * d), wp(3 * d), ct_1(d); std::vector ct_ref(d), ht_ref(d), checked(2 * d); @@ -370,7 +351,7 @@ void TestKernelGRU() { for (auto& act_cand : all_acts) { const jit::gru_attr_t attr(d, jit::to_kerneltype(act_gate), jit::to_kerneltype(act_cand)); - auto ref = jit::GetRefer(); + auto ref = jit::GetReferFunc(); EXPECT_TRUE(ref != nullptr); std::vector xsrc(3 * d), ht_1(d), ht_ref(d); RandomVec(3 * d, xsrc.data()); @@ -423,7 +404,7 @@ void TestKernelNCHW16CMulNC() { using T = typename KernelTuple::data_type; VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type); const int n = 3, c = 16 * 4, h = 10, w = 10; - auto ref = jit::GetRefer(); + auto ref = jit::GetReferFunc(); EXPECT_TRUE(ref != nullptr); int sz = n * c * h * w; std::vector x(sz), y(n * c), zref(sz); @@ -439,7 +420,9 @@ void TestKernelNCHW16CMulNC() { constexpr int simd_width = ZMM_FLOAT_BLOCK; int C = c / simd_width; auto tgt = jit::KernelFuncs::Cache().At(0); - auto jitcode = jit::GetJitCode(0); + auto funcs = jit::GetAllCandidateFuncs(0); + EXPECT_GT(funcs.size(), 0UL); + auto jitcode = funcs[0]; EXPECT_TRUE(tgt != nullptr); if (std::is_same::value && @@ -482,7 +465,7 @@ void TestKernelLayerNorm() { int left = n * x_dim_0; for (int x_dim_1 : TestSizes()) { int right = x_dim_1; - auto ref = jit::GetRefer(); + auto ref = jit::GetReferFunc(); EXPECT_TRUE(ref != nullptr); int sz = left * right; std::vector x(sz), mean(left), var(left), scale(right), bias(right), @@ -555,7 +538,7 @@ void TestKernelCRFDecoding() { test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 2000)); for (int seq_len : {1, 11, 17, 50}) { for (int tag_num : test_sizes) { - auto ref = jit::GetRefer(); + auto ref = jit::GetReferFunc(); EXPECT_TRUE(ref != nullptr); int x_sz = seq_len * tag_num; int w_sz = (tag_num + state_trans_base_idx) * tag_num; @@ -606,7 +589,7 @@ void TestKernelSeqPool() { jit::seq_pool_attr_t attr(w, type); for (int h : test_sizes) { attr.h = h; - auto ref = jit::GetRefer(); + auto ref = jit::GetReferFunc(); EXPECT_TRUE(ref != nullptr); std::vector x(h * w), yref(w); RandomVec(h * w, x.data()); @@ -649,7 +632,7 @@ void TestKernelEmbSeqPool() { for (auto type : pool_types) { for (int idx_w : {1, 2, 10, 16}) { for (int idx_h : {1, 2, 9, 13, 16}) { - auto ref = jit::GetRefer(); + auto ref = jit::GetReferFunc(); EXPECT_TRUE(ref != nullptr); std::vector idx(idx_h * idx_w); RandomVec(idx_h * idx_w, idx.data(), 0, tbl_h - 1); @@ -701,7 +684,7 @@ void TestKernelMatMul() { for (int m : {1, 2, 3, 4}) { for (int n : {1, 2, 3, 4}) { for (int k : TestSizes()) { - auto ref = jit::GetRefer(); + auto ref = jit::GetReferFunc(); EXPECT_TRUE(ref != nullptr); std::vector a(m * k), b(k * n), c(m * n); RandomVec(m * k, a.data()); @@ -740,7 +723,7 @@ void TestKernelSoftmax() { VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type); for (int bs : {1, 2, 10}) { for (int n : TestSizes()) { - auto ref = jit::GetRefer(); + auto ref = jit::GetReferFunc(); EXPECT_TRUE(ref != nullptr); std::vector x(bs * n), y(bs * n); RandomVec(bs * n, x.data()); @@ -808,7 +791,7 @@ void TestKernelSgd() { RandomVec(rows_size * grad_w, grad.data()); const int64_t* rows_data = rows.data(); const T* grad_data = grad.data(); - auto ref = jit::GetRefer(); + auto ref = jit::GetReferFunc(); EXPECT_TRUE(ref != nullptr); jit::sgd_attr_t attr(param_h, grad_w, rows_size, grad_w, rows_size); ref(&lr, param_data, grad_data, rows_data, out_data, &attr); @@ -874,7 +857,7 @@ void TestKernelVBroadcast() { RandomVec(w, x.data()); const T* x_data = x.data(); for (int64_t h : {1, 2, 6}) { - auto ref = jit::GetRefer(); + auto ref = jit::GetReferFunc(); EXPECT_TRUE(ref != nullptr); std::vector y(w * h); T* y_data = y.data(); @@ -900,6 +883,135 @@ void TestKernelVBroadcast() { } } +// test pool +TEST(JITKernel_pool, jitcreator) { + const auto& jitcreators = jit::JitCodeCreatorPool::Instance().AllCreators(); + EXPECT_EQ(jitcreators.size(), 25UL); +} + +TEST(JITKernel_pool, jitpool) { + // jitpool is related with attr + const auto& kers = jit::JitCodePool().Instance().AllKernels(); + EXPECT_EQ(kers.size(), 0UL); + jit::GetAllCandidateKernels, CPUPlace>(3); + // after call GetAllCandidateKernels, it will create jitcode Automatically + EXPECT_EQ(kers.size(), 1UL); +} + +TEST(JITKernel_pool, more) { + const auto& kers = jit::KernelPool::Instance().AllKernels(); + EXPECT_EQ(kers.size(), 21UL); +} + +TEST(JITKernel_pool, refer) { + const auto& kers = jit::ReferKernelPool::Instance().AllKernels(); + EXPECT_EQ(kers.size(), 29UL); +} + +// test helper +TEST(JITKernel_helper, GetAllCandidateKernels) { + auto fp_kers = + jit::GetAllCandidateKernels, CPUPlace>(10); +#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__) + EXPECT_GE(fp_kers.size(), 1UL); // refer +#else + EXPECT_GE(fp_kers.size(), 3UL); // jitcode, mkl, refer +#endif + + auto db_kers = + jit::GetAllCandidateKernels, CPUPlace>(10); +#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__) + EXPECT_GE(db_kers.size(), 1UL); // refer +#else + EXPECT_GE(db_kers.size(), 2UL); // mkl, refer +#endif +} + +TEST(JITKernel_helper, GetAllCandidateFuncsWithTypes) { + auto fp_kers = + jit::GetAllCandidateFuncsWithTypes, CPUPlace>(10); + EXPECT_GE(fp_kers.size(), 3UL); // jitcode, mkl, refer + + auto db_kers = + jit::GetAllCandidateFuncsWithTypes, CPUPlace>(10); + EXPECT_GE(db_kers.size(), 2UL); // mkl, refer +} + +TEST(JITKernel_helper, GetAllCandidateFuncs) { + auto funcs = jit::GetAllCandidateFuncs, CPUPlace>(10); + auto kers = jit::GetAllCandidateKernels, CPUPlace>(10); + EXPECT_EQ(funcs.size(), kers.size()); + + std::vector x(10), tgt(10); + RandomVec(10, x.data()); + auto best = jit::GetDefaultBestFunc, CPUPlace>(10); + best(x.data(), tgt.data(), 10); + for (auto f : funcs) { + std::vector y(10); + f(x.data(), y.data(), 10); + ExpectEQ(y.data(), tgt.data(), 10); + } +} + +TEST(JITKernel_helper, attr) { + std::ostringstream out; + + // KernelTypes + out << jit::to_string(jit::kNone) << jit::to_string(jit::kCRFDecoding) + << jit::to_string(jit::kEmbSeqPool) << jit::to_string(jit::kGRUH1) + << jit::to_string(jit::kGRUHtPart1) << jit::to_string(jit::kGRUHtPart2) + << jit::to_string(jit::kHSum) << jit::to_string(jit::kHMax) + << jit::to_string(jit::kLSTMCtHt) << jit::to_string(jit::kLSTMC1H1) + << jit::to_string(jit::kLayerNorm) << jit::to_string(jit::kMatMul) + << jit::to_string(jit::kNCHW16CMulNC) << jit::to_string(jit::kSeqPool) + << jit::to_string(jit::kSoftmax) << jit::to_string(jit::kVAdd) + << jit::to_string(jit::kVAddBias) << jit::to_string(jit::kVAddRelu) + << jit::to_string(jit::kVBroadcast) << jit::to_string(jit::kVCopy) + << jit::to_string(jit::kVExp) << jit::to_string(jit::kVIdentity) + << jit::to_string(jit::kVMul) << jit::to_string(jit::kVRelu) + << jit::to_string(jit::kVScal) << jit::to_string(jit::kSgd) + << jit::to_string(jit::kVSigmoid) << jit::to_string(jit::kVSquare) + << jit::to_string(jit::kVSub) << jit::to_string(jit::kVTanh); + EXPECT_EQ(out.str().size(), 234); + + // SeqPoolTypes + out.str(""); + out << jit::to_string(jit::kSum) << jit::to_string(jit::kAvg) + << jit::to_string(jit::kSqrt); + EXPECT_EQ(out.str().size(), 13); + + EXPECT_EQ(jit::to_kerneltype("relu"), jit::kVRelu); + EXPECT_EQ(jit::to_kerneltype("Identity"), jit::kVIdentity); + EXPECT_EQ(jit::to_kerneltype("VEXP"), jit::kVExp); + EXPECT_EQ(jit::to_kerneltype("SigmoiD"), jit::kVSigmoid); + EXPECT_EQ(jit::to_kerneltype("VTanh"), jit::kVTanh); + + out.str(""); + out << jit::lstm_attr_t(8, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh); + EXPECT_EQ(out.str().size(), 89); + + out.str(""); + out << jit::gru_attr_t(8, jit::kVIdentity, jit::kVSigmoid); + EXPECT_EQ(out.str().size(), 52); + + out.str(""); + out << jit::seq_pool_attr_t(8, jit::SeqPoolType::kSum); + EXPECT_EQ(out.str().size(), 44); + + out.str(""); + out << jit::emb_seq_pool_attr_t(1, 2, 3, 4, 5, jit::SeqPoolType::kAvg); + EXPECT_EQ(out.str().size(), 93); + + out.str(""); + out << jit::sgd_attr_t(1, 2, 3, 4, 5); + EXPECT_EQ(out.str().size(), 81); + + out.str(""); + out << jit::matmul_attr_t(1, 2, 3); + EXPECT_EQ(out.str().size(), 14); +} + +// test kernerls #define TestKernelVMul TestKernelXYZN #define TestKernelVAdd TestKernelXYZN #define TestKernelVAddRelu TestKernelXYZN @@ -969,6 +1081,14 @@ TEST_CPU_KERNEL(Softmax); TEST_CPU_KERNEL(Sgd); TEST_CPU_KERNEL(VBroadcast); +TEST(JITKernel, kernel_func) { + auto f1 = jit::KernelFuncs, CPUPlace>::Cache().At(3); + auto f2 = jit::KernelFuncs, CPUPlace>::Cache()[3]; + EXPECT_TRUE(f1 != nullptr); + EXPECT_TRUE(f1 == f2); + // TODO(TJ): check not equal +} + TEST(JITKernel_key, lstm) { jit::lstm_attr_t attr1(8, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh); jit::lstm_attr_t attr2(9, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh); @@ -1000,11 +1120,3 @@ TEST(JITKernel_key, gru) { EXPECT_TRUE(key2 == key3); EXPECT_TRUE(key3 != key4); } - -TEST(JITKernel, kernel_func) { - auto f1 = jit::KernelFuncs, CPUPlace>::Cache().At(3); - auto f2 = jit::KernelFuncs, CPUPlace>::Cache()[3]; - EXPECT_TRUE(f1 != nullptr); - EXPECT_TRUE(f1 == f2); - // TODO(TJ): check not equal -} -- GitLab