提交 900c789a 编写于 作者: T tensor-tang

use jitcode and use vmul

上级 53709e7e
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/operators/jit/gen/blas.h" #include "paddle/fluid/operators/jit/gen/blas.h"
#include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -103,17 +104,24 @@ void VXXJitCode::genCode() { ...@@ -103,17 +104,24 @@ void VXXJitCode::genCode() {
ret(); ret();
} }
} // namespace gen class VMulCreator : public JitCodeCreator<int> {
public:
template <> bool UseMe(const int& attr) const override {
std::unique_ptr<GenBase> CreateJitCode<KernelType::vmul, float, int>(int attr) { return platform::MayIUse(platform::avx);
if (UseJitCode<KernelType::vmul, float, int>(attr)) {
return make_unique<gen::VMulJitCode>(
attr, CodeSize<KernelType::vmul, float, int>(attr));
} }
return nullptr; size_t CodeSize(const int& d) const override {
} return 96 + d / YMM_FLOAT_BLOCK * 4 * 8;
}
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override {
return make_unique<VMulJitCode>(attr, CodeSize(attr));
}
};
} // namespace gen
} // namespace jit } // namespace jit
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(vmul, gen::VMulCreator);
...@@ -25,7 +25,18 @@ namespace gen { ...@@ -25,7 +25,18 @@ namespace gen {
// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu) // function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu)
class VXXJitCode : public JitCode { class VXXJitCode : public JitCode {
public: public:
const char* name() const override { explicit VXXJitCode(int d, operand_type type, int scalar_index,
bool with_relu, size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr),
num_(d),
type_(type),
scalar_index_(scalar_index),
with_relu_(with_relu) {
this->genCode();
}
virtual const char* name() const {
std::string base = "VXXJitCode"; std::string base = "VXXJitCode";
if (scalar_index_ == 1) { if (scalar_index_ == 1) {
base += "_Scalar"; base += "_Scalar";
...@@ -45,15 +56,6 @@ class VXXJitCode : public JitCode { ...@@ -45,15 +56,6 @@ class VXXJitCode : public JitCode {
base += (with_relu_ ? "_Relu" : ""); base += (with_relu_ ? "_Relu" : "");
return base.c_str(); return base.c_str();
} }
explicit VXXJitCode(int d, operand_type type, int scalar_index,
bool with_relu, size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr),
num_(d),
type_(type),
scalar_index_(scalar_index),
with_relu_(with_relu) {}
// static bool init(int d, int scalar_index = 0);
void genCode() override; void genCode() override;
private: private:
......
...@@ -16,23 +16,6 @@ ...@@ -16,23 +16,6 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace jit { namespace jit {} // namespace jit
template <>
size_t GetKey<int>(int d) {
return d;
}
// template <>
// std::shared_ptr<const GenBase> CreateJitCode<KernelType::vmul, int>(int attr)
// {
// if (UseJitCode<KernelType::vmul, int>(attr)) {
// return std::make_shared<gen::VMulJitCode<int>>(attr,
// CodeSize<KernelType::vmul, int>(attr)));
// }
// return nullptr;
// }
} // namespace jit
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -70,9 +70,10 @@ typedef enum { ...@@ -70,9 +70,10 @@ typedef enum {
class JitCode : public GenBase, public Xbyak::CodeGenerator { class JitCode : public GenBase, public Xbyak::CodeGenerator {
public: public:
explicit JitCode(size_t code_size, void* code_ptr = nullptr) explicit JitCode(size_t code_size, void* code_ptr = nullptr)
: Xbyak::CodeGenerator(code_size, code_ptr) { : Xbyak::CodeGenerator(code_size, code_ptr) {}
this->genCode();
} virtual const char* name() const = 0;
virtual void genCode() = 0;
size_t getSize() const override { return CodeGenerator::getSize(); } size_t getSize() const override { return CodeGenerator::getSize(); }
const unsigned char* getCodeInternal() override { const unsigned char* getCodeInternal() override {
...@@ -80,9 +81,6 @@ class JitCode : public GenBase, public Xbyak::CodeGenerator { ...@@ -80,9 +81,6 @@ class JitCode : public GenBase, public Xbyak::CodeGenerator {
return code; return code;
} }
virtual const char* name() const = 0;
virtual void genCode() = 0;
protected: protected:
Xbyak::Reg64 param1{abi_param1}; Xbyak::Reg64 param1{abi_param1};
const int EVEX_max_8b_offt = 0x200; const int EVEX_max_8b_offt = 0x200;
......
...@@ -23,6 +23,11 @@ namespace paddle { ...@@ -23,6 +23,11 @@ namespace paddle {
namespace operators { namespace operators {
namespace jit { namespace jit {
template <>
size_t JitCodeKey<int>(int d) {
return d;
}
// refer do not need useme, it would be the last one. // refer do not need useme, it would be the last one.
void GenBase::dumpCode(const unsigned char* code) const { void GenBase::dumpCode(const unsigned char* code) const {
if (code) { if (code) {
......
...@@ -15,9 +15,8 @@ ...@@ -15,9 +15,8 @@
#pragma once #pragma once
#include <gflags/gflags.h> #include <gflags/gflags.h>
#include <memory> // for shared_ptr #include <memory> // for unique_ptr
#include "paddle/fluid/operators/jit/kernel_base.h" #include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/platform/macros.h"
DECLARE_bool(dump_jitcode); DECLARE_bool(dump_jitcode);
...@@ -25,29 +24,12 @@ namespace paddle { ...@@ -25,29 +24,12 @@ namespace paddle {
namespace operators { namespace operators {
namespace jit { namespace jit {
// TODO(TJ): make these functions as virtual of a class
// Every JitCode should estimate the code size itself
template <KernelType KT, typename T, typename Attr>
size_t CodeSize(Attr attr) {
return 4096;
}
// Every JitCode should have a condition when to use this JitCode
template <KernelType KT, typename T, typename Attr>
bool UseJitCode(Attr attr) {
return false;
}
// Every JitCode should have a method to get the key from attribution
template <typename Attr>
size_t GetKey(Attr attr);
class GenBase : public Kernel { class GenBase : public Kernel {
public: public:
virtual ~GenBase() = default;
virtual const char* name() const = 0; virtual const char* name() const = 0;
virtual const unsigned char* getCodeInternal() = 0;
virtual size_t getSize() const = 0; virtual size_t getSize() const = 0;
virtual const unsigned char* getCodeInternal() = 0;
template <typename FUNC> template <typename FUNC>
const FUNC getCode() { const FUNC getCode() {
const unsigned char* code = this->getCodeInternal(); const unsigned char* code = this->getCodeInternal();
...@@ -61,8 +43,31 @@ class GenBase : public Kernel { ...@@ -61,8 +43,31 @@ class GenBase : public Kernel {
void dumpCode(const unsigned char* code) const; void dumpCode(const unsigned char* code) const;
}; };
template <KernelType KT, typename T, typename Attr> // Every JitCode should have a method to get the key from attribution
std::unique_ptr<GenBase> CreateJitCode(Attr attr); template <typename Attr>
size_t JitCodeKey(Attr attr);
// Creator is used to creat the jitcode and save in pool.
// Every JitCode should have one creator.
class GenCreator {
public:
virtual ~GenCreator() = default;
};
template <typename Attr>
class JitCodeCreator : public GenCreator {
public:
virtual ~JitCodeCreator() = default;
// condition when this jit code can be used.
virtual bool UseMe(const Attr& attr) const = 0;
// estimate this code size
virtual size_t CodeSize(const Attr& attr) const = 0;
// create this code
virtual std::unique_ptr<GenBase> CreateJitCode(const Attr& attr) const = 0;
};
} // namespace jit } // namespace jit
} // namespace operators } // namespace operators
......
...@@ -21,6 +21,11 @@ namespace paddle { ...@@ -21,6 +21,11 @@ namespace paddle {
namespace operators { namespace operators {
namespace jit { namespace jit {
JitCodeCreatorPool& JitCodeCreatorPool::Instance() {
static JitCodeCreatorPool g_creator_pool;
return g_creator_pool;
}
KernelPool& KernelPool::Instance() { KernelPool& KernelPool::Instance() {
static KernelPool g_kernel_pool; static KernelPool g_kernel_pool;
return g_kernel_pool; return g_kernel_pool;
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#pragma once #pragma once
#include <memory> // for shared_ptr #include <memory> // for unique_ptr
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
...@@ -52,6 +52,28 @@ class JitCodePool { ...@@ -52,6 +52,28 @@ class JitCodePool {
DISABLE_COPY_AND_ASSIGN(JitCodePool); DISABLE_COPY_AND_ASSIGN(JitCodePool);
}; };
class JitCodeCreatorPool {
typedef std::unique_ptr<const GenCreator> GenCreatorPtr;
typedef std::unordered_map<KernelKey, std::vector<GenCreatorPtr>,
KernelKey::Hash>
GenCreatorPtrMap;
public:
JitCodeCreatorPool() = default;
static JitCodeCreatorPool& Instance();
GenCreatorPtrMap& AllCreators() { return creators_; }
void Insert(const KernelKey& key, GenCreatorPtr value) {
if (creators_.find(key) == creators_.end()) {
creators_.emplace(key, std::vector<GenCreatorPtr>());
}
creators_.at(key).emplace_back(std::move(value));
}
private:
GenCreatorPtrMap creators_;
DISABLE_COPY_AND_ASSIGN(JitCodeCreatorPool);
};
typedef std::unique_ptr<const Kernel> KernelPtr; typedef std::unique_ptr<const Kernel> KernelPtr;
typedef std::unordered_map<KernelKey, std::vector<KernelPtr>, KernelKey::Hash> typedef std::unordered_map<KernelKey, std::vector<KernelPtr>, KernelKey::Hash>
KernelMap; KernelMap;
...@@ -113,24 +135,33 @@ inline Func GetRefer() { ...@@ -113,24 +135,33 @@ inline Func GetRefer() {
template <KernelType KT, typename T, typename Func, typename Attr, template <KernelType KT, typename T, typename Func, typename Attr,
typename PlaceType = platform::CPUPlace> typename PlaceType = platform::CPUPlace>
const Func Get(Attr attr) { const Func Get(Attr attr) {
size_t key = GetKey<Attr>(attr); size_t key = JitCodeKey<Attr>(attr);
auto& codes = JitCodePool<KT>().Instance(); auto& codes = JitCodePool<KT>().Instance();
if (codes.Has(key)) { if (codes.Has(key)) {
return codes.AllKernels().at(key)->template getCode<Func>(); return codes.AllKernels().at(key)->template getCode<Func>();
} }
KernelKey kkey(KT, PlaceType());
if (std::is_same<PlaceType, platform::CPUPlace>::value) { if (std::is_same<PlaceType, platform::CPUPlace>::value) {
auto p = CreateJitCode<KT, T, Attr>(attr); // pool: (KernelKey(type, place), vector<GenCreatorPtr>)
if (p) { auto& creator_map = JitCodeCreatorPool().Instance().AllCreators();
auto f = p->template getCode<Func>(); auto iter = creator_map.find(kkey);
codes.Insert(key, std::move(p)); auto& creators = iter->second;
return f; for (auto& cur : creators) {
auto i = dynamic_cast<const JitCodeCreator<Attr>*>(cur.get());
if (i && i->UseMe(attr)) {
auto p = i->CreateJitCode(attr);
if (p) {
auto f = p->template getCode<Func>();
codes.Insert(key, std::move(p));
return f;
}
}
} }
} }
// pool: (KernelKey(type, place), vector<Kernel>) // pool: (KernelKey(type, place), vector<KernelPtr>)
auto& pool = KernelPool().Instance().AllKernels(); auto& pool = KernelPool().Instance().AllKernels();
KernelKey kkey(KT, PlaceType());
auto iter = pool.find(kkey); auto iter = pool.find(kkey);
if (iter != pool.end()) { if (iter != pool.end()) {
auto& impls = iter->second; auto& impls = iter->second;
......
...@@ -116,7 +116,30 @@ class JitKernelRegistrar { ...@@ -116,7 +116,30 @@ class JitKernelRegistrar {
#define REGISTER_GPUKERNEL_MORE(kernel_type, impl_type, ...) \ #define REGISTER_GPUKERNEL_MORE(kernel_type, impl_type, ...) \
REGISTER_KERNEL_MORE(kernel_type, impl_type, GPUPlace, __VA_ARGS__) REGISTER_KERNEL_MORE(kernel_type, impl_type, GPUPlace, __VA_ARGS__)
// REGISTER_JITKERNEL_JITCODE(vmul, JitKernelCode<vmul, int>); #define REGISTER_JITKERNEL_GEN(kernel_type, ...) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_gen_##kernel_type##_CPUPlace_, \
"REGISTER_JITKERNEL_GEN must be called in global namespace"); \
extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static int __assert_gen_##kernel_type##_has_refer_ UNUSED = \
TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static ::paddle::operators::jit::JitKernelRegistrar< \
::paddle::operators::jit::JitCodeCreatorPool, \
::paddle::platform::CPUPlace, __VA_ARGS__> \
__jit_kernel_registrar_gen_##kernel_type##_CPUPlace_( \
::paddle::operators::jit::KernelType::kernel_type); \
int TouchJitKernelReg_gen_##kernel_type##_CPUPlace_() { \
__jit_kernel_registrar_gen_##kernel_type##_CPUPlace_.Touch(); \
return 0; \
}
#define USE_JITKERNEL_GEN(kernel_type) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_gen_##kernel_type##_CPUPlace_, \
"USE_JITKERNEL_GEN must be called in global namespace"); \
extern int TouchJitKernelReg_gen_##kernel_type##_CPUPlace_(); \
static int use_jitkernel_gen_##kernel_type##_CPUPlace_ UNUSED = \
TouchJitKernelReg_gen_##kernel_type##_CPUPlace_()
#define USE_JITKERNEL_REFER(kernel_type) \ #define USE_JITKERNEL_REFER(kernel_type) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \ STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
......
...@@ -61,6 +61,7 @@ void ExpectEQ(const T* target, const T* refer, int n) { ...@@ -61,6 +61,7 @@ void ExpectEQ(const T* target, const T* refer, int n) {
// TODO(TJ): remove me // TODO(TJ): remove me
USE_JITKERNEL_MORE(vmul, mkl); USE_JITKERNEL_MORE(vmul, mkl);
USE_JITKERNEL_REFER(vmul); USE_JITKERNEL_REFER(vmul);
USE_JITKERNEL_GEN(vmul);
TEST(JitKernel, vmul) { TEST(JitKernel, vmul) {
using T = float; using T = float;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册