提交 900c789a 编写于 作者: T tensor-tang

use jitcode and use vmul

上级 53709e7e
......@@ -14,6 +14,7 @@
#include "paddle/fluid/operators/jit/gen/blas.h"
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
......@@ -103,17 +104,24 @@ void VXXJitCode::genCode() {
ret();
}
} // namespace gen
template <>
std::unique_ptr<GenBase> CreateJitCode<KernelType::vmul, float, int>(int attr) {
if (UseJitCode<KernelType::vmul, float, int>(attr)) {
return make_unique<gen::VMulJitCode>(
attr, CodeSize<KernelType::vmul, float, int>(attr));
class VMulCreator : public JitCodeCreator<int> {
public:
bool UseMe(const int& attr) const override {
return platform::MayIUse(platform::avx);
}
return nullptr;
}
size_t CodeSize(const int& d) const override {
return 96 + d / YMM_FLOAT_BLOCK * 4 * 8;
}
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override {
return make_unique<VMulJitCode>(attr, CodeSize(attr));
}
};
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(vmul, gen::VMulCreator);
......@@ -25,7 +25,18 @@ namespace gen {
// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu)
class VXXJitCode : public JitCode {
public:
const char* name() const override {
explicit VXXJitCode(int d, operand_type type, int scalar_index,
bool with_relu, size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr),
num_(d),
type_(type),
scalar_index_(scalar_index),
with_relu_(with_relu) {
this->genCode();
}
virtual const char* name() const {
std::string base = "VXXJitCode";
if (scalar_index_ == 1) {
base += "_Scalar";
......@@ -45,15 +56,6 @@ class VXXJitCode : public JitCode {
base += (with_relu_ ? "_Relu" : "");
return base.c_str();
}
explicit VXXJitCode(int d, operand_type type, int scalar_index,
bool with_relu, size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr),
num_(d),
type_(type),
scalar_index_(scalar_index),
with_relu_(with_relu) {}
// static bool init(int d, int scalar_index = 0);
void genCode() override;
private:
......
......@@ -16,23 +16,6 @@
namespace paddle {
namespace operators {
namespace jit {
template <>
size_t GetKey<int>(int d) {
return d;
}
// template <>
// std::shared_ptr<const GenBase> CreateJitCode<KernelType::vmul, int>(int attr)
// {
// if (UseJitCode<KernelType::vmul, int>(attr)) {
// return std::make_shared<gen::VMulJitCode<int>>(attr,
// CodeSize<KernelType::vmul, int>(attr)));
// }
// return nullptr;
// }
} // namespace jit
namespace jit {} // namespace jit
} // namespace operators
} // namespace paddle
......@@ -70,9 +70,10 @@ typedef enum {
class JitCode : public GenBase, public Xbyak::CodeGenerator {
public:
explicit JitCode(size_t code_size, void* code_ptr = nullptr)
: Xbyak::CodeGenerator(code_size, code_ptr) {
this->genCode();
}
: Xbyak::CodeGenerator(code_size, code_ptr) {}
virtual const char* name() const = 0;
virtual void genCode() = 0;
size_t getSize() const override { return CodeGenerator::getSize(); }
const unsigned char* getCodeInternal() override {
......@@ -80,9 +81,6 @@ class JitCode : public GenBase, public Xbyak::CodeGenerator {
return code;
}
virtual const char* name() const = 0;
virtual void genCode() = 0;
protected:
Xbyak::Reg64 param1{abi_param1};
const int EVEX_max_8b_offt = 0x200;
......
......@@ -23,6 +23,11 @@ namespace paddle {
namespace operators {
namespace jit {
template <>
size_t JitCodeKey<int>(int d) {
return d;
}
// refer do not need useme, it would be the last one.
void GenBase::dumpCode(const unsigned char* code) const {
if (code) {
......
......@@ -15,9 +15,8 @@
#pragma once
#include <gflags/gflags.h>
#include <memory> // for shared_ptr
#include <memory> // for unique_ptr
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/platform/macros.h"
DECLARE_bool(dump_jitcode);
......@@ -25,29 +24,12 @@ namespace paddle {
namespace operators {
namespace jit {
// TODO(TJ): make these functions as virtual of a class
// Every JitCode should estimate the code size itself
template <KernelType KT, typename T, typename Attr>
size_t CodeSize(Attr attr) {
return 4096;
}
// Every JitCode should have a condition when to use this JitCode
template <KernelType KT, typename T, typename Attr>
bool UseJitCode(Attr attr) {
return false;
}
// Every JitCode should have a method to get the key from attribution
template <typename Attr>
size_t GetKey(Attr attr);
class GenBase : public Kernel {
public:
virtual ~GenBase() = default;
virtual const char* name() const = 0;
virtual const unsigned char* getCodeInternal() = 0;
virtual size_t getSize() const = 0;
virtual const unsigned char* getCodeInternal() = 0;
template <typename FUNC>
const FUNC getCode() {
const unsigned char* code = this->getCodeInternal();
......@@ -61,8 +43,31 @@ class GenBase : public Kernel {
void dumpCode(const unsigned char* code) const;
};
template <KernelType KT, typename T, typename Attr>
std::unique_ptr<GenBase> CreateJitCode(Attr attr);
// Every JitCode should have a method to get the key from attribution
template <typename Attr>
size_t JitCodeKey(Attr attr);
// Creator is used to creat the jitcode and save in pool.
// Every JitCode should have one creator.
class GenCreator {
public:
virtual ~GenCreator() = default;
};
template <typename Attr>
class JitCodeCreator : public GenCreator {
public:
virtual ~JitCodeCreator() = default;
// condition when this jit code can be used.
virtual bool UseMe(const Attr& attr) const = 0;
// estimate this code size
virtual size_t CodeSize(const Attr& attr) const = 0;
// create this code
virtual std::unique_ptr<GenBase> CreateJitCode(const Attr& attr) const = 0;
};
} // namespace jit
} // namespace operators
......
......@@ -21,6 +21,11 @@ namespace paddle {
namespace operators {
namespace jit {
JitCodeCreatorPool& JitCodeCreatorPool::Instance() {
static JitCodeCreatorPool g_creator_pool;
return g_creator_pool;
}
KernelPool& KernelPool::Instance() {
static KernelPool g_kernel_pool;
return g_kernel_pool;
......
......@@ -14,7 +14,7 @@
#pragma once
#include <memory> // for shared_ptr
#include <memory> // for unique_ptr
#include <string>
#include <unordered_map>
#include <vector>
......@@ -52,6 +52,28 @@ class JitCodePool {
DISABLE_COPY_AND_ASSIGN(JitCodePool);
};
class JitCodeCreatorPool {
typedef std::unique_ptr<const GenCreator> GenCreatorPtr;
typedef std::unordered_map<KernelKey, std::vector<GenCreatorPtr>,
KernelKey::Hash>
GenCreatorPtrMap;
public:
JitCodeCreatorPool() = default;
static JitCodeCreatorPool& Instance();
GenCreatorPtrMap& AllCreators() { return creators_; }
void Insert(const KernelKey& key, GenCreatorPtr value) {
if (creators_.find(key) == creators_.end()) {
creators_.emplace(key, std::vector<GenCreatorPtr>());
}
creators_.at(key).emplace_back(std::move(value));
}
private:
GenCreatorPtrMap creators_;
DISABLE_COPY_AND_ASSIGN(JitCodeCreatorPool);
};
typedef std::unique_ptr<const Kernel> KernelPtr;
typedef std::unordered_map<KernelKey, std::vector<KernelPtr>, KernelKey::Hash>
KernelMap;
......@@ -113,24 +135,33 @@ inline Func GetRefer() {
template <KernelType KT, typename T, typename Func, typename Attr,
typename PlaceType = platform::CPUPlace>
const Func Get(Attr attr) {
size_t key = GetKey<Attr>(attr);
size_t key = JitCodeKey<Attr>(attr);
auto& codes = JitCodePool<KT>().Instance();
if (codes.Has(key)) {
return codes.AllKernels().at(key)->template getCode<Func>();
}
KernelKey kkey(KT, PlaceType());
if (std::is_same<PlaceType, platform::CPUPlace>::value) {
auto p = CreateJitCode<KT, T, Attr>(attr);
if (p) {
auto f = p->template getCode<Func>();
codes.Insert(key, std::move(p));
return f;
// pool: (KernelKey(type, place), vector<GenCreatorPtr>)
auto& creator_map = JitCodeCreatorPool().Instance().AllCreators();
auto iter = creator_map.find(kkey);
auto& creators = iter->second;
for (auto& cur : creators) {
auto i = dynamic_cast<const JitCodeCreator<Attr>*>(cur.get());
if (i && i->UseMe(attr)) {
auto p = i->CreateJitCode(attr);
if (p) {
auto f = p->template getCode<Func>();
codes.Insert(key, std::move(p));
return f;
}
}
}
}
// pool: (KernelKey(type, place), vector<Kernel>)
// pool: (KernelKey(type, place), vector<KernelPtr>)
auto& pool = KernelPool().Instance().AllKernels();
KernelKey kkey(KT, PlaceType());
auto iter = pool.find(kkey);
if (iter != pool.end()) {
auto& impls = iter->second;
......
......@@ -116,7 +116,30 @@ class JitKernelRegistrar {
#define REGISTER_GPUKERNEL_MORE(kernel_type, impl_type, ...) \
REGISTER_KERNEL_MORE(kernel_type, impl_type, GPUPlace, __VA_ARGS__)
// REGISTER_JITKERNEL_JITCODE(vmul, JitKernelCode<vmul, int>);
#define REGISTER_JITKERNEL_GEN(kernel_type, ...) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_gen_##kernel_type##_CPUPlace_, \
"REGISTER_JITKERNEL_GEN must be called in global namespace"); \
extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static int __assert_gen_##kernel_type##_has_refer_ UNUSED = \
TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static ::paddle::operators::jit::JitKernelRegistrar< \
::paddle::operators::jit::JitCodeCreatorPool, \
::paddle::platform::CPUPlace, __VA_ARGS__> \
__jit_kernel_registrar_gen_##kernel_type##_CPUPlace_( \
::paddle::operators::jit::KernelType::kernel_type); \
int TouchJitKernelReg_gen_##kernel_type##_CPUPlace_() { \
__jit_kernel_registrar_gen_##kernel_type##_CPUPlace_.Touch(); \
return 0; \
}
#define USE_JITKERNEL_GEN(kernel_type) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_gen_##kernel_type##_CPUPlace_, \
"USE_JITKERNEL_GEN must be called in global namespace"); \
extern int TouchJitKernelReg_gen_##kernel_type##_CPUPlace_(); \
static int use_jitkernel_gen_##kernel_type##_CPUPlace_ UNUSED = \
TouchJitKernelReg_gen_##kernel_type##_CPUPlace_()
#define USE_JITKERNEL_REFER(kernel_type) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
......
......@@ -61,6 +61,7 @@ void ExpectEQ(const T* target, const T* refer, int n) {
// TODO(TJ): remove me
USE_JITKERNEL_MORE(vmul, mkl);
USE_JITKERNEL_REFER(vmul);
USE_JITKERNEL_GEN(vmul);
TEST(JitKernel, vmul) {
using T = float;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册