提交 45bdd84d 编写于 作者: T tensor-tang

enhance the jitkernel helper and add unit tests

test=develop
上级 14a764c9
...@@ -111,33 +111,11 @@ template <typename KernelTuple, typename PlaceType, typename... Args> ...@@ -111,33 +111,11 @@ template <typename KernelTuple, typename PlaceType, typename... Args>
void BenchAllImpls(const typename KernelTuple::attr_type& attr, Args... args) { void BenchAllImpls(const typename KernelTuple::attr_type& attr, Args... args) {
BenchFunc<KernelTuple, Args...> benchmark; BenchFunc<KernelTuple, Args...> benchmark;
std::vector<std::pair<std::string, double>> infos; std::vector<std::pair<std::string, double>> infos;
// test refer auto funcs = jit::GetAllCandidateFuncsWithTypes<KernelTuple, PlaceType>(attr);
auto refer = jit::GetRefer<KernelTuple>(); for (auto f : funcs) {
if (!refer) { infos.push_back(std::make_pair(f.first, benchmark(f.second, args...)));
LOG(FATAL) << "Refer can not be empty!";
} }
infos.push_back(std::make_pair("Refer", benchmark(refer, args...)));
// test jitcode
auto jitcode = jit::GetJitCode<KernelTuple, PlaceType>(attr);
if (jitcode) {
infos.push_back(std::make_pair("JitCode", benchmark(jitcode, args...)));
}
// test all impls in more
jit::KernelKey kkey(KernelTuple::kernel_type, PlaceType());
auto& pool = jit::KernelPool().Instance().AllKernels();
auto iter = pool.find(kkey);
if (iter != pool.end()) {
auto& impls = iter->second;
for (auto& impl : impls) {
auto i = dynamic_cast<const jit::KernelMore<KernelTuple>*>(impl.get());
if (i && i->UseMe(attr)) {
auto more = i->GetFunc();
infos.push_back(
std::make_pair(i->ImplType(), benchmark(more, args...)));
}
}
}
// Test result from Get function // Test result from Get function
auto tgt = jit::KernelFuncs<KernelTuple, PlaceType>::Cache().At(attr); auto tgt = jit::KernelFuncs<KernelTuple, PlaceType>::Cache().At(attr);
if (!tgt) { if (!tgt) {
......
...@@ -81,7 +81,7 @@ void VActJitCode::genCode() { ...@@ -81,7 +81,7 @@ void VActJitCode::genCode() {
#define DECLARE_ACT_CREATOR(name) \ #define DECLARE_ACT_CREATOR(name) \
class name##Creator : public JitCodeCreator<int> { \ class name##Creator : public JitCodeCreator<int> { \
public: \ public: \
bool UseMe(const int& attr) const override; \ bool CanBeUsed(const int& attr) const override; \
size_t CodeSize(const int& d) const override; \ size_t CodeSize(const int& d) const override; \
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { \ std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { \
return make_unique<name##JitCode>(attr, CodeSize(attr)); \ return make_unique<name##JitCode>(attr, CodeSize(attr)); \
...@@ -96,27 +96,27 @@ DECLARE_ACT_CREATOR(VSigmoid); ...@@ -96,27 +96,27 @@ DECLARE_ACT_CREATOR(VSigmoid);
DECLARE_ACT_CREATOR(VTanh); DECLARE_ACT_CREATOR(VTanh);
// TODO(TJ): tuning use me // TODO(TJ): tuning use me
bool VReluCreator::UseMe(const int& d) const { bool VReluCreator::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx); return platform::MayIUse(platform::avx);
} }
bool VSquareCreator::UseMe(const int& d) const { bool VSquareCreator::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx); return platform::MayIUse(platform::avx);
} }
bool VIdentityCreator::UseMe(const int& d) const { bool VIdentityCreator::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx); return platform::MayIUse(platform::avx);
} }
bool VExpCreator::UseMe(const int& d) const { bool VExpCreator::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx) && d < 32; return platform::MayIUse(platform::avx) && d < 32;
} }
bool VSigmoidCreator::UseMe(const int& d) const { bool VSigmoidCreator::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx); return platform::MayIUse(platform::avx);
} }
bool VTanhCreator::UseMe(const int& d) const { bool VTanhCreator::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx); return platform::MayIUse(platform::avx);
} }
......
...@@ -142,7 +142,7 @@ void NCHW16CMulNCJitCode::genCode() { ...@@ -142,7 +142,7 @@ void NCHW16CMulNCJitCode::genCode() {
class NCHW16CMulNCCreator : public JitCodeCreator<int> { class NCHW16CMulNCCreator : public JitCodeCreator<int> {
public: public:
bool UseMe(const int& attr) const override { bool CanBeUsed(const int& attr) const override {
return platform::MayIUse(platform::avx512f); return platform::MayIUse(platform::avx512f);
} }
size_t CodeSize(const int& d) const override { return 256 * 1024; } size_t CodeSize(const int& d) const override { return 256 * 1024; }
...@@ -154,7 +154,7 @@ class NCHW16CMulNCCreator : public JitCodeCreator<int> { ...@@ -154,7 +154,7 @@ class NCHW16CMulNCCreator : public JitCodeCreator<int> {
#define DECLARE_BLAS_CREATOR(name) \ #define DECLARE_BLAS_CREATOR(name) \
class name##Creator : public JitCodeCreator<int> { \ class name##Creator : public JitCodeCreator<int> { \
public: \ public: \
bool UseMe(const int& attr) const override { \ bool CanBeUsed(const int& attr) const override { \
return platform::MayIUse(platform::avx) && attr <= 1024; \ return platform::MayIUse(platform::avx) && attr <= 1024; \
} \ } \
size_t CodeSize(const int& d) const override { \ size_t CodeSize(const int& d) const override { \
......
...@@ -121,7 +121,7 @@ void EmbSeqPoolJitCode::genCode() { ...@@ -121,7 +121,7 @@ void EmbSeqPoolJitCode::genCode() {
class EmbSeqPoolCreator : public JitCodeCreator<emb_seq_pool_attr_t> { class EmbSeqPoolCreator : public JitCodeCreator<emb_seq_pool_attr_t> {
public: public:
bool UseMe(const emb_seq_pool_attr_t& attr) const override { bool CanBeUsed(const emb_seq_pool_attr_t& attr) const override {
return platform::MayIUse(platform::avx) && return platform::MayIUse(platform::avx) &&
attr.table_width % YMM_FLOAT_BLOCK == 0; attr.table_width % YMM_FLOAT_BLOCK == 0;
} }
......
...@@ -86,7 +86,7 @@ void GRUJitCode::genCode() { ...@@ -86,7 +86,7 @@ void GRUJitCode::genCode() {
class name##Creator : public JitCodeCreator<gru_attr_t> { \ class name##Creator : public JitCodeCreator<gru_attr_t> { \
public: \ public: \
/* TODO(TJ): enable more */ \ /* TODO(TJ): enable more */ \
bool UseMe(const gru_attr_t& attr) const override { \ bool CanBeUsed(const gru_attr_t& attr) const override { \
return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \ return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \
} \ } \
size_t CodeSize(const gru_attr_t& attr) const override { \ size_t CodeSize(const gru_attr_t& attr) const override { \
......
...@@ -76,7 +76,7 @@ void HOPVJitCode::genCode() { ...@@ -76,7 +76,7 @@ void HOPVJitCode::genCode() {
#define DECLARE_HOP_CREATOR(name) \ #define DECLARE_HOP_CREATOR(name) \
class name##Creator : public JitCodeCreator<int> { \ class name##Creator : public JitCodeCreator<int> { \
public: \ public: \
bool UseMe(const int& attr) const override { \ bool CanBeUsed(const int& attr) const override { \
return platform::MayIUse(platform::avx); \ return platform::MayIUse(platform::avx); \
} \ } \
size_t CodeSize(const int& d) const override { \ size_t CodeSize(const int& d) const override { \
......
...@@ -73,7 +73,7 @@ class JitCode : public GenBase, public Xbyak::CodeGenerator { ...@@ -73,7 +73,7 @@ class JitCode : public GenBase, public Xbyak::CodeGenerator {
virtual void genCode() = 0; virtual void genCode() = 0;
size_t getSize() const override { return CodeGenerator::getSize(); } size_t getSize() const override { return CodeGenerator::getSize(); }
const unsigned char* getCodeInternal() override { const unsigned char* getCodeInternal() const override {
const Xbyak::uint8* code = CodeGenerator::getCode(); const Xbyak::uint8* code = CodeGenerator::getCode();
return code; return code;
} }
......
...@@ -114,7 +114,7 @@ void LSTMJitCode::genCode() { ...@@ -114,7 +114,7 @@ void LSTMJitCode::genCode() {
class name##Creator : public JitCodeCreator<lstm_attr_t> { \ class name##Creator : public JitCodeCreator<lstm_attr_t> { \
public: \ public: \
/* TODO(TJ): enable more */ \ /* TODO(TJ): enable more */ \
bool UseMe(const lstm_attr_t& attr) const override { \ bool CanBeUsed(const lstm_attr_t& attr) const override { \
return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \ return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \
} \ } \
size_t CodeSize(const lstm_attr_t& attr) const override { \ size_t CodeSize(const lstm_attr_t& attr) const override { \
......
...@@ -98,7 +98,7 @@ void MatMulJitCode::genCode() { ...@@ -98,7 +98,7 @@ void MatMulJitCode::genCode() {
class MatMulCreator : public JitCodeCreator<matmul_attr_t> { class MatMulCreator : public JitCodeCreator<matmul_attr_t> {
public: public:
bool UseMe(const matmul_attr_t& attr) const override { bool CanBeUsed(const matmul_attr_t& attr) const override {
return attr.m == 1 && platform::MayIUse(platform::avx512f) && return attr.m == 1 && platform::MayIUse(platform::avx512f) &&
attr.n % ZMM_FLOAT_BLOCK == 0 && attr.k < 512; attr.n % ZMM_FLOAT_BLOCK == 0 && attr.k < 512;
} }
......
...@@ -57,7 +57,7 @@ void SeqPoolJitCode::genCode() { ...@@ -57,7 +57,7 @@ void SeqPoolJitCode::genCode() {
class SeqPoolCreator : public JitCodeCreator<seq_pool_attr_t> { class SeqPoolCreator : public JitCodeCreator<seq_pool_attr_t> {
public: public:
bool UseMe(const seq_pool_attr_t& attr) const override { bool CanBeUsed(const seq_pool_attr_t& attr) const override {
return platform::MayIUse(platform::avx); return platform::MayIUse(platform::avx);
} }
size_t CodeSize(const seq_pool_attr_t& attr) const override { size_t CodeSize(const seq_pool_attr_t& attr) const override {
......
...@@ -104,7 +104,7 @@ void SgdJitCode::genCode() { ...@@ -104,7 +104,7 @@ void SgdJitCode::genCode() {
class SgdCreator : public JitCodeCreator<sgd_attr_t> { class SgdCreator : public JitCodeCreator<sgd_attr_t> {
public: public:
bool UseMe(const sgd_attr_t& attr) const override { bool CanBeUsed(const sgd_attr_t& attr) const override {
return platform::MayIUse(platform::avx) && return platform::MayIUse(platform::avx) &&
attr.grad_width % YMM_FLOAT_BLOCK == 0; attr.grad_width % YMM_FLOAT_BLOCK == 0;
} }
......
...@@ -69,7 +69,7 @@ void VBroadcastJitCode::genCode() { ...@@ -69,7 +69,7 @@ void VBroadcastJitCode::genCode() {
class VBroadcastCreator : public JitCodeCreator<int64_t> { class VBroadcastCreator : public JitCodeCreator<int64_t> {
public: public:
bool UseMe(const int64_t& w) const override { bool CanBeUsed(const int64_t& w) const override {
return platform::MayIUse(platform::avx) && w % YMM_FLOAT_BLOCK == 0; return platform::MayIUse(platform::avx) && w % YMM_FLOAT_BLOCK == 0;
} }
size_t CodeSize(const int64_t& w) const override { size_t CodeSize(const int64_t& w) const override {
......
...@@ -31,7 +31,7 @@ namespace paddle { ...@@ -31,7 +31,7 @@ namespace paddle {
namespace operators { namespace operators {
namespace jit { namespace jit {
// refer do not need useme, it would be the last one. // refer do not need CanBeUsed, it would be the last one.
void GenBase::dumpCode(const unsigned char* code) const { void GenBase::dumpCode(const unsigned char* code) const {
if (code) { if (code) {
static int counter = 0; static int counter = 0;
......
...@@ -31,9 +31,10 @@ class GenBase : public Kernel { ...@@ -31,9 +31,10 @@ class GenBase : public Kernel {
virtual ~GenBase() = default; virtual ~GenBase() = default;
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual size_t getSize() const = 0; virtual size_t getSize() const = 0;
virtual const unsigned char* getCodeInternal() = 0; virtual const unsigned char* getCodeInternal() const = 0;
const char* ImplType() const override { return "JitCode"; }
template <typename Func> template <typename Func>
Func getCode() { Func getCode() const {
const unsigned char* code = this->getCodeInternal(); const unsigned char* code = this->getCodeInternal();
if (FLAGS_dump_jitcode) { if (FLAGS_dump_jitcode) {
this->dumpCode(code); this->dumpCode(code);
...@@ -65,7 +66,7 @@ class JitCodeCreator : public GenCreator { ...@@ -65,7 +66,7 @@ class JitCodeCreator : public GenCreator {
virtual ~JitCodeCreator() = default; virtual ~JitCodeCreator() = default;
// condition when this jit code can be used. // condition when this jit code can be used.
virtual bool UseMe(const Attr& attr) const = 0; virtual bool CanBeUsed(const Attr& attr) const = 0;
// estimate this code size // estimate this code size
virtual size_t CodeSize(const Attr& attr) const = 0; virtual size_t CodeSize(const Attr& attr) const = 0;
......
...@@ -14,9 +14,6 @@ ...@@ -14,9 +14,6 @@
#pragma once #pragma once
extern "C" {
#include <xxhash.h>
}
#include <iostream> #include <iostream>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
...@@ -36,31 +33,30 @@ template <typename KernelTuple, typename PlaceType> ...@@ -36,31 +33,30 @@ template <typename KernelTuple, typename PlaceType>
inline typename std::enable_if< inline typename std::enable_if<
std::is_same<typename KernelTuple::data_type, float>::value && std::is_same<typename KernelTuple::data_type, float>::value &&
std::is_same<PlaceType, platform::CPUPlace>::value, std::is_same<PlaceType, platform::CPUPlace>::value,
typename KernelTuple::func_type>::type const Kernel*>::type
GetJitCode(const typename KernelTuple::attr_type& attr) { GetJitCode(const typename KernelTuple::attr_type& attr) {
using Func = typename KernelTuple::func_type;
using Attr = typename KernelTuple::attr_type; using Attr = typename KernelTuple::attr_type;
size_t key = JitCodeKey<Attr>(attr); size_t key = JitCodeKey<Attr>(attr);
auto& codes = JitCodePool<KernelTuple::kernel_type>().Instance(); auto& codes = JitCodePool<KernelTuple::kernel_type>::Instance();
if (codes.Has(key)) { if (codes.Has(key)) {
return codes.AllKernels().at(key)->template getCode<Func>(); return codes.AllKernels().at(key).get();
} }
// creator is not related with attr, so can use KernelKey as key // creator is not related with attr, so can use KernelKey as key
KernelKey kkey(KernelTuple::kernel_type, PlaceType()); KernelKey kkey(KernelTuple::kernel_type, PlaceType());
// pool: (KernelKey(type, place), vector<GenCreatorPtr>) // pool: (KernelKey(type, place), vector<GenCreatorPtr>)
auto& creator_map = JitCodeCreatorPool().Instance().AllCreators(); auto& creator_map = JitCodeCreatorPool::Instance().AllCreators();
auto iter = creator_map.find(kkey); auto iter = creator_map.find(kkey);
if (iter != creator_map.end()) { if (iter != creator_map.end()) {
auto& creators = iter->second; auto& creators = iter->second;
for (auto& cur : creators) { for (auto& cur : creators) {
auto i = dynamic_cast<const JitCodeCreator<Attr>*>(cur.get()); auto i = dynamic_cast<const JitCodeCreator<Attr>*>(cur.get());
if (i && i->UseMe(attr)) { if (i && i->CanBeUsed(attr)) {
auto p = i->CreateJitCode(attr); auto p = i->CreateJitCode(attr);
if (p) { if (p) {
auto f = p->template getCode<Func>(); auto res = p.get();
codes.Insert(key, std::move(p)); codes.Insert(key, std::move(p));
return f; return res;
} }
} }
} }
...@@ -72,7 +68,7 @@ template <typename KernelTuple, typename PlaceType> ...@@ -72,7 +68,7 @@ template <typename KernelTuple, typename PlaceType>
inline typename std::enable_if< inline typename std::enable_if<
!std::is_same<typename KernelTuple::data_type, float>::value || !std::is_same<typename KernelTuple::data_type, float>::value ||
!std::is_same<PlaceType, platform::CPUPlace>::value, !std::is_same<PlaceType, platform::CPUPlace>::value,
typename KernelTuple::func_type>::type const Kernel*>::type
GetJitCode(const typename KernelTuple::attr_type& attr) { GetJitCode(const typename KernelTuple::attr_type& attr) {
return nullptr; return nullptr;
} }
...@@ -80,8 +76,8 @@ GetJitCode(const typename KernelTuple::attr_type& attr) { ...@@ -80,8 +76,8 @@ GetJitCode(const typename KernelTuple::attr_type& attr) {
// Refer code do not related with attr, which is just for cast // Refer code do not related with attr, which is just for cast
// Refer is always on CPUPlace // Refer is always on CPUPlace
template <typename KernelTuple> template <typename KernelTuple>
inline typename KernelTuple::func_type GetRefer() { inline const Kernel* GetReferKernel() {
auto& ref_pool = ReferKernelPool().Instance().AllKernels(); auto& ref_pool = ReferKernelPool::Instance().AllKernels();
KernelKey kkey(KernelTuple::kernel_type, platform::CPUPlace()); KernelKey kkey(KernelTuple::kernel_type, platform::CPUPlace());
auto ref_iter = ref_pool.find(kkey); auto ref_iter = ref_pool.find(kkey);
PADDLE_ENFORCE(ref_iter != ref_pool.end(), PADDLE_ENFORCE(ref_iter != ref_pool.end(),
...@@ -90,36 +86,93 @@ inline typename KernelTuple::func_type GetRefer() { ...@@ -90,36 +86,93 @@ inline typename KernelTuple::func_type GetRefer() {
for (auto& impl : ref_impls) { for (auto& impl : ref_impls) {
auto i = dynamic_cast<const ReferKernel<KernelTuple>*>(impl.get()); auto i = dynamic_cast<const ReferKernel<KernelTuple>*>(impl.get());
if (i) { if (i) {
return i->GetFunc(); return i;
} }
} }
return nullptr; return nullptr;
} }
template <typename KernelTuple, typename PlaceType = platform::CPUPlace> template <typename KernelTuple>
typename KernelTuple::func_type Get( inline typename KernelTuple::func_type GetReferFunc() {
auto ker = GetReferKernel<KernelTuple>();
auto p = dynamic_cast<const ReferKernel<KernelTuple>*>(ker);
PADDLE_ENFORCE(p, "The Refer kernel should exsit");
return p->GetFunc();
}
// Return all Kernels that can be used
template <typename KernelTuple, typename PlaceType>
std::vector<const Kernel*> GetAllCandidateKernels(
const typename KernelTuple::attr_type& attr) { const typename KernelTuple::attr_type& attr) {
auto jitfunc = GetJitCode<KernelTuple, PlaceType>(attr); // the search order shoudl be jitcode > more > refer
if (jitfunc) { std::vector<const Kernel*> res;
return jitfunc; auto jitker = GetJitCode<KernelTuple, PlaceType>(attr);
if (jitker) {
res.emplace_back(jitker);
} }
// pool: (KernelKey(type, place), vector<KernelPtr>) // more kernelpool: (KernelKey(type, place), vector<KernelPtr>)
KernelKey kkey(KernelTuple::kernel_type, PlaceType()); KernelKey kkey(KernelTuple::kernel_type, PlaceType());
auto& pool = KernelPool().Instance().AllKernels(); auto& pool = KernelPool::Instance().AllKernels();
auto iter = pool.find(kkey); auto iter = pool.find(kkey);
if (iter != pool.end()) { if (iter != pool.end()) {
auto& impls = iter->second; auto& impls = iter->second;
for (auto& impl : impls) { for (auto& impl : impls) {
auto i = dynamic_cast<const KernelMore<KernelTuple>*>(impl.get()); auto i = dynamic_cast<const KernelMore<KernelTuple>*>(impl.get());
if (i && i->UseMe(attr)) { if (i && i->CanBeUsed(attr)) {
return i->GetFunc(); res.emplace_back(i);
} }
} }
} }
// The last implementation should be reference function on CPUPlace. // The last implementation should be reference function on CPUPlace.
return GetRefer<KernelTuple>(); auto ref = GetReferKernel<KernelTuple>();
PADDLE_ENFORCE(ref != nullptr, "Refer Kernel can not be empty.");
res.emplace_back(ref);
return res;
}
template <typename KernelTuple, typename PlaceType = platform::CPUPlace>
std::vector<std::pair<std::string, typename KernelTuple::func_type>>
GetAllCandidateFuncsWithTypes(const typename KernelTuple::attr_type& attr) {
using Func = typename KernelTuple::func_type;
auto kers = GetAllCandidateKernels<KernelTuple, PlaceType>(attr);
std::vector<std::pair<std::string, Func>> res;
for (auto k : kers) {
std::string name = k->ImplType();
if (name == "JitCode") {
auto i = dynamic_cast<const GenBase*>(k);
PADDLE_ENFORCE(i, "jitcode kernel cast can not fail.");
res.emplace_back(std::make_pair(name, i->template getCode<Func>()));
} else {
auto i = dynamic_cast<const KernelMore<KernelTuple>*>(k);
PADDLE_ENFORCE(i, "kernel cast can not fail.");
res.emplace_back(std::make_pair(name, i->GetFunc()));
}
}
return res;
}
template <typename KernelTuple, typename PlaceType = platform::CPUPlace>
std::vector<typename KernelTuple::func_type> GetAllCandidateFuncs(
const typename KernelTuple::attr_type& attr) {
auto funcs = GetAllCandidateFuncsWithTypes<KernelTuple, PlaceType>(attr);
std::vector<typename KernelTuple::func_type> res;
for (auto& i : funcs) {
res.emplace_back(i.second);
}
return res;
}
template <typename KernelTuple, typename PlaceType = platform::CPUPlace>
typename KernelTuple::func_type GetDefaultBestFunc(
const typename KernelTuple::attr_type& attr) {
auto funcs = GetAllCandidateFuncs<KernelTuple, PlaceType>(attr);
PADDLE_ENFORCE_GE(funcs.size(), 1UL);
// Here could do some runtime benchmark of this attr and return the best one.
// But yet just get the first one as the default best one,
// which is searched in order and tuned by offline.
return funcs[0];
} }
template <typename KernelTuple, typename PlaceType> template <typename KernelTuple, typename PlaceType>
...@@ -134,17 +187,13 @@ class KernelFuncs { ...@@ -134,17 +187,13 @@ class KernelFuncs {
// the exposed interface to use // the exposed interface to use
typename KernelTuple::func_type At( typename KernelTuple::func_type At(
const typename KernelTuple::attr_type& attr) { const typename KernelTuple::attr_type& attr) {
// XXH64: 13.8 GB/s // Maybe here is not good enough, not all kernels should have jitcode
// TODO(TJ): change me, maybe not all attr change need one key, should be int64_t key = JitCodeKey<typename KernelTuple::attr_type>(attr);
// attrkey
int64_t key = XXH64(&attr, sizeof(typename KernelTuple::attr_type), 0);
if (Has(key)) { if (Has(key)) {
return funcs_.at(key); return funcs_.at(key);
} }
// If do not have this attr in cache, // If do not have this attr in cache then get the default best
// then could run some runtime benchmark of this attr and save the best one. auto func = GetDefaultBestFunc<KernelTuple, PlaceType>(attr);
// Here just get the offline benchmarked best one.
auto func = Get<KernelTuple, PlaceType>(attr);
Insert(key, func); Insert(key, func);
return func; return func;
} }
...@@ -156,7 +205,6 @@ class KernelFuncs { ...@@ -156,7 +205,6 @@ class KernelFuncs {
protected: protected:
bool Has(int64_t key) const { return funcs_.find(key) != funcs_.end(); } bool Has(int64_t key) const { return funcs_.find(key) != funcs_.end(); }
void Insert(int64_t key, typename KernelTuple::func_type func) { void Insert(int64_t key, typename KernelTuple::func_type func) {
funcs_.emplace(key, func); funcs_.emplace(key, func);
} }
......
...@@ -302,6 +302,7 @@ class Kernel { ...@@ -302,6 +302,7 @@ class Kernel {
public: public:
Kernel() = default; Kernel() = default;
virtual ~Kernel() = default; virtual ~Kernel() = default;
virtual const char* ImplType() const = 0;
DISABLE_COPY_AND_ASSIGN(Kernel); DISABLE_COPY_AND_ASSIGN(Kernel);
}; };
...@@ -312,8 +313,8 @@ class KernelMore : public Kernel { ...@@ -312,8 +313,8 @@ class KernelMore : public Kernel {
using Func = typename KernelTuple::func_type; using Func = typename KernelTuple::func_type;
using Attr = typename KernelTuple::attr_type; using Attr = typename KernelTuple::attr_type;
virtual Func GetFunc() const { return func; } virtual Func GetFunc() const { return func; }
virtual bool UseMe(const Attr& attr) const = 0; // specify this kernel can be used, means it should not fail if use it.
virtual const char* ImplType() const = 0; virtual bool CanBeUsed(const Attr& attr) const = 0;
protected: protected:
Func func{nullptr}; Func func{nullptr};
...@@ -323,7 +324,7 @@ template <typename KernelTuple> ...@@ -323,7 +324,7 @@ template <typename KernelTuple>
class ReferKernel : public KernelMore<KernelTuple> { class ReferKernel : public KernelMore<KernelTuple> {
public: public:
// Refer code can always be used // Refer code can always be used
bool UseMe(const typename KernelTuple::attr_type& attr) const override { bool CanBeUsed(const typename KernelTuple::attr_type& attr) const override {
return true; return true;
} }
const char* ImplType() const override { return "Refer"; } const char* ImplType() const override { return "Refer"; }
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/kernel_key.h" #include "paddle/fluid/operators/jit/kernel_key.h"
#include <xxhash.h>
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
...@@ -49,6 +50,8 @@ static inline int act_type_convert(KernelType type) { ...@@ -49,6 +50,8 @@ static inline int act_type_convert(KernelType type) {
template <> template <>
size_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) { size_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) {
// XXH64: 13.8 GB/s
size_t key = attr.d; size_t key = attr.d;
int gate_key = act_type_convert(attr.act_gate) << 1; int gate_key = act_type_convert(attr.act_gate) << 1;
int cand_key = act_type_convert(attr.act_cand) << (1 + act_type_shift); int cand_key = act_type_convert(attr.act_cand) << (1 + act_type_shift);
......
...@@ -161,7 +161,7 @@ void CRFDecoding(const int seq_len, const float* x, const float* w, ...@@ -161,7 +161,7 @@ void CRFDecoding(const int seq_len, const float* x, const float* w,
} }
} }
bool CRFDecodingKernel::UseMe(const int& d) const { bool CRFDecodingKernel::CanBeUsed(const int& d) const {
#ifdef __AVX512F__ #ifdef __AVX512F__
constexpr int block = ZMM_FLOAT_BLOCK; constexpr int block = ZMM_FLOAT_BLOCK;
#else #else
......
...@@ -29,7 +29,8 @@ void CRFDecoding(const int seq_len, const float* x, const float* w, ...@@ -29,7 +29,8 @@ void CRFDecoding(const int seq_len, const float* x, const float* w,
class CRFDecodingKernel : public KernelMore<CRFDecodingTuple<float>> { class CRFDecodingKernel : public KernelMore<CRFDecodingTuple<float>> {
public: public:
CRFDecodingKernel() { this->func = CRFDecoding; } CRFDecodingKernel() { this->func = CRFDecoding; }
bool UseMe(const typename CRFDecodingTuple<float>::attr_type&) const override; bool CanBeUsed(
const typename CRFDecodingTuple<float>::attr_type&) const override;
const char* ImplType() const override { return "Intrinsic"; } const char* ImplType() const override { return "Intrinsic"; }
}; };
......
...@@ -153,7 +153,7 @@ void LayerNorm(float* x, float* out, float* mean, float* var, ...@@ -153,7 +153,7 @@ void LayerNorm(float* x, float* out, float* mean, float* var,
} }
} }
bool LayerNormKernel::UseMe(const int& d) const { bool LayerNormKernel::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx) && d >= YMM_FLOAT_BLOCK; return platform::MayIUse(platform::avx) && d >= YMM_FLOAT_BLOCK;
} }
......
...@@ -30,7 +30,8 @@ void LayerNorm(float* x, float* out, float* mean, float* var, ...@@ -30,7 +30,8 @@ void LayerNorm(float* x, float* out, float* mean, float* var,
class LayerNormKernel : public KernelMore<LayerNormTuple<float>> { class LayerNormKernel : public KernelMore<LayerNormTuple<float>> {
public: public:
LayerNormKernel() { this->func = LayerNorm; } LayerNormKernel() { this->func = LayerNorm; }
bool UseMe(const typename LayerNormTuple<float>::attr_type&) const override; bool CanBeUsed(
const typename LayerNormTuple<float>::attr_type&) const override;
const char* ImplType() const override { return "Intrinsic"; } const char* ImplType() const override { return "Intrinsic"; }
}; };
......
...@@ -204,21 +204,21 @@ void GRUHtPart2(gru_t* step, const gru_attr_t* attr) { ...@@ -204,21 +204,21 @@ void GRUHtPart2(gru_t* step, const gru_attr_t* attr) {
} }
// TODO(TJ): tuning me // TODO(TJ): tuning me
bool VSigmoidKernel::UseMe(const int& d) const { return true; } bool VSigmoidKernel::CanBeUsed(const int& d) const { return true; }
bool VTanhKernel::UseMe(const int& d) const { return true; } bool VTanhKernel::CanBeUsed(const int& d) const { return true; }
bool SoftmaxKernel::UseMe(const int& d) const { return true; } bool SoftmaxKernel::CanBeUsed(const int& d) const { return true; }
bool LSTMCtHtKernel::UseMe(const lstm_attr_t& attr) const { return true; } bool LSTMCtHtKernel::CanBeUsed(const lstm_attr_t& attr) const { return true; }
bool LSTMC1H1Kernel::UseMe(const lstm_attr_t& attr) const { return true; } bool LSTMC1H1Kernel::CanBeUsed(const lstm_attr_t& attr) const { return true; }
bool GRUH1Kernel::UseMe(const gru_attr_t& attr) const { return true; } bool GRUH1Kernel::CanBeUsed(const gru_attr_t& attr) const { return true; }
bool GRUHtPart1Kernel::UseMe(const gru_attr_t& attr) const { return true; } bool GRUHtPart1Kernel::CanBeUsed(const gru_attr_t& attr) const { return true; }
bool GRUHtPart2Kernel::UseMe(const gru_attr_t& attr) const { return true; } bool GRUHtPart2Kernel::CanBeUsed(const gru_attr_t& attr) const { return true; }
} // namespace mix } // namespace mix
} // namespace more } // namespace more
......
...@@ -34,12 +34,12 @@ void GRUH1(gru_t* step, const gru_attr_t* attr); ...@@ -34,12 +34,12 @@ void GRUH1(gru_t* step, const gru_attr_t* attr);
void GRUHtPart1(gru_t* step, const gru_attr_t* attr); void GRUHtPart1(gru_t* step, const gru_attr_t* attr);
void GRUHtPart2(gru_t* step, const gru_attr_t* attr); void GRUHtPart2(gru_t* step, const gru_attr_t* attr);
#define DECLARE_MORE_KERNEL(name) \ #define DECLARE_MORE_KERNEL(name) \
class name##Kernel : public KernelMore<name##Tuple<T>> { \ class name##Kernel : public KernelMore<name##Tuple<T>> { \
public: \ public: \
name##Kernel() { this->func = name; } \ name##Kernel() { this->func = name; } \
bool UseMe(const typename name##Tuple<T>::attr_type&) const override; \ bool CanBeUsed(const typename name##Tuple<T>::attr_type&) const override; \
const char* ImplType() const override { return "Mixed"; } \ const char* ImplType() const override { return "Mixed"; } \
} }
// XYN // XYN
......
...@@ -130,105 +130,106 @@ void ASum<double>(const double* x, double* res, int n) { ...@@ -130,105 +130,106 @@ void ASum<double>(const double* x, double* res, int n) {
// TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512 // TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512
template <> template <>
bool VMulKernel<float>::UseMe(const int& d) const { bool VMulKernel<float>::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx512f) && d > 512; return platform::MayIUse(platform::avx512f) && d > 512;
} }
template <> template <>
bool VAddKernel<float>::UseMe(const int& d) const { bool VAddKernel<float>::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx) && d > 512; return platform::MayIUse(platform::avx) && d > 512;
} }
template <> template <>
bool VScalKernel<float>::UseMe(const int& d) const { bool VScalKernel<float>::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx512f) && d > 512; return platform::MayIUse(platform::avx512f) && d > 512;
} }
template <> template <>
bool VExpKernel<float>::UseMe(const int& d) const { bool VExpKernel<float>::CanBeUsed(const int& d) const {
return d > 7; return d > 7;
} }
template <> template <>
bool VSquareKernel<float>::UseMe(const int& d) const { bool VSquareKernel<float>::CanBeUsed(const int& d) const {
return d > 7; return d > 7;
} }
template <> template <>
bool VCopyKernel<float>::UseMe(const int& d) const { bool VCopyKernel<float>::CanBeUsed(const int& d) const {
return d > 15; return d > 15;
} }
template <> template <>
bool VBroadcastKernel<float>::UseMe(const int64_t& d) const { bool VBroadcastKernel<float>::CanBeUsed(const int64_t& d) const {
return d > 127; return d > 127;
} }
template <> template <>
bool VBroadcastKernel<double>::UseMe(const int64_t& attr) const { bool VBroadcastKernel<double>::CanBeUsed(const int64_t& attr) const {
return true; return true;
} }
template <> template <>
bool VSigmoidKernel<float>::UseMe(const int& d) const { bool VSigmoidKernel<float>::CanBeUsed(const int& d) const {
return d > 7; return d > 7;
} }
template <> template <>
bool VTanhKernel<float>::UseMe(const int& d) const { bool VTanhKernel<float>::CanBeUsed(const int& d) const {
return d > 7; return d > 7;
} }
template <> template <>
bool SeqPoolKernel<float>::UseMe(const seq_pool_attr_t& attr) const { bool SeqPoolKernel<float>::CanBeUsed(const seq_pool_attr_t& attr) const {
return true; return true;
} }
template <> template <>
bool SeqPoolKernel<double>::UseMe(const seq_pool_attr_t& attr) const { bool SeqPoolKernel<double>::CanBeUsed(const seq_pool_attr_t& attr) const {
return true; return true;
} }
template <> template <>
bool EmbSeqPoolKernel<float>::UseMe(const emb_seq_pool_attr_t& attr) const { bool EmbSeqPoolKernel<float>::CanBeUsed(const emb_seq_pool_attr_t& attr) const {
return true; return true;
} }
template <> template <>
bool EmbSeqPoolKernel<double>::UseMe(const emb_seq_pool_attr_t& attr) const { bool EmbSeqPoolKernel<double>::CanBeUsed(
const emb_seq_pool_attr_t& attr) const {
return true; return true;
} }
template <> template <>
bool SgdKernel<float>::UseMe(const sgd_attr_t& attr) const { bool SgdKernel<float>::CanBeUsed(const sgd_attr_t& attr) const {
return true; return true;
} }
template <> template <>
bool SgdKernel<double>::UseMe(const sgd_attr_t& attr) const { bool SgdKernel<double>::CanBeUsed(const sgd_attr_t& attr) const {
return true; return true;
} }
template <> template <>
bool MatMulKernel<float>::UseMe(const matmul_attr_t& attr) const { bool MatMulKernel<float>::CanBeUsed(const matmul_attr_t& attr) const {
return platform::MayIUse(platform::avx); return platform::MayIUse(platform::avx);
} }
template <> template <>
bool MatMulKernel<double>::UseMe(const matmul_attr_t& attr) const { bool MatMulKernel<double>::CanBeUsed(const matmul_attr_t& attr) const {
return true; return true;
} }
template <> template <>
bool SoftmaxKernel<float>::UseMe(const int& d) const { bool SoftmaxKernel<float>::CanBeUsed(const int& d) const {
// tuned on avx2 // tuned on avx2
return platform::MayIUse(platform::avx) && d < 60; return platform::MayIUse(platform::avx) && d < 60;
} }
#define AWALYS_USE_ME_WITH_DOUBLE(func) \ #define AWALYS_USE_ME_WITH_DOUBLE(func) \
template <> \ template <> \
bool func##Kernel<double>::UseMe(const int& d) const { \ bool func##Kernel<double>::CanBeUsed(const int& d) const { \
return true; \ return true; \
} }
AWALYS_USE_ME_WITH_DOUBLE(VMul); AWALYS_USE_ME_WITH_DOUBLE(VMul);
......
...@@ -175,13 +175,13 @@ void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows, ...@@ -175,13 +175,13 @@ void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows,
} }
} }
#define DECLARE_MKL_KERNEL(name) \ #define DECLARE_MKL_KERNEL(name) \
template <typename T> \ template <typename T> \
class name##Kernel : public KernelMore<name##Tuple<T>> { \ class name##Kernel : public KernelMore<name##Tuple<T>> { \
public: \ public: \
name##Kernel() { this->func = name<T>; } \ name##Kernel() { this->func = name<T>; } \
bool UseMe(const typename name##Tuple<T>::attr_type&) const override; \ bool CanBeUsed(const typename name##Tuple<T>::attr_type&) const override; \
const char* ImplType() const override { return "MKL"; } \ const char* ImplType() const override { return "MKL"; } \
} }
// ABCMNK // ABCMNK
......
...@@ -49,8 +49,8 @@ struct JitKernelRegistrarFunctor<Pool, PlaceType, false, I, KernelImpls...> { ...@@ -49,8 +49,8 @@ struct JitKernelRegistrarFunctor<Pool, PlaceType, false, I, KernelImpls...> {
void operator()(KernelType kt) const { void operator()(KernelType kt) const {
KernelKey kkey(kt, PlaceType()); KernelKey kkey(kt, PlaceType());
Pool().Instance().Insert(kkey, Pool::Instance().Insert(kkey,
std::move(make_unique<const KERNEL_IMPL_TYPE>())); std::move(make_unique<const KERNEL_IMPL_TYPE>()));
constexpr auto size = std::tuple_size<std::tuple<KernelImpls...>>::value; constexpr auto size = std::tuple_size<std::tuple<KernelImpls...>>::value;
JitKernelRegistrarFunctor<Pool, PlaceType, I + 1 == size, I + 1, JitKernelRegistrarFunctor<Pool, PlaceType, I + 1 == size, I + 1,
KernelImpls...> KernelImpls...>
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <iostream>
#include <random> #include <random>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -68,31 +69,11 @@ template <typename KernelTuple, typename PlaceType, typename Tester, ...@@ -68,31 +69,11 @@ template <typename KernelTuple, typename PlaceType, typename Tester,
typename... Args> typename... Args>
void TestAllImpls(const typename KernelTuple::attr_type& attr, void TestAllImpls(const typename KernelTuple::attr_type& attr,
const Tester& verifier, const Args&... args) { const Tester& verifier, const Args&... args) {
// test jitcode auto funcs = jit::GetAllCandidateFuncsWithTypes<KernelTuple, PlaceType>(attr);
auto jitcode = jit::GetJitCode<KernelTuple, PlaceType>(attr); for (auto f : funcs) {
if (jitcode) { VLOG(10) << "Test Kernel " << f.first;
VLOG(10) << "Test Jitcode Kernel "; verifier(f.second, args...);
verifier(jitcode, args...);
} }
// test all impls in more
jit::KernelKey kkey(KernelTuple::kernel_type, PlaceType());
auto& pool = jit::KernelPool().Instance().AllKernels();
auto iter = pool.find(kkey);
if (iter != pool.end()) {
auto& impls = iter->second;
for (auto& impl : impls) {
auto i = dynamic_cast<const jit::KernelMore<KernelTuple>*>(impl.get());
if (i && i->UseMe(attr)) {
auto more = i->GetFunc();
VLOG(10) << "Test More Kernel : " << i->ImplType();
verifier(more, args...);
}
}
}
// test result from Get function
VLOG(10) << "Test final get function ";
auto tgt = jit::KernelFuncs<KernelTuple, PlaceType>::Cache().At(attr);
verifier(tgt, args...);
} }
template <typename KernelTuple, typename PlaceType> template <typename KernelTuple, typename PlaceType>
...@@ -100,7 +81,7 @@ void TestKernelXYZN() { ...@@ -100,7 +81,7 @@ void TestKernelXYZN() {
using T = typename KernelTuple::data_type; using T = typename KernelTuple::data_type;
VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type); VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type);
for (int d : TestSizes()) { for (int d : TestSizes()) {
auto ref = jit::GetRefer<KernelTuple>(); auto ref = jit::GetReferFunc<KernelTuple>();
EXPECT_TRUE(ref != nullptr); EXPECT_TRUE(ref != nullptr);
std::vector<T> x(d), y(d), zref(d); std::vector<T> x(d), y(d), zref(d);
...@@ -159,7 +140,7 @@ void TestKernelAXYN() { ...@@ -159,7 +140,7 @@ void TestKernelAXYN() {
using T = typename KernelTuple::data_type; using T = typename KernelTuple::data_type;
VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type); VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type);
for (int d : TestSizes()) { for (int d : TestSizes()) {
auto ref = jit::GetRefer<KernelTuple>(); auto ref = jit::GetReferFunc<KernelTuple>();
EXPECT_TRUE(ref != nullptr); EXPECT_TRUE(ref != nullptr);
const T a = static_cast<T>(3); const T a = static_cast<T>(3);
...@@ -202,7 +183,7 @@ void TestKernelXYN() { ...@@ -202,7 +183,7 @@ void TestKernelXYN() {
using T = typename KernelTuple::data_type; using T = typename KernelTuple::data_type;
VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type); VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type);
for (int d : TestSizes()) { for (int d : TestSizes()) {
auto ref = jit::GetRefer<KernelTuple>(); auto ref = jit::GetReferFunc<KernelTuple>();
EXPECT_TRUE(ref != nullptr); EXPECT_TRUE(ref != nullptr);
std::vector<T> x(d), yref(d); std::vector<T> x(d), yref(d);
...@@ -245,7 +226,7 @@ void TestKernelXRN() { ...@@ -245,7 +226,7 @@ void TestKernelXRN() {
auto last_acc = FLAGS_acc; auto last_acc = FLAGS_acc;
FLAGS_acc = 1e-4; FLAGS_acc = 1e-4;
for (int d : TestSizes()) { for (int d : TestSizes()) {
auto ref = jit::GetRefer<KernelTuple>(); auto ref = jit::GetReferFunc<KernelTuple>();
EXPECT_TRUE(ref != nullptr); EXPECT_TRUE(ref != nullptr);
std::vector<T> x(d); std::vector<T> x(d);
RandomVec<T>(d, x.data()); RandomVec<T>(d, x.data());
...@@ -279,7 +260,7 @@ void TestKernelLSTM() { ...@@ -279,7 +260,7 @@ void TestKernelLSTM() {
const jit::lstm_attr_t attr( const jit::lstm_attr_t attr(
d, jit::to_kerneltype(act_gate), jit::to_kerneltype(act_cand), d, jit::to_kerneltype(act_gate), jit::to_kerneltype(act_cand),
jit::to_kerneltype(act_cell), use_peephole); jit::to_kerneltype(act_cell), use_peephole);
auto ref = jit::GetRefer<KernelTuple>(); auto ref = jit::GetReferFunc<KernelTuple>();
EXPECT_TRUE(ref != nullptr); EXPECT_TRUE(ref != nullptr);
std::vector<T> xsrc(4 * d), wp(3 * d), ct_1(d); std::vector<T> xsrc(4 * d), wp(3 * d), ct_1(d);
std::vector<T> ct_ref(d), ht_ref(d), checked(2 * d); std::vector<T> ct_ref(d), ht_ref(d), checked(2 * d);
...@@ -370,7 +351,7 @@ void TestKernelGRU() { ...@@ -370,7 +351,7 @@ void TestKernelGRU() {
for (auto& act_cand : all_acts) { for (auto& act_cand : all_acts) {
const jit::gru_attr_t attr(d, jit::to_kerneltype(act_gate), const jit::gru_attr_t attr(d, jit::to_kerneltype(act_gate),
jit::to_kerneltype(act_cand)); jit::to_kerneltype(act_cand));
auto ref = jit::GetRefer<KernelTuple>(); auto ref = jit::GetReferFunc<KernelTuple>();
EXPECT_TRUE(ref != nullptr); EXPECT_TRUE(ref != nullptr);
std::vector<T> xsrc(3 * d), ht_1(d), ht_ref(d); std::vector<T> xsrc(3 * d), ht_1(d), ht_ref(d);
RandomVec<T>(3 * d, xsrc.data()); RandomVec<T>(3 * d, xsrc.data());
...@@ -423,7 +404,7 @@ void TestKernelNCHW16CMulNC() { ...@@ -423,7 +404,7 @@ void TestKernelNCHW16CMulNC() {
using T = typename KernelTuple::data_type; using T = typename KernelTuple::data_type;
VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type); VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type);
const int n = 3, c = 16 * 4, h = 10, w = 10; const int n = 3, c = 16 * 4, h = 10, w = 10;
auto ref = jit::GetRefer<KernelTuple>(); auto ref = jit::GetReferFunc<KernelTuple>();
EXPECT_TRUE(ref != nullptr); EXPECT_TRUE(ref != nullptr);
int sz = n * c * h * w; int sz = n * c * h * w;
std::vector<T> x(sz), y(n * c), zref(sz); std::vector<T> x(sz), y(n * c), zref(sz);
...@@ -439,7 +420,9 @@ void TestKernelNCHW16CMulNC() { ...@@ -439,7 +420,9 @@ void TestKernelNCHW16CMulNC() {
constexpr int simd_width = ZMM_FLOAT_BLOCK; constexpr int simd_width = ZMM_FLOAT_BLOCK;
int C = c / simd_width; int C = c / simd_width;
auto tgt = jit::KernelFuncs<KernelTuple, PlaceType>::Cache().At(0); auto tgt = jit::KernelFuncs<KernelTuple, PlaceType>::Cache().At(0);
auto jitcode = jit::GetJitCode<KernelTuple, PlaceType>(0); auto funcs = jit::GetAllCandidateFuncs<KernelTuple, PlaceType>(0);
EXPECT_GT(funcs.size(), 0UL);
auto jitcode = funcs[0];
EXPECT_TRUE(tgt != nullptr); EXPECT_TRUE(tgt != nullptr);
if (std::is_same<T, float>::value && if (std::is_same<T, float>::value &&
...@@ -482,7 +465,7 @@ void TestKernelLayerNorm() { ...@@ -482,7 +465,7 @@ void TestKernelLayerNorm() {
int left = n * x_dim_0; int left = n * x_dim_0;
for (int x_dim_1 : TestSizes()) { for (int x_dim_1 : TestSizes()) {
int right = x_dim_1; int right = x_dim_1;
auto ref = jit::GetRefer<KernelTuple>(); auto ref = jit::GetReferFunc<KernelTuple>();
EXPECT_TRUE(ref != nullptr); EXPECT_TRUE(ref != nullptr);
int sz = left * right; int sz = left * right;
std::vector<T> x(sz), mean(left), var(left), scale(right), bias(right), std::vector<T> x(sz), mean(left), var(left), scale(right), bias(right),
...@@ -555,7 +538,7 @@ void TestKernelCRFDecoding() { ...@@ -555,7 +538,7 @@ void TestKernelCRFDecoding() {
test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 2000)); test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 2000));
for (int seq_len : {1, 11, 17, 50}) { for (int seq_len : {1, 11, 17, 50}) {
for (int tag_num : test_sizes) { for (int tag_num : test_sizes) {
auto ref = jit::GetRefer<KernelTuple>(); auto ref = jit::GetReferFunc<KernelTuple>();
EXPECT_TRUE(ref != nullptr); EXPECT_TRUE(ref != nullptr);
int x_sz = seq_len * tag_num; int x_sz = seq_len * tag_num;
int w_sz = (tag_num + state_trans_base_idx) * tag_num; int w_sz = (tag_num + state_trans_base_idx) * tag_num;
...@@ -606,7 +589,7 @@ void TestKernelSeqPool() { ...@@ -606,7 +589,7 @@ void TestKernelSeqPool() {
jit::seq_pool_attr_t attr(w, type); jit::seq_pool_attr_t attr(w, type);
for (int h : test_sizes) { for (int h : test_sizes) {
attr.h = h; attr.h = h;
auto ref = jit::GetRefer<KernelTuple>(); auto ref = jit::GetReferFunc<KernelTuple>();
EXPECT_TRUE(ref != nullptr); EXPECT_TRUE(ref != nullptr);
std::vector<T> x(h * w), yref(w); std::vector<T> x(h * w), yref(w);
RandomVec<T>(h * w, x.data()); RandomVec<T>(h * w, x.data());
...@@ -649,7 +632,7 @@ void TestKernelEmbSeqPool() { ...@@ -649,7 +632,7 @@ void TestKernelEmbSeqPool() {
for (auto type : pool_types) { for (auto type : pool_types) {
for (int idx_w : {1, 2, 10, 16}) { for (int idx_w : {1, 2, 10, 16}) {
for (int idx_h : {1, 2, 9, 13, 16}) { for (int idx_h : {1, 2, 9, 13, 16}) {
auto ref = jit::GetRefer<KernelTuple>(); auto ref = jit::GetReferFunc<KernelTuple>();
EXPECT_TRUE(ref != nullptr); EXPECT_TRUE(ref != nullptr);
std::vector<int64_t> idx(idx_h * idx_w); std::vector<int64_t> idx(idx_h * idx_w);
RandomVec<int64_t>(idx_h * idx_w, idx.data(), 0, tbl_h - 1); RandomVec<int64_t>(idx_h * idx_w, idx.data(), 0, tbl_h - 1);
...@@ -701,7 +684,7 @@ void TestKernelMatMul() { ...@@ -701,7 +684,7 @@ void TestKernelMatMul() {
for (int m : {1, 2, 3, 4}) { for (int m : {1, 2, 3, 4}) {
for (int n : {1, 2, 3, 4}) { for (int n : {1, 2, 3, 4}) {
for (int k : TestSizes()) { for (int k : TestSizes()) {
auto ref = jit::GetRefer<KernelTuple>(); auto ref = jit::GetReferFunc<KernelTuple>();
EXPECT_TRUE(ref != nullptr); EXPECT_TRUE(ref != nullptr);
std::vector<T> a(m * k), b(k * n), c(m * n); std::vector<T> a(m * k), b(k * n), c(m * n);
RandomVec<T>(m * k, a.data()); RandomVec<T>(m * k, a.data());
...@@ -740,7 +723,7 @@ void TestKernelSoftmax() { ...@@ -740,7 +723,7 @@ void TestKernelSoftmax() {
VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type); VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type);
for (int bs : {1, 2, 10}) { for (int bs : {1, 2, 10}) {
for (int n : TestSizes()) { for (int n : TestSizes()) {
auto ref = jit::GetRefer<KernelTuple>(); auto ref = jit::GetReferFunc<KernelTuple>();
EXPECT_TRUE(ref != nullptr); EXPECT_TRUE(ref != nullptr);
std::vector<T> x(bs * n), y(bs * n); std::vector<T> x(bs * n), y(bs * n);
RandomVec<T>(bs * n, x.data()); RandomVec<T>(bs * n, x.data());
...@@ -808,7 +791,7 @@ void TestKernelSgd() { ...@@ -808,7 +791,7 @@ void TestKernelSgd() {
RandomVec<T>(rows_size * grad_w, grad.data()); RandomVec<T>(rows_size * grad_w, grad.data());
const int64_t* rows_data = rows.data(); const int64_t* rows_data = rows.data();
const T* grad_data = grad.data(); const T* grad_data = grad.data();
auto ref = jit::GetRefer<KernelTuple>(); auto ref = jit::GetReferFunc<KernelTuple>();
EXPECT_TRUE(ref != nullptr); EXPECT_TRUE(ref != nullptr);
jit::sgd_attr_t attr(param_h, grad_w, rows_size, grad_w, rows_size); jit::sgd_attr_t attr(param_h, grad_w, rows_size, grad_w, rows_size);
ref(&lr, param_data, grad_data, rows_data, out_data, &attr); ref(&lr, param_data, grad_data, rows_data, out_data, &attr);
...@@ -874,7 +857,7 @@ void TestKernelVBroadcast() { ...@@ -874,7 +857,7 @@ void TestKernelVBroadcast() {
RandomVec<T>(w, x.data()); RandomVec<T>(w, x.data());
const T* x_data = x.data(); const T* x_data = x.data();
for (int64_t h : {1, 2, 6}) { for (int64_t h : {1, 2, 6}) {
auto ref = jit::GetRefer<KernelTuple>(); auto ref = jit::GetReferFunc<KernelTuple>();
EXPECT_TRUE(ref != nullptr); EXPECT_TRUE(ref != nullptr);
std::vector<T> y(w * h); std::vector<T> y(w * h);
T* y_data = y.data(); T* y_data = y.data();
...@@ -900,6 +883,135 @@ void TestKernelVBroadcast() { ...@@ -900,6 +883,135 @@ void TestKernelVBroadcast() {
} }
} }
// test pool
TEST(JITKernel_pool, jitcreator) {
const auto& jitcreators = jit::JitCodeCreatorPool::Instance().AllCreators();
EXPECT_EQ(jitcreators.size(), 25UL);
}
TEST(JITKernel_pool, jitpool) {
// jitpool is related with attr
const auto& kers = jit::JitCodePool<jit::kVAdd>().Instance().AllKernels();
EXPECT_EQ(kers.size(), 0UL);
jit::GetAllCandidateKernels<jit::VAddTuple<float>, CPUPlace>(3);
// after call GetAllCandidateKernels, it will create jitcode Automatically
EXPECT_EQ(kers.size(), 1UL);
}
TEST(JITKernel_pool, more) {
const auto& kers = jit::KernelPool::Instance().AllKernels();
EXPECT_EQ(kers.size(), 21UL);
}
TEST(JITKernel_pool, refer) {
const auto& kers = jit::ReferKernelPool::Instance().AllKernels();
EXPECT_EQ(kers.size(), 29UL);
}
// test helper
TEST(JITKernel_helper, GetAllCandidateKernels) {
auto fp_kers =
jit::GetAllCandidateKernels<jit::VExpTuple<float>, CPUPlace>(10);
#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__)
EXPECT_GE(fp_kers.size(), 1UL); // refer
#else
EXPECT_GE(fp_kers.size(), 3UL); // jitcode, mkl, refer
#endif
auto db_kers =
jit::GetAllCandidateKernels<jit::VExpTuple<double>, CPUPlace>(10);
#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__)
EXPECT_GE(db_kers.size(), 1UL); // refer
#else
EXPECT_GE(db_kers.size(), 2UL); // mkl, refer
#endif
}
TEST(JITKernel_helper, GetAllCandidateFuncsWithTypes) {
auto fp_kers =
jit::GetAllCandidateFuncsWithTypes<jit::VExpTuple<float>, CPUPlace>(10);
EXPECT_GE(fp_kers.size(), 3UL); // jitcode, mkl, refer
auto db_kers =
jit::GetAllCandidateFuncsWithTypes<jit::VExpTuple<double>, CPUPlace>(10);
EXPECT_GE(db_kers.size(), 2UL); // mkl, refer
}
TEST(JITKernel_helper, GetAllCandidateFuncs) {
auto funcs = jit::GetAllCandidateFuncs<jit::VExpTuple<float>, CPUPlace>(10);
auto kers = jit::GetAllCandidateKernels<jit::VExpTuple<float>, CPUPlace>(10);
EXPECT_EQ(funcs.size(), kers.size());
std::vector<float> x(10), tgt(10);
RandomVec<float>(10, x.data());
auto best = jit::GetDefaultBestFunc<jit::VExpTuple<float>, CPUPlace>(10);
best(x.data(), tgt.data(), 10);
for (auto f : funcs) {
std::vector<float> y(10);
f(x.data(), y.data(), 10);
ExpectEQ<float>(y.data(), tgt.data(), 10);
}
}
TEST(JITKernel_helper, attr) {
std::ostringstream out;
// KernelTypes
out << jit::to_string(jit::kNone) << jit::to_string(jit::kCRFDecoding)
<< jit::to_string(jit::kEmbSeqPool) << jit::to_string(jit::kGRUH1)
<< jit::to_string(jit::kGRUHtPart1) << jit::to_string(jit::kGRUHtPart2)
<< jit::to_string(jit::kHSum) << jit::to_string(jit::kHMax)
<< jit::to_string(jit::kLSTMCtHt) << jit::to_string(jit::kLSTMC1H1)
<< jit::to_string(jit::kLayerNorm) << jit::to_string(jit::kMatMul)
<< jit::to_string(jit::kNCHW16CMulNC) << jit::to_string(jit::kSeqPool)
<< jit::to_string(jit::kSoftmax) << jit::to_string(jit::kVAdd)
<< jit::to_string(jit::kVAddBias) << jit::to_string(jit::kVAddRelu)
<< jit::to_string(jit::kVBroadcast) << jit::to_string(jit::kVCopy)
<< jit::to_string(jit::kVExp) << jit::to_string(jit::kVIdentity)
<< jit::to_string(jit::kVMul) << jit::to_string(jit::kVRelu)
<< jit::to_string(jit::kVScal) << jit::to_string(jit::kSgd)
<< jit::to_string(jit::kVSigmoid) << jit::to_string(jit::kVSquare)
<< jit::to_string(jit::kVSub) << jit::to_string(jit::kVTanh);
EXPECT_EQ(out.str().size(), 234);
// SeqPoolTypes
out.str("");
out << jit::to_string(jit::kSum) << jit::to_string(jit::kAvg)
<< jit::to_string(jit::kSqrt);
EXPECT_EQ(out.str().size(), 13);
EXPECT_EQ(jit::to_kerneltype("relu"), jit::kVRelu);
EXPECT_EQ(jit::to_kerneltype("Identity"), jit::kVIdentity);
EXPECT_EQ(jit::to_kerneltype("VEXP"), jit::kVExp);
EXPECT_EQ(jit::to_kerneltype("SigmoiD"), jit::kVSigmoid);
EXPECT_EQ(jit::to_kerneltype("VTanh"), jit::kVTanh);
out.str("");
out << jit::lstm_attr_t(8, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh);
EXPECT_EQ(out.str().size(), 89);
out.str("");
out << jit::gru_attr_t(8, jit::kVIdentity, jit::kVSigmoid);
EXPECT_EQ(out.str().size(), 52);
out.str("");
out << jit::seq_pool_attr_t(8, jit::SeqPoolType::kSum);
EXPECT_EQ(out.str().size(), 44);
out.str("");
out << jit::emb_seq_pool_attr_t(1, 2, 3, 4, 5, jit::SeqPoolType::kAvg);
EXPECT_EQ(out.str().size(), 93);
out.str("");
out << jit::sgd_attr_t(1, 2, 3, 4, 5);
EXPECT_EQ(out.str().size(), 81);
out.str("");
out << jit::matmul_attr_t(1, 2, 3);
EXPECT_EQ(out.str().size(), 14);
}
// test kernerls
#define TestKernelVMul TestKernelXYZN #define TestKernelVMul TestKernelXYZN
#define TestKernelVAdd TestKernelXYZN #define TestKernelVAdd TestKernelXYZN
#define TestKernelVAddRelu TestKernelXYZN #define TestKernelVAddRelu TestKernelXYZN
...@@ -969,6 +1081,14 @@ TEST_CPU_KERNEL(Softmax); ...@@ -969,6 +1081,14 @@ TEST_CPU_KERNEL(Softmax);
TEST_CPU_KERNEL(Sgd); TEST_CPU_KERNEL(Sgd);
TEST_CPU_KERNEL(VBroadcast); TEST_CPU_KERNEL(VBroadcast);
TEST(JITKernel, kernel_func) {
auto f1 = jit::KernelFuncs<jit::VAddTuple<float>, CPUPlace>::Cache().At(3);
auto f2 = jit::KernelFuncs<jit::VAddTuple<float>, CPUPlace>::Cache()[3];
EXPECT_TRUE(f1 != nullptr);
EXPECT_TRUE(f1 == f2);
// TODO(TJ): check not equal
}
TEST(JITKernel_key, lstm) { TEST(JITKernel_key, lstm) {
jit::lstm_attr_t attr1(8, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh); jit::lstm_attr_t attr1(8, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh);
jit::lstm_attr_t attr2(9, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh); jit::lstm_attr_t attr2(9, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh);
...@@ -1000,11 +1120,3 @@ TEST(JITKernel_key, gru) { ...@@ -1000,11 +1120,3 @@ TEST(JITKernel_key, gru) {
EXPECT_TRUE(key2 == key3); EXPECT_TRUE(key2 == key3);
EXPECT_TRUE(key3 != key4); EXPECT_TRUE(key3 != key4);
} }
TEST(JITKernel, kernel_func) {
auto f1 = jit::KernelFuncs<jit::VAddTuple<float>, CPUPlace>::Cache().At(3);
auto f2 = jit::KernelFuncs<jit::VAddTuple<float>, CPUPlace>::Cache()[3];
EXPECT_TRUE(f1 != nullptr);
EXPECT_TRUE(f1 == f2);
// TODO(TJ): check not equal
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册