From 900c789a35798cf73b67a2bb7b7944f3110c7bda Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Mon, 10 Dec 2018 12:00:46 +0000 Subject: [PATCH] use jitcode and use vmul --- paddle/fluid/operators/jit/gen/blas.cc | 26 ++++++++---- paddle/fluid/operators/jit/gen/blas.h | 22 +++++----- paddle/fluid/operators/jit/gen/jitcode.cc | 19 +-------- paddle/fluid/operators/jit/gen/jitcode.h | 10 ++--- paddle/fluid/operators/jit/gen_base.cc | 5 +++ paddle/fluid/operators/jit/gen_base.h | 51 +++++++++++++---------- paddle/fluid/operators/jit/kernel_pool.cc | 5 +++ paddle/fluid/operators/jit/kernel_pool.h | 49 ++++++++++++++++++---- paddle/fluid/operators/jit/registry.h | 25 ++++++++++- paddle/fluid/operators/jit/test.cc | 1 + 10 files changed, 137 insertions(+), 76 deletions(-) diff --git a/paddle/fluid/operators/jit/gen/blas.cc b/paddle/fluid/operators/jit/gen/blas.cc index 4a8b4554c8b..3e5ce540647 100644 --- a/paddle/fluid/operators/jit/gen/blas.cc +++ b/paddle/fluid/operators/jit/gen/blas.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/operators/jit/gen/blas.h" #include "paddle/fluid/operators/jit/registry.h" +#include "paddle/fluid/platform/cpu_info.h" namespace paddle { namespace operators { @@ -103,17 +104,24 @@ void VXXJitCode::genCode() { ret(); } -} // namespace gen - -template <> -std::unique_ptr CreateJitCode(int attr) { - if (UseJitCode(attr)) { - return make_unique( - attr, CodeSize(attr)); +class VMulCreator : public JitCodeCreator { + public: + bool UseMe(const int& attr) const override { + return platform::MayIUse(platform::avx); } - return nullptr; -} + size_t CodeSize(const int& d) const override { + return 96 + d / YMM_FLOAT_BLOCK * 4 * 8; + } + std::unique_ptr CreateJitCode(const int& attr) const override { + return make_unique(attr, CodeSize(attr)); + } +}; +} // namespace gen } // namespace jit } // namespace operators } // namespace paddle + +namespace gen = paddle::operators::jit::gen; + +REGISTER_JITKERNEL_GEN(vmul, gen::VMulCreator); diff --git a/paddle/fluid/operators/jit/gen/blas.h b/paddle/fluid/operators/jit/gen/blas.h index edc05f86a03..60f32805678 100644 --- a/paddle/fluid/operators/jit/gen/blas.h +++ b/paddle/fluid/operators/jit/gen/blas.h @@ -25,7 +25,18 @@ namespace gen { // function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu) class VXXJitCode : public JitCode { public: - const char* name() const override { + explicit VXXJitCode(int d, operand_type type, int scalar_index, + bool with_relu, size_t code_size = 256 * 1024, + void* code_ptr = nullptr) + : JitCode(code_size, code_ptr), + num_(d), + type_(type), + scalar_index_(scalar_index), + with_relu_(with_relu) { + this->genCode(); + } + + virtual const char* name() const { std::string base = "VXXJitCode"; if (scalar_index_ == 1) { base += "_Scalar"; @@ -45,15 +56,6 @@ class VXXJitCode : public JitCode { base += (with_relu_ ? "_Relu" : ""); return base.c_str(); } - explicit VXXJitCode(int d, operand_type type, int scalar_index, - bool with_relu, size_t code_size = 256 * 1024, - void* code_ptr = nullptr) - : JitCode(code_size, code_ptr), - num_(d), - type_(type), - scalar_index_(scalar_index), - with_relu_(with_relu) {} - // static bool init(int d, int scalar_index = 0); void genCode() override; private: diff --git a/paddle/fluid/operators/jit/gen/jitcode.cc b/paddle/fluid/operators/jit/gen/jitcode.cc index 93204d340e9..7aaf6a2ff65 100644 --- a/paddle/fluid/operators/jit/gen/jitcode.cc +++ b/paddle/fluid/operators/jit/gen/jitcode.cc @@ -16,23 +16,6 @@ namespace paddle { namespace operators { -namespace jit { - -template <> -size_t GetKey(int d) { - return d; -} - -// template <> -// std::shared_ptr CreateJitCode(int attr) -// { -// if (UseJitCode(attr)) { -// return std::make_shared>(attr, -// CodeSize(attr))); -// } -// return nullptr; -// } - -} // namespace jit +namespace jit {} // namespace jit } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/jit/gen/jitcode.h b/paddle/fluid/operators/jit/gen/jitcode.h index 52b8da9a82a..caa3ef9dda7 100644 --- a/paddle/fluid/operators/jit/gen/jitcode.h +++ b/paddle/fluid/operators/jit/gen/jitcode.h @@ -70,9 +70,10 @@ typedef enum { class JitCode : public GenBase, public Xbyak::CodeGenerator { public: explicit JitCode(size_t code_size, void* code_ptr = nullptr) - : Xbyak::CodeGenerator(code_size, code_ptr) { - this->genCode(); - } + : Xbyak::CodeGenerator(code_size, code_ptr) {} + + virtual const char* name() const = 0; + virtual void genCode() = 0; size_t getSize() const override { return CodeGenerator::getSize(); } const unsigned char* getCodeInternal() override { @@ -80,9 +81,6 @@ class JitCode : public GenBase, public Xbyak::CodeGenerator { return code; } - virtual const char* name() const = 0; - virtual void genCode() = 0; - protected: Xbyak::Reg64 param1{abi_param1}; const int EVEX_max_8b_offt = 0x200; diff --git a/paddle/fluid/operators/jit/gen_base.cc b/paddle/fluid/operators/jit/gen_base.cc index 310da0c76f1..a8bf9029637 100644 --- a/paddle/fluid/operators/jit/gen_base.cc +++ b/paddle/fluid/operators/jit/gen_base.cc @@ -23,6 +23,11 @@ namespace paddle { namespace operators { namespace jit { +template <> +size_t JitCodeKey(int d) { + return d; +} + // refer do not need useme, it would be the last one. void GenBase::dumpCode(const unsigned char* code) const { if (code) { diff --git a/paddle/fluid/operators/jit/gen_base.h b/paddle/fluid/operators/jit/gen_base.h index 4a136534dca..3b874cf2b01 100644 --- a/paddle/fluid/operators/jit/gen_base.h +++ b/paddle/fluid/operators/jit/gen_base.h @@ -15,9 +15,8 @@ #pragma once #include -#include // for shared_ptr +#include // for unique_ptr #include "paddle/fluid/operators/jit/kernel_base.h" -#include "paddle/fluid/platform/macros.h" DECLARE_bool(dump_jitcode); @@ -25,29 +24,12 @@ namespace paddle { namespace operators { namespace jit { -// TODO(TJ): make these functions as virtual of a class - -// Every JitCode should estimate the code size itself -template -size_t CodeSize(Attr attr) { - return 4096; -} - -// Every JitCode should have a condition when to use this JitCode -template -bool UseJitCode(Attr attr) { - return false; -} - -// Every JitCode should have a method to get the key from attribution -template -size_t GetKey(Attr attr); - class GenBase : public Kernel { public: + virtual ~GenBase() = default; virtual const char* name() const = 0; - virtual const unsigned char* getCodeInternal() = 0; virtual size_t getSize() const = 0; + virtual const unsigned char* getCodeInternal() = 0; template const FUNC getCode() { const unsigned char* code = this->getCodeInternal(); @@ -61,8 +43,31 @@ class GenBase : public Kernel { void dumpCode(const unsigned char* code) const; }; -template -std::unique_ptr CreateJitCode(Attr attr); +// Every JitCode should have a method to get the key from attribution +template +size_t JitCodeKey(Attr attr); + +// Creator is used to creat the jitcode and save in pool. +// Every JitCode should have one creator. +class GenCreator { + public: + virtual ~GenCreator() = default; +}; + +template +class JitCodeCreator : public GenCreator { + public: + virtual ~JitCodeCreator() = default; + + // condition when this jit code can be used. + virtual bool UseMe(const Attr& attr) const = 0; + + // estimate this code size + virtual size_t CodeSize(const Attr& attr) const = 0; + + // create this code + virtual std::unique_ptr CreateJitCode(const Attr& attr) const = 0; +}; } // namespace jit } // namespace operators diff --git a/paddle/fluid/operators/jit/kernel_pool.cc b/paddle/fluid/operators/jit/kernel_pool.cc index f300d28a6f0..bc98c644fbe 100644 --- a/paddle/fluid/operators/jit/kernel_pool.cc +++ b/paddle/fluid/operators/jit/kernel_pool.cc @@ -21,6 +21,11 @@ namespace paddle { namespace operators { namespace jit { +JitCodeCreatorPool& JitCodeCreatorPool::Instance() { + static JitCodeCreatorPool g_creator_pool; + return g_creator_pool; +} + KernelPool& KernelPool::Instance() { static KernelPool g_kernel_pool; return g_kernel_pool; diff --git a/paddle/fluid/operators/jit/kernel_pool.h b/paddle/fluid/operators/jit/kernel_pool.h index 737b7f60e3c..c9e7fc84e51 100644 --- a/paddle/fluid/operators/jit/kernel_pool.h +++ b/paddle/fluid/operators/jit/kernel_pool.h @@ -14,7 +14,7 @@ #pragma once -#include // for shared_ptr +#include // for unique_ptr #include #include #include @@ -52,6 +52,28 @@ class JitCodePool { DISABLE_COPY_AND_ASSIGN(JitCodePool); }; +class JitCodeCreatorPool { + typedef std::unique_ptr GenCreatorPtr; + typedef std::unordered_map, + KernelKey::Hash> + GenCreatorPtrMap; + + public: + JitCodeCreatorPool() = default; + static JitCodeCreatorPool& Instance(); + GenCreatorPtrMap& AllCreators() { return creators_; } + void Insert(const KernelKey& key, GenCreatorPtr value) { + if (creators_.find(key) == creators_.end()) { + creators_.emplace(key, std::vector()); + } + creators_.at(key).emplace_back(std::move(value)); + } + + private: + GenCreatorPtrMap creators_; + DISABLE_COPY_AND_ASSIGN(JitCodeCreatorPool); +}; + typedef std::unique_ptr KernelPtr; typedef std::unordered_map, KernelKey::Hash> KernelMap; @@ -113,24 +135,33 @@ inline Func GetRefer() { template const Func Get(Attr attr) { - size_t key = GetKey(attr); + size_t key = JitCodeKey(attr); auto& codes = JitCodePool().Instance(); if (codes.Has(key)) { return codes.AllKernels().at(key)->template getCode(); } + KernelKey kkey(KT, PlaceType()); if (std::is_same::value) { - auto p = CreateJitCode(attr); - if (p) { - auto f = p->template getCode(); - codes.Insert(key, std::move(p)); - return f; + // pool: (KernelKey(type, place), vector) + auto& creator_map = JitCodeCreatorPool().Instance().AllCreators(); + auto iter = creator_map.find(kkey); + auto& creators = iter->second; + for (auto& cur : creators) { + auto i = dynamic_cast*>(cur.get()); + if (i && i->UseMe(attr)) { + auto p = i->CreateJitCode(attr); + if (p) { + auto f = p->template getCode(); + codes.Insert(key, std::move(p)); + return f; + } + } } } - // pool: (KernelKey(type, place), vector) + // pool: (KernelKey(type, place), vector) auto& pool = KernelPool().Instance().AllKernels(); - KernelKey kkey(KT, PlaceType()); auto iter = pool.find(kkey); if (iter != pool.end()) { auto& impls = iter->second; diff --git a/paddle/fluid/operators/jit/registry.h b/paddle/fluid/operators/jit/registry.h index c1f02d9cd57..cb32c487208 100644 --- a/paddle/fluid/operators/jit/registry.h +++ b/paddle/fluid/operators/jit/registry.h @@ -116,7 +116,30 @@ class JitKernelRegistrar { #define REGISTER_GPUKERNEL_MORE(kernel_type, impl_type, ...) \ REGISTER_KERNEL_MORE(kernel_type, impl_type, GPUPlace, __VA_ARGS__) -// REGISTER_JITKERNEL_JITCODE(vmul, JitKernelCode); +#define REGISTER_JITKERNEL_GEN(kernel_type, ...) \ + STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \ + __reg_jitkernel_gen_##kernel_type##_CPUPlace_, \ + "REGISTER_JITKERNEL_GEN must be called in global namespace"); \ + extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \ + static int __assert_gen_##kernel_type##_has_refer_ UNUSED = \ + TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \ + static ::paddle::operators::jit::JitKernelRegistrar< \ + ::paddle::operators::jit::JitCodeCreatorPool, \ + ::paddle::platform::CPUPlace, __VA_ARGS__> \ + __jit_kernel_registrar_gen_##kernel_type##_CPUPlace_( \ + ::paddle::operators::jit::KernelType::kernel_type); \ + int TouchJitKernelReg_gen_##kernel_type##_CPUPlace_() { \ + __jit_kernel_registrar_gen_##kernel_type##_CPUPlace_.Touch(); \ + return 0; \ + } + +#define USE_JITKERNEL_GEN(kernel_type) \ + STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \ + __reg_jitkernel_gen_##kernel_type##_CPUPlace_, \ + "USE_JITKERNEL_GEN must be called in global namespace"); \ + extern int TouchJitKernelReg_gen_##kernel_type##_CPUPlace_(); \ + static int use_jitkernel_gen_##kernel_type##_CPUPlace_ UNUSED = \ + TouchJitKernelReg_gen_##kernel_type##_CPUPlace_() #define USE_JITKERNEL_REFER(kernel_type) \ STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \ diff --git a/paddle/fluid/operators/jit/test.cc b/paddle/fluid/operators/jit/test.cc index 836b6eee800..5af9ed697d6 100644 --- a/paddle/fluid/operators/jit/test.cc +++ b/paddle/fluid/operators/jit/test.cc @@ -61,6 +61,7 @@ void ExpectEQ(const T* target, const T* refer, int n) { // TODO(TJ): remove me USE_JITKERNEL_MORE(vmul, mkl); USE_JITKERNEL_REFER(vmul); +USE_JITKERNEL_GEN(vmul); TEST(JitKernel, vmul) { using T = float; -- GitLab