提交 45bdd84d 编写于 作者: T tensor-tang

enhance the jitkernel helper and add unit tests

test=develop
上级 14a764c9
......@@ -111,33 +111,11 @@ template <typename KernelTuple, typename PlaceType, typename... Args>
void BenchAllImpls(const typename KernelTuple::attr_type& attr, Args... args) {
BenchFunc<KernelTuple, Args...> benchmark;
std::vector<std::pair<std::string, double>> infos;
// test refer
auto refer = jit::GetRefer<KernelTuple>();
if (!refer) {
LOG(FATAL) << "Refer can not be empty!";
auto funcs = jit::GetAllCandidateFuncsWithTypes<KernelTuple, PlaceType>(attr);
for (auto f : funcs) {
infos.push_back(std::make_pair(f.first, benchmark(f.second, args...)));
}
infos.push_back(std::make_pair("Refer", benchmark(refer, args...)));
// test jitcode
auto jitcode = jit::GetJitCode<KernelTuple, PlaceType>(attr);
if (jitcode) {
infos.push_back(std::make_pair("JitCode", benchmark(jitcode, args...)));
}
// test all impls in more
jit::KernelKey kkey(KernelTuple::kernel_type, PlaceType());
auto& pool = jit::KernelPool().Instance().AllKernels();
auto iter = pool.find(kkey);
if (iter != pool.end()) {
auto& impls = iter->second;
for (auto& impl : impls) {
auto i = dynamic_cast<const jit::KernelMore<KernelTuple>*>(impl.get());
if (i && i->UseMe(attr)) {
auto more = i->GetFunc();
infos.push_back(
std::make_pair(i->ImplType(), benchmark(more, args...)));
}
}
}
// Test result from Get function
auto tgt = jit::KernelFuncs<KernelTuple, PlaceType>::Cache().At(attr);
if (!tgt) {
......
......@@ -81,7 +81,7 @@ void VActJitCode::genCode() {
#define DECLARE_ACT_CREATOR(name) \
class name##Creator : public JitCodeCreator<int> { \
public: \
bool UseMe(const int& attr) const override; \
bool CanBeUsed(const int& attr) const override; \
size_t CodeSize(const int& d) const override; \
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { \
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
......@@ -96,27 +96,27 @@ DECLARE_ACT_CREATOR(VSigmoid);
DECLARE_ACT_CREATOR(VTanh);
// TODO(TJ): tuning use me
bool VReluCreator::UseMe(const int& d) const {
bool VReluCreator::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx);
}
bool VSquareCreator::UseMe(const int& d) const {
bool VSquareCreator::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx);
}
bool VIdentityCreator::UseMe(const int& d) const {
bool VIdentityCreator::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx);
}
bool VExpCreator::UseMe(const int& d) const {
bool VExpCreator::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx) && d < 32;
}
bool VSigmoidCreator::UseMe(const int& d) const {
bool VSigmoidCreator::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx);
}
bool VTanhCreator::UseMe(const int& d) const {
bool VTanhCreator::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx);
}
......
......@@ -142,7 +142,7 @@ void NCHW16CMulNCJitCode::genCode() {
class NCHW16CMulNCCreator : public JitCodeCreator<int> {
public:
bool UseMe(const int& attr) const override {
bool CanBeUsed(const int& attr) const override {
return platform::MayIUse(platform::avx512f);
}
size_t CodeSize(const int& d) const override { return 256 * 1024; }
......@@ -154,7 +154,7 @@ class NCHW16CMulNCCreator : public JitCodeCreator<int> {
#define DECLARE_BLAS_CREATOR(name) \
class name##Creator : public JitCodeCreator<int> { \
public: \
bool UseMe(const int& attr) const override { \
bool CanBeUsed(const int& attr) const override { \
return platform::MayIUse(platform::avx) && attr <= 1024; \
} \
size_t CodeSize(const int& d) const override { \
......
......@@ -121,7 +121,7 @@ void EmbSeqPoolJitCode::genCode() {
class EmbSeqPoolCreator : public JitCodeCreator<emb_seq_pool_attr_t> {
public:
bool UseMe(const emb_seq_pool_attr_t& attr) const override {
bool CanBeUsed(const emb_seq_pool_attr_t& attr) const override {
return platform::MayIUse(platform::avx) &&
attr.table_width % YMM_FLOAT_BLOCK == 0;
}
......
......@@ -86,7 +86,7 @@ void GRUJitCode::genCode() {
class name##Creator : public JitCodeCreator<gru_attr_t> { \
public: \
/* TODO(TJ): enable more */ \
bool UseMe(const gru_attr_t& attr) const override { \
bool CanBeUsed(const gru_attr_t& attr) const override { \
return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \
} \
size_t CodeSize(const gru_attr_t& attr) const override { \
......
......@@ -76,7 +76,7 @@ void HOPVJitCode::genCode() {
#define DECLARE_HOP_CREATOR(name) \
class name##Creator : public JitCodeCreator<int> { \
public: \
bool UseMe(const int& attr) const override { \
bool CanBeUsed(const int& attr) const override { \
return platform::MayIUse(platform::avx); \
} \
size_t CodeSize(const int& d) const override { \
......
......@@ -73,7 +73,7 @@ class JitCode : public GenBase, public Xbyak::CodeGenerator {
virtual void genCode() = 0;
size_t getSize() const override { return CodeGenerator::getSize(); }
const unsigned char* getCodeInternal() override {
const unsigned char* getCodeInternal() const override {
const Xbyak::uint8* code = CodeGenerator::getCode();
return code;
}
......
......@@ -114,7 +114,7 @@ void LSTMJitCode::genCode() {
class name##Creator : public JitCodeCreator<lstm_attr_t> { \
public: \
/* TODO(TJ): enable more */ \
bool UseMe(const lstm_attr_t& attr) const override { \
bool CanBeUsed(const lstm_attr_t& attr) const override { \
return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \
} \
size_t CodeSize(const lstm_attr_t& attr) const override { \
......
......@@ -98,7 +98,7 @@ void MatMulJitCode::genCode() {
class MatMulCreator : public JitCodeCreator<matmul_attr_t> {
public:
bool UseMe(const matmul_attr_t& attr) const override {
bool CanBeUsed(const matmul_attr_t& attr) const override {
return attr.m == 1 && platform::MayIUse(platform::avx512f) &&
attr.n % ZMM_FLOAT_BLOCK == 0 && attr.k < 512;
}
......
......@@ -57,7 +57,7 @@ void SeqPoolJitCode::genCode() {
class SeqPoolCreator : public JitCodeCreator<seq_pool_attr_t> {
public:
bool UseMe(const seq_pool_attr_t& attr) const override {
bool CanBeUsed(const seq_pool_attr_t& attr) const override {
return platform::MayIUse(platform::avx);
}
size_t CodeSize(const seq_pool_attr_t& attr) const override {
......
......@@ -104,7 +104,7 @@ void SgdJitCode::genCode() {
class SgdCreator : public JitCodeCreator<sgd_attr_t> {
public:
bool UseMe(const sgd_attr_t& attr) const override {
bool CanBeUsed(const sgd_attr_t& attr) const override {
return platform::MayIUse(platform::avx) &&
attr.grad_width % YMM_FLOAT_BLOCK == 0;
}
......
......@@ -69,7 +69,7 @@ void VBroadcastJitCode::genCode() {
class VBroadcastCreator : public JitCodeCreator<int64_t> {
public:
bool UseMe(const int64_t& w) const override {
bool CanBeUsed(const int64_t& w) const override {
return platform::MayIUse(platform::avx) && w % YMM_FLOAT_BLOCK == 0;
}
size_t CodeSize(const int64_t& w) const override {
......
......@@ -31,7 +31,7 @@ namespace paddle {
namespace operators {
namespace jit {
// refer do not need useme, it would be the last one.
// refer do not need CanBeUsed, it would be the last one.
void GenBase::dumpCode(const unsigned char* code) const {
if (code) {
static int counter = 0;
......
......@@ -31,9 +31,10 @@ class GenBase : public Kernel {
virtual ~GenBase() = default;
virtual std::string name() const = 0;
virtual size_t getSize() const = 0;
virtual const unsigned char* getCodeInternal() = 0;
virtual const unsigned char* getCodeInternal() const = 0;
const char* ImplType() const override { return "JitCode"; }
template <typename Func>
Func getCode() {
Func getCode() const {
const unsigned char* code = this->getCodeInternal();
if (FLAGS_dump_jitcode) {
this->dumpCode(code);
......@@ -65,7 +66,7 @@ class JitCodeCreator : public GenCreator {
virtual ~JitCodeCreator() = default;
// condition when this jit code can be used.
virtual bool UseMe(const Attr& attr) const = 0;
virtual bool CanBeUsed(const Attr& attr) const = 0;
// estimate this code size
virtual size_t CodeSize(const Attr& attr) const = 0;
......
......@@ -14,9 +14,6 @@
#pragma once
extern "C" {
#include <xxhash.h>
}
#include <iostream>
#include <string>
#include <unordered_map>
......@@ -36,31 +33,30 @@ template <typename KernelTuple, typename PlaceType>
inline typename std::enable_if<
std::is_same<typename KernelTuple::data_type, float>::value &&
std::is_same<PlaceType, platform::CPUPlace>::value,
typename KernelTuple::func_type>::type
const Kernel*>::type
GetJitCode(const typename KernelTuple::attr_type& attr) {
using Func = typename KernelTuple::func_type;
using Attr = typename KernelTuple::attr_type;
size_t key = JitCodeKey<Attr>(attr);
auto& codes = JitCodePool<KernelTuple::kernel_type>().Instance();
auto& codes = JitCodePool<KernelTuple::kernel_type>::Instance();
if (codes.Has(key)) {
return codes.AllKernels().at(key)->template getCode<Func>();
return codes.AllKernels().at(key).get();
}
// creator is not related with attr, so can use KernelKey as key
KernelKey kkey(KernelTuple::kernel_type, PlaceType());
// pool: (KernelKey(type, place), vector<GenCreatorPtr>)
auto& creator_map = JitCodeCreatorPool().Instance().AllCreators();
auto& creator_map = JitCodeCreatorPool::Instance().AllCreators();
auto iter = creator_map.find(kkey);
if (iter != creator_map.end()) {
auto& creators = iter->second;
for (auto& cur : creators) {
auto i = dynamic_cast<const JitCodeCreator<Attr>*>(cur.get());
if (i && i->UseMe(attr)) {
if (i && i->CanBeUsed(attr)) {
auto p = i->CreateJitCode(attr);
if (p) {
auto f = p->template getCode<Func>();
auto res = p.get();
codes.Insert(key, std::move(p));
return f;
return res;
}
}
}
......@@ -72,7 +68,7 @@ template <typename KernelTuple, typename PlaceType>
inline typename std::enable_if<
!std::is_same<typename KernelTuple::data_type, float>::value ||
!std::is_same<PlaceType, platform::CPUPlace>::value,
typename KernelTuple::func_type>::type
const Kernel*>::type
GetJitCode(const typename KernelTuple::attr_type& attr) {
return nullptr;
}
......@@ -80,8 +76,8 @@ GetJitCode(const typename KernelTuple::attr_type& attr) {
// Refer code do not related with attr, which is just for cast
// Refer is always on CPUPlace
template <typename KernelTuple>
inline typename KernelTuple::func_type GetRefer() {
auto& ref_pool = ReferKernelPool().Instance().AllKernels();
inline const Kernel* GetReferKernel() {
auto& ref_pool = ReferKernelPool::Instance().AllKernels();
KernelKey kkey(KernelTuple::kernel_type, platform::CPUPlace());
auto ref_iter = ref_pool.find(kkey);
PADDLE_ENFORCE(ref_iter != ref_pool.end(),
......@@ -90,36 +86,93 @@ inline typename KernelTuple::func_type GetRefer() {
for (auto& impl : ref_impls) {
auto i = dynamic_cast<const ReferKernel<KernelTuple>*>(impl.get());
if (i) {
return i->GetFunc();
return i;
}
}
return nullptr;
}
template <typename KernelTuple, typename PlaceType = platform::CPUPlace>
typename KernelTuple::func_type Get(
template <typename KernelTuple>
inline typename KernelTuple::func_type GetReferFunc() {
auto ker = GetReferKernel<KernelTuple>();
auto p = dynamic_cast<const ReferKernel<KernelTuple>*>(ker);
PADDLE_ENFORCE(p, "The Refer kernel should exsit");
return p->GetFunc();
}
// Return all Kernels that can be used
template <typename KernelTuple, typename PlaceType>
std::vector<const Kernel*> GetAllCandidateKernels(
const typename KernelTuple::attr_type& attr) {
auto jitfunc = GetJitCode<KernelTuple, PlaceType>(attr);
if (jitfunc) {
return jitfunc;
// the search order shoudl be jitcode > more > refer
std::vector<const Kernel*> res;
auto jitker = GetJitCode<KernelTuple, PlaceType>(attr);
if (jitker) {
res.emplace_back(jitker);
}
// pool: (KernelKey(type, place), vector<KernelPtr>)
// more kernelpool: (KernelKey(type, place), vector<KernelPtr>)
KernelKey kkey(KernelTuple::kernel_type, PlaceType());
auto& pool = KernelPool().Instance().AllKernels();
auto& pool = KernelPool::Instance().AllKernels();
auto iter = pool.find(kkey);
if (iter != pool.end()) {
auto& impls = iter->second;
for (auto& impl : impls) {
auto i = dynamic_cast<const KernelMore<KernelTuple>*>(impl.get());
if (i && i->UseMe(attr)) {
return i->GetFunc();
if (i && i->CanBeUsed(attr)) {
res.emplace_back(i);
}
}
}
// The last implementation should be reference function on CPUPlace.
return GetRefer<KernelTuple>();
auto ref = GetReferKernel<KernelTuple>();
PADDLE_ENFORCE(ref != nullptr, "Refer Kernel can not be empty.");
res.emplace_back(ref);
return res;
}
template <typename KernelTuple, typename PlaceType = platform::CPUPlace>
std::vector<std::pair<std::string, typename KernelTuple::func_type>>
GetAllCandidateFuncsWithTypes(const typename KernelTuple::attr_type& attr) {
using Func = typename KernelTuple::func_type;
auto kers = GetAllCandidateKernels<KernelTuple, PlaceType>(attr);
std::vector<std::pair<std::string, Func>> res;
for (auto k : kers) {
std::string name = k->ImplType();
if (name == "JitCode") {
auto i = dynamic_cast<const GenBase*>(k);
PADDLE_ENFORCE(i, "jitcode kernel cast can not fail.");
res.emplace_back(std::make_pair(name, i->template getCode<Func>()));
} else {
auto i = dynamic_cast<const KernelMore<KernelTuple>*>(k);
PADDLE_ENFORCE(i, "kernel cast can not fail.");
res.emplace_back(std::make_pair(name, i->GetFunc()));
}
}
return res;
}
template <typename KernelTuple, typename PlaceType = platform::CPUPlace>
std::vector<typename KernelTuple::func_type> GetAllCandidateFuncs(
const typename KernelTuple::attr_type& attr) {
auto funcs = GetAllCandidateFuncsWithTypes<KernelTuple, PlaceType>(attr);
std::vector<typename KernelTuple::func_type> res;
for (auto& i : funcs) {
res.emplace_back(i.second);
}
return res;
}
template <typename KernelTuple, typename PlaceType = platform::CPUPlace>
typename KernelTuple::func_type GetDefaultBestFunc(
const typename KernelTuple::attr_type& attr) {
auto funcs = GetAllCandidateFuncs<KernelTuple, PlaceType>(attr);
PADDLE_ENFORCE_GE(funcs.size(), 1UL);
// Here could do some runtime benchmark of this attr and return the best one.
// But yet just get the first one as the default best one,
// which is searched in order and tuned by offline.
return funcs[0];
}
template <typename KernelTuple, typename PlaceType>
......@@ -134,17 +187,13 @@ class KernelFuncs {
// the exposed interface to use
typename KernelTuple::func_type At(
const typename KernelTuple::attr_type& attr) {
// XXH64: 13.8 GB/s
// TODO(TJ): change me, maybe not all attr change need one key, should be
// attrkey
int64_t key = XXH64(&attr, sizeof(typename KernelTuple::attr_type), 0);
// Maybe here is not good enough, not all kernels should have jitcode
int64_t key = JitCodeKey<typename KernelTuple::attr_type>(attr);
if (Has(key)) {
return funcs_.at(key);
}
// If do not have this attr in cache,
// then could run some runtime benchmark of this attr and save the best one.
// Here just get the offline benchmarked best one.
auto func = Get<KernelTuple, PlaceType>(attr);
// If do not have this attr in cache then get the default best
auto func = GetDefaultBestFunc<KernelTuple, PlaceType>(attr);
Insert(key, func);
return func;
}
......@@ -156,7 +205,6 @@ class KernelFuncs {
protected:
bool Has(int64_t key) const { return funcs_.find(key) != funcs_.end(); }
void Insert(int64_t key, typename KernelTuple::func_type func) {
funcs_.emplace(key, func);
}
......
......@@ -302,6 +302,7 @@ class Kernel {
public:
Kernel() = default;
virtual ~Kernel() = default;
virtual const char* ImplType() const = 0;
DISABLE_COPY_AND_ASSIGN(Kernel);
};
......@@ -312,8 +313,8 @@ class KernelMore : public Kernel {
using Func = typename KernelTuple::func_type;
using Attr = typename KernelTuple::attr_type;
virtual Func GetFunc() const { return func; }
virtual bool UseMe(const Attr& attr) const = 0;
virtual const char* ImplType() const = 0;
// specify this kernel can be used, means it should not fail if use it.
virtual bool CanBeUsed(const Attr& attr) const = 0;
protected:
Func func{nullptr};
......@@ -323,7 +324,7 @@ template <typename KernelTuple>
class ReferKernel : public KernelMore<KernelTuple> {
public:
// Refer code can always be used
bool UseMe(const typename KernelTuple::attr_type& attr) const override {
bool CanBeUsed(const typename KernelTuple::attr_type& attr) const override {
return true;
}
const char* ImplType() const override { return "Refer"; }
......
......@@ -13,6 +13,7 @@
* limitations under the License. */
#include "paddle/fluid/operators/jit/kernel_key.h"
#include <xxhash.h>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
......@@ -49,6 +50,8 @@ static inline int act_type_convert(KernelType type) {
template <>
size_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) {
// XXH64: 13.8 GB/s
size_t key = attr.d;
int gate_key = act_type_convert(attr.act_gate) << 1;
int cand_key = act_type_convert(attr.act_cand) << (1 + act_type_shift);
......
......@@ -161,7 +161,7 @@ void CRFDecoding(const int seq_len, const float* x, const float* w,
}
}
bool CRFDecodingKernel::UseMe(const int& d) const {
bool CRFDecodingKernel::CanBeUsed(const int& d) const {
#ifdef __AVX512F__
constexpr int block = ZMM_FLOAT_BLOCK;
#else
......
......@@ -29,7 +29,8 @@ void CRFDecoding(const int seq_len, const float* x, const float* w,
class CRFDecodingKernel : public KernelMore<CRFDecodingTuple<float>> {
public:
CRFDecodingKernel() { this->func = CRFDecoding; }
bool UseMe(const typename CRFDecodingTuple<float>::attr_type&) const override;
bool CanBeUsed(
const typename CRFDecodingTuple<float>::attr_type&) const override;
const char* ImplType() const override { return "Intrinsic"; }
};
......
......@@ -153,7 +153,7 @@ void LayerNorm(float* x, float* out, float* mean, float* var,
}
}
bool LayerNormKernel::UseMe(const int& d) const {
bool LayerNormKernel::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx) && d >= YMM_FLOAT_BLOCK;
}
......
......@@ -30,7 +30,8 @@ void LayerNorm(float* x, float* out, float* mean, float* var,
class LayerNormKernel : public KernelMore<LayerNormTuple<float>> {
public:
LayerNormKernel() { this->func = LayerNorm; }
bool UseMe(const typename LayerNormTuple<float>::attr_type&) const override;
bool CanBeUsed(
const typename LayerNormTuple<float>::attr_type&) const override;
const char* ImplType() const override { return "Intrinsic"; }
};
......
......@@ -204,21 +204,21 @@ void GRUHtPart2(gru_t* step, const gru_attr_t* attr) {
}
// TODO(TJ): tuning me
bool VSigmoidKernel::UseMe(const int& d) const { return true; }
bool VSigmoidKernel::CanBeUsed(const int& d) const { return true; }
bool VTanhKernel::UseMe(const int& d) const { return true; }
bool VTanhKernel::CanBeUsed(const int& d) const { return true; }
bool SoftmaxKernel::UseMe(const int& d) const { return true; }
bool SoftmaxKernel::CanBeUsed(const int& d) const { return true; }
bool LSTMCtHtKernel::UseMe(const lstm_attr_t& attr) const { return true; }
bool LSTMCtHtKernel::CanBeUsed(const lstm_attr_t& attr) const { return true; }
bool LSTMC1H1Kernel::UseMe(const lstm_attr_t& attr) const { return true; }
bool LSTMC1H1Kernel::CanBeUsed(const lstm_attr_t& attr) const { return true; }
bool GRUH1Kernel::UseMe(const gru_attr_t& attr) const { return true; }
bool GRUH1Kernel::CanBeUsed(const gru_attr_t& attr) const { return true; }
bool GRUHtPart1Kernel::UseMe(const gru_attr_t& attr) const { return true; }
bool GRUHtPart1Kernel::CanBeUsed(const gru_attr_t& attr) const { return true; }
bool GRUHtPart2Kernel::UseMe(const gru_attr_t& attr) const { return true; }
bool GRUHtPart2Kernel::CanBeUsed(const gru_attr_t& attr) const { return true; }
} // namespace mix
} // namespace more
......
......@@ -38,7 +38,7 @@ void GRUHtPart2(gru_t* step, const gru_attr_t* attr);
class name##Kernel : public KernelMore<name##Tuple<T>> { \
public: \
name##Kernel() { this->func = name; } \
bool UseMe(const typename name##Tuple<T>::attr_type&) const override; \
bool CanBeUsed(const typename name##Tuple<T>::attr_type&) const override; \
const char* ImplType() const override { return "Mixed"; } \
}
......
......@@ -130,104 +130,105 @@ void ASum<double>(const double* x, double* res, int n) {
// TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512
template <>
bool VMulKernel<float>::UseMe(const int& d) const {
bool VMulKernel<float>::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx512f) && d > 512;
}
template <>
bool VAddKernel<float>::UseMe(const int& d) const {
bool VAddKernel<float>::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx) && d > 512;
}
template <>
bool VScalKernel<float>::UseMe(const int& d) const {
bool VScalKernel<float>::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx512f) && d > 512;
}
template <>
bool VExpKernel<float>::UseMe(const int& d) const {
bool VExpKernel<float>::CanBeUsed(const int& d) const {
return d > 7;
}
template <>
bool VSquareKernel<float>::UseMe(const int& d) const {
bool VSquareKernel<float>::CanBeUsed(const int& d) const {
return d > 7;
}
template <>
bool VCopyKernel<float>::UseMe(const int& d) const {
bool VCopyKernel<float>::CanBeUsed(const int& d) const {
return d > 15;
}
template <>
bool VBroadcastKernel<float>::UseMe(const int64_t& d) const {
bool VBroadcastKernel<float>::CanBeUsed(const int64_t& d) const {
return d > 127;
}
template <>
bool VBroadcastKernel<double>::UseMe(const int64_t& attr) const {
bool VBroadcastKernel<double>::CanBeUsed(const int64_t& attr) const {
return true;
}
template <>
bool VSigmoidKernel<float>::UseMe(const int& d) const {
bool VSigmoidKernel<float>::CanBeUsed(const int& d) const {
return d > 7;
}
template <>
bool VTanhKernel<float>::UseMe(const int& d) const {
bool VTanhKernel<float>::CanBeUsed(const int& d) const {
return d > 7;
}
template <>
bool SeqPoolKernel<float>::UseMe(const seq_pool_attr_t& attr) const {
bool SeqPoolKernel<float>::CanBeUsed(const seq_pool_attr_t& attr) const {
return true;
}
template <>
bool SeqPoolKernel<double>::UseMe(const seq_pool_attr_t& attr) const {
bool SeqPoolKernel<double>::CanBeUsed(const seq_pool_attr_t& attr) const {
return true;
}
template <>
bool EmbSeqPoolKernel<float>::UseMe(const emb_seq_pool_attr_t& attr) const {
bool EmbSeqPoolKernel<float>::CanBeUsed(const emb_seq_pool_attr_t& attr) const {
return true;
}
template <>
bool EmbSeqPoolKernel<double>::UseMe(const emb_seq_pool_attr_t& attr) const {
bool EmbSeqPoolKernel<double>::CanBeUsed(
const emb_seq_pool_attr_t& attr) const {
return true;
}
template <>
bool SgdKernel<float>::UseMe(const sgd_attr_t& attr) const {
bool SgdKernel<float>::CanBeUsed(const sgd_attr_t& attr) const {
return true;
}
template <>
bool SgdKernel<double>::UseMe(const sgd_attr_t& attr) const {
bool SgdKernel<double>::CanBeUsed(const sgd_attr_t& attr) const {
return true;
}
template <>
bool MatMulKernel<float>::UseMe(const matmul_attr_t& attr) const {
bool MatMulKernel<float>::CanBeUsed(const matmul_attr_t& attr) const {
return platform::MayIUse(platform::avx);
}
template <>
bool MatMulKernel<double>::UseMe(const matmul_attr_t& attr) const {
bool MatMulKernel<double>::CanBeUsed(const matmul_attr_t& attr) const {
return true;
}
template <>
bool SoftmaxKernel<float>::UseMe(const int& d) const {
bool SoftmaxKernel<float>::CanBeUsed(const int& d) const {
// tuned on avx2
return platform::MayIUse(platform::avx) && d < 60;
}
#define AWALYS_USE_ME_WITH_DOUBLE(func) \
template <> \
bool func##Kernel<double>::UseMe(const int& d) const { \
bool func##Kernel<double>::CanBeUsed(const int& d) const { \
return true; \
}
......
......@@ -180,7 +180,7 @@ void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows,
class name##Kernel : public KernelMore<name##Tuple<T>> { \
public: \
name##Kernel() { this->func = name<T>; } \
bool UseMe(const typename name##Tuple<T>::attr_type&) const override; \
bool CanBeUsed(const typename name##Tuple<T>::attr_type&) const override; \
const char* ImplType() const override { return "MKL"; } \
}
......
......@@ -49,7 +49,7 @@ struct JitKernelRegistrarFunctor<Pool, PlaceType, false, I, KernelImpls...> {
void operator()(KernelType kt) const {
KernelKey kkey(kt, PlaceType());
Pool().Instance().Insert(kkey,
Pool::Instance().Insert(kkey,
std::move(make_unique<const KERNEL_IMPL_TYPE>()));
constexpr auto size = std::tuple_size<std::tuple<KernelImpls...>>::value;
JitKernelRegistrarFunctor<Pool, PlaceType, I + 1 == size, I + 1,
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <iostream>
#include <random>
#include <string>
#include <vector>
......@@ -68,31 +69,11 @@ template <typename KernelTuple, typename PlaceType, typename Tester,
typename... Args>
void TestAllImpls(const typename KernelTuple::attr_type& attr,
const Tester& verifier, const Args&... args) {
// test jitcode
auto jitcode = jit::GetJitCode<KernelTuple, PlaceType>(attr);
if (jitcode) {
VLOG(10) << "Test Jitcode Kernel ";
verifier(jitcode, args...);
}
// test all impls in more
jit::KernelKey kkey(KernelTuple::kernel_type, PlaceType());
auto& pool = jit::KernelPool().Instance().AllKernels();
auto iter = pool.find(kkey);
if (iter != pool.end()) {
auto& impls = iter->second;
for (auto& impl : impls) {
auto i = dynamic_cast<const jit::KernelMore<KernelTuple>*>(impl.get());
if (i && i->UseMe(attr)) {
auto more = i->GetFunc();
VLOG(10) << "Test More Kernel : " << i->ImplType();
verifier(more, args...);
}
}
}
// test result from Get function
VLOG(10) << "Test final get function ";
auto tgt = jit::KernelFuncs<KernelTuple, PlaceType>::Cache().At(attr);
verifier(tgt, args...);
auto funcs = jit::GetAllCandidateFuncsWithTypes<KernelTuple, PlaceType>(attr);
for (auto f : funcs) {
VLOG(10) << "Test Kernel " << f.first;
verifier(f.second, args...);
}
}
template <typename KernelTuple, typename PlaceType>
......@@ -100,7 +81,7 @@ void TestKernelXYZN() {
using T = typename KernelTuple::data_type;
VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type);
for (int d : TestSizes()) {
auto ref = jit::GetRefer<KernelTuple>();
auto ref = jit::GetReferFunc<KernelTuple>();
EXPECT_TRUE(ref != nullptr);
std::vector<T> x(d), y(d), zref(d);
......@@ -159,7 +140,7 @@ void TestKernelAXYN() {
using T = typename KernelTuple::data_type;
VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type);
for (int d : TestSizes()) {
auto ref = jit::GetRefer<KernelTuple>();
auto ref = jit::GetReferFunc<KernelTuple>();
EXPECT_TRUE(ref != nullptr);
const T a = static_cast<T>(3);
......@@ -202,7 +183,7 @@ void TestKernelXYN() {
using T = typename KernelTuple::data_type;
VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type);
for (int d : TestSizes()) {
auto ref = jit::GetRefer<KernelTuple>();
auto ref = jit::GetReferFunc<KernelTuple>();
EXPECT_TRUE(ref != nullptr);
std::vector<T> x(d), yref(d);
......@@ -245,7 +226,7 @@ void TestKernelXRN() {
auto last_acc = FLAGS_acc;
FLAGS_acc = 1e-4;
for (int d : TestSizes()) {
auto ref = jit::GetRefer<KernelTuple>();
auto ref = jit::GetReferFunc<KernelTuple>();
EXPECT_TRUE(ref != nullptr);
std::vector<T> x(d);
RandomVec<T>(d, x.data());
......@@ -279,7 +260,7 @@ void TestKernelLSTM() {
const jit::lstm_attr_t attr(
d, jit::to_kerneltype(act_gate), jit::to_kerneltype(act_cand),
jit::to_kerneltype(act_cell), use_peephole);
auto ref = jit::GetRefer<KernelTuple>();
auto ref = jit::GetReferFunc<KernelTuple>();
EXPECT_TRUE(ref != nullptr);
std::vector<T> xsrc(4 * d), wp(3 * d), ct_1(d);
std::vector<T> ct_ref(d), ht_ref(d), checked(2 * d);
......@@ -370,7 +351,7 @@ void TestKernelGRU() {
for (auto& act_cand : all_acts) {
const jit::gru_attr_t attr(d, jit::to_kerneltype(act_gate),
jit::to_kerneltype(act_cand));
auto ref = jit::GetRefer<KernelTuple>();
auto ref = jit::GetReferFunc<KernelTuple>();
EXPECT_TRUE(ref != nullptr);
std::vector<T> xsrc(3 * d), ht_1(d), ht_ref(d);
RandomVec<T>(3 * d, xsrc.data());
......@@ -423,7 +404,7 @@ void TestKernelNCHW16CMulNC() {
using T = typename KernelTuple::data_type;
VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type);
const int n = 3, c = 16 * 4, h = 10, w = 10;
auto ref = jit::GetRefer<KernelTuple>();
auto ref = jit::GetReferFunc<KernelTuple>();
EXPECT_TRUE(ref != nullptr);
int sz = n * c * h * w;
std::vector<T> x(sz), y(n * c), zref(sz);
......@@ -439,7 +420,9 @@ void TestKernelNCHW16CMulNC() {
constexpr int simd_width = ZMM_FLOAT_BLOCK;
int C = c / simd_width;
auto tgt = jit::KernelFuncs<KernelTuple, PlaceType>::Cache().At(0);
auto jitcode = jit::GetJitCode<KernelTuple, PlaceType>(0);
auto funcs = jit::GetAllCandidateFuncs<KernelTuple, PlaceType>(0);
EXPECT_GT(funcs.size(), 0UL);
auto jitcode = funcs[0];
EXPECT_TRUE(tgt != nullptr);
if (std::is_same<T, float>::value &&
......@@ -482,7 +465,7 @@ void TestKernelLayerNorm() {
int left = n * x_dim_0;
for (int x_dim_1 : TestSizes()) {
int right = x_dim_1;
auto ref = jit::GetRefer<KernelTuple>();
auto ref = jit::GetReferFunc<KernelTuple>();
EXPECT_TRUE(ref != nullptr);
int sz = left * right;
std::vector<T> x(sz), mean(left), var(left), scale(right), bias(right),
......@@ -555,7 +538,7 @@ void TestKernelCRFDecoding() {
test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 2000));
for (int seq_len : {1, 11, 17, 50}) {
for (int tag_num : test_sizes) {
auto ref = jit::GetRefer<KernelTuple>();
auto ref = jit::GetReferFunc<KernelTuple>();
EXPECT_TRUE(ref != nullptr);
int x_sz = seq_len * tag_num;
int w_sz = (tag_num + state_trans_base_idx) * tag_num;
......@@ -606,7 +589,7 @@ void TestKernelSeqPool() {
jit::seq_pool_attr_t attr(w, type);
for (int h : test_sizes) {
attr.h = h;
auto ref = jit::GetRefer<KernelTuple>();
auto ref = jit::GetReferFunc<KernelTuple>();
EXPECT_TRUE(ref != nullptr);
std::vector<T> x(h * w), yref(w);
RandomVec<T>(h * w, x.data());
......@@ -649,7 +632,7 @@ void TestKernelEmbSeqPool() {
for (auto type : pool_types) {
for (int idx_w : {1, 2, 10, 16}) {
for (int idx_h : {1, 2, 9, 13, 16}) {
auto ref = jit::GetRefer<KernelTuple>();
auto ref = jit::GetReferFunc<KernelTuple>();
EXPECT_TRUE(ref != nullptr);
std::vector<int64_t> idx(idx_h * idx_w);
RandomVec<int64_t>(idx_h * idx_w, idx.data(), 0, tbl_h - 1);
......@@ -701,7 +684,7 @@ void TestKernelMatMul() {
for (int m : {1, 2, 3, 4}) {
for (int n : {1, 2, 3, 4}) {
for (int k : TestSizes()) {
auto ref = jit::GetRefer<KernelTuple>();
auto ref = jit::GetReferFunc<KernelTuple>();
EXPECT_TRUE(ref != nullptr);
std::vector<T> a(m * k), b(k * n), c(m * n);
RandomVec<T>(m * k, a.data());
......@@ -740,7 +723,7 @@ void TestKernelSoftmax() {
VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type);
for (int bs : {1, 2, 10}) {
for (int n : TestSizes()) {
auto ref = jit::GetRefer<KernelTuple>();
auto ref = jit::GetReferFunc<KernelTuple>();
EXPECT_TRUE(ref != nullptr);
std::vector<T> x(bs * n), y(bs * n);
RandomVec<T>(bs * n, x.data());
......@@ -808,7 +791,7 @@ void TestKernelSgd() {
RandomVec<T>(rows_size * grad_w, grad.data());
const int64_t* rows_data = rows.data();
const T* grad_data = grad.data();
auto ref = jit::GetRefer<KernelTuple>();
auto ref = jit::GetReferFunc<KernelTuple>();
EXPECT_TRUE(ref != nullptr);
jit::sgd_attr_t attr(param_h, grad_w, rows_size, grad_w, rows_size);
ref(&lr, param_data, grad_data, rows_data, out_data, &attr);
......@@ -874,7 +857,7 @@ void TestKernelVBroadcast() {
RandomVec<T>(w, x.data());
const T* x_data = x.data();
for (int64_t h : {1, 2, 6}) {
auto ref = jit::GetRefer<KernelTuple>();
auto ref = jit::GetReferFunc<KernelTuple>();
EXPECT_TRUE(ref != nullptr);
std::vector<T> y(w * h);
T* y_data = y.data();
......@@ -900,6 +883,135 @@ void TestKernelVBroadcast() {
}
}
// test pool
TEST(JITKernel_pool, jitcreator) {
const auto& jitcreators = jit::JitCodeCreatorPool::Instance().AllCreators();
EXPECT_EQ(jitcreators.size(), 25UL);
}
TEST(JITKernel_pool, jitpool) {
// jitpool is related with attr
const auto& kers = jit::JitCodePool<jit::kVAdd>().Instance().AllKernels();
EXPECT_EQ(kers.size(), 0UL);
jit::GetAllCandidateKernels<jit::VAddTuple<float>, CPUPlace>(3);
// after call GetAllCandidateKernels, it will create jitcode Automatically
EXPECT_EQ(kers.size(), 1UL);
}
TEST(JITKernel_pool, more) {
const auto& kers = jit::KernelPool::Instance().AllKernels();
EXPECT_EQ(kers.size(), 21UL);
}
TEST(JITKernel_pool, refer) {
const auto& kers = jit::ReferKernelPool::Instance().AllKernels();
EXPECT_EQ(kers.size(), 29UL);
}
// test helper
TEST(JITKernel_helper, GetAllCandidateKernels) {
auto fp_kers =
jit::GetAllCandidateKernels<jit::VExpTuple<float>, CPUPlace>(10);
#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__)
EXPECT_GE(fp_kers.size(), 1UL); // refer
#else
EXPECT_GE(fp_kers.size(), 3UL); // jitcode, mkl, refer
#endif
auto db_kers =
jit::GetAllCandidateKernels<jit::VExpTuple<double>, CPUPlace>(10);
#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__)
EXPECT_GE(db_kers.size(), 1UL); // refer
#else
EXPECT_GE(db_kers.size(), 2UL); // mkl, refer
#endif
}
TEST(JITKernel_helper, GetAllCandidateFuncsWithTypes) {
auto fp_kers =
jit::GetAllCandidateFuncsWithTypes<jit::VExpTuple<float>, CPUPlace>(10);
EXPECT_GE(fp_kers.size(), 3UL); // jitcode, mkl, refer
auto db_kers =
jit::GetAllCandidateFuncsWithTypes<jit::VExpTuple<double>, CPUPlace>(10);
EXPECT_GE(db_kers.size(), 2UL); // mkl, refer
}
TEST(JITKernel_helper, GetAllCandidateFuncs) {
auto funcs = jit::GetAllCandidateFuncs<jit::VExpTuple<float>, CPUPlace>(10);
auto kers = jit::GetAllCandidateKernels<jit::VExpTuple<float>, CPUPlace>(10);
EXPECT_EQ(funcs.size(), kers.size());
std::vector<float> x(10), tgt(10);
RandomVec<float>(10, x.data());
auto best = jit::GetDefaultBestFunc<jit::VExpTuple<float>, CPUPlace>(10);
best(x.data(), tgt.data(), 10);
for (auto f : funcs) {
std::vector<float> y(10);
f(x.data(), y.data(), 10);
ExpectEQ<float>(y.data(), tgt.data(), 10);
}
}
TEST(JITKernel_helper, attr) {
std::ostringstream out;
// KernelTypes
out << jit::to_string(jit::kNone) << jit::to_string(jit::kCRFDecoding)
<< jit::to_string(jit::kEmbSeqPool) << jit::to_string(jit::kGRUH1)
<< jit::to_string(jit::kGRUHtPart1) << jit::to_string(jit::kGRUHtPart2)
<< jit::to_string(jit::kHSum) << jit::to_string(jit::kHMax)
<< jit::to_string(jit::kLSTMCtHt) << jit::to_string(jit::kLSTMC1H1)
<< jit::to_string(jit::kLayerNorm) << jit::to_string(jit::kMatMul)
<< jit::to_string(jit::kNCHW16CMulNC) << jit::to_string(jit::kSeqPool)
<< jit::to_string(jit::kSoftmax) << jit::to_string(jit::kVAdd)
<< jit::to_string(jit::kVAddBias) << jit::to_string(jit::kVAddRelu)
<< jit::to_string(jit::kVBroadcast) << jit::to_string(jit::kVCopy)
<< jit::to_string(jit::kVExp) << jit::to_string(jit::kVIdentity)
<< jit::to_string(jit::kVMul) << jit::to_string(jit::kVRelu)
<< jit::to_string(jit::kVScal) << jit::to_string(jit::kSgd)
<< jit::to_string(jit::kVSigmoid) << jit::to_string(jit::kVSquare)
<< jit::to_string(jit::kVSub) << jit::to_string(jit::kVTanh);
EXPECT_EQ(out.str().size(), 234);
// SeqPoolTypes
out.str("");
out << jit::to_string(jit::kSum) << jit::to_string(jit::kAvg)
<< jit::to_string(jit::kSqrt);
EXPECT_EQ(out.str().size(), 13);
EXPECT_EQ(jit::to_kerneltype("relu"), jit::kVRelu);
EXPECT_EQ(jit::to_kerneltype("Identity"), jit::kVIdentity);
EXPECT_EQ(jit::to_kerneltype("VEXP"), jit::kVExp);
EXPECT_EQ(jit::to_kerneltype("SigmoiD"), jit::kVSigmoid);
EXPECT_EQ(jit::to_kerneltype("VTanh"), jit::kVTanh);
out.str("");
out << jit::lstm_attr_t(8, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh);
EXPECT_EQ(out.str().size(), 89);
out.str("");
out << jit::gru_attr_t(8, jit::kVIdentity, jit::kVSigmoid);
EXPECT_EQ(out.str().size(), 52);
out.str("");
out << jit::seq_pool_attr_t(8, jit::SeqPoolType::kSum);
EXPECT_EQ(out.str().size(), 44);
out.str("");
out << jit::emb_seq_pool_attr_t(1, 2, 3, 4, 5, jit::SeqPoolType::kAvg);
EXPECT_EQ(out.str().size(), 93);
out.str("");
out << jit::sgd_attr_t(1, 2, 3, 4, 5);
EXPECT_EQ(out.str().size(), 81);
out.str("");
out << jit::matmul_attr_t(1, 2, 3);
EXPECT_EQ(out.str().size(), 14);
}
// test kernerls
#define TestKernelVMul TestKernelXYZN
#define TestKernelVAdd TestKernelXYZN
#define TestKernelVAddRelu TestKernelXYZN
......@@ -969,6 +1081,14 @@ TEST_CPU_KERNEL(Softmax);
TEST_CPU_KERNEL(Sgd);
TEST_CPU_KERNEL(VBroadcast);
TEST(JITKernel, kernel_func) {
auto f1 = jit::KernelFuncs<jit::VAddTuple<float>, CPUPlace>::Cache().At(3);
auto f2 = jit::KernelFuncs<jit::VAddTuple<float>, CPUPlace>::Cache()[3];
EXPECT_TRUE(f1 != nullptr);
EXPECT_TRUE(f1 == f2);
// TODO(TJ): check not equal
}
TEST(JITKernel_key, lstm) {
jit::lstm_attr_t attr1(8, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh);
jit::lstm_attr_t attr2(9, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh);
......@@ -1000,11 +1120,3 @@ TEST(JITKernel_key, gru) {
EXPECT_TRUE(key2 == key3);
EXPECT_TRUE(key3 != key4);
}
TEST(JITKernel, kernel_func) {
auto f1 = jit::KernelFuncs<jit::VAddTuple<float>, CPUPlace>::Cache().At(3);
auto f2 = jit::KernelFuncs<jit::VAddTuple<float>, CPUPlace>::Cache()[3];
EXPECT_TRUE(f1 != nullptr);
EXPECT_TRUE(f1 == f2);
// TODO(TJ): check not equal
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册