提交 45bfa70c 编写于 作者: T tensor-tang

complete vmul jit kernel

上级 77236e33
# set(use_jit_file ${PADDLE_BINARY_DIR}/paddle/fluid/operators/jit/kernels.h)
# file(WRITE ${pass_file} "// Generated by the paddle/fluid/framework/ir/CMakeLists.txt. DO NOT EDIT!\n\n")
# file(APPEND ${pass_file} "\#pragma once\n")
# file(APPEND ${pass_file} "\#include \"paddle/fluid/framework/ir/pass.h\"\n")
set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce place) set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce place)
cc_library(jit_kernel_base SRCS kernels.cc DEPS ${JIT_KERNEL_DEPS}) cc_library(jit_kernel_base SRCS kernels.cc DEPS ${JIT_KERNEL_DEPS})
add_subdirectory(more)
add_subdirectory(refer) add_subdirectory(refer)
add_subdirectory(more)
if(WITH_XBYAK) if(WITH_XBYAK)
add_subdirectory(jitcode) add_subdirectory(jitcode)
endif() endif()
# Debug
message(STATUS "--------${JIT_KERNEL_DEPS}")
cc_library(jit_kernel SRCS kernels.cc DEPS ${JIT_KERNEL_DEPS}) cc_library(jit_kernel SRCS kernels.cc DEPS ${JIT_KERNEL_DEPS})
cc_test(jit_kernel_test SRCS test.cc DEPS jit_kernel) cc_test(jit_kernel_test SRCS test.cc DEPS jit_kernel)
TBD TBD
# Use me
Add USE_JIT_KERNEL(yourname) to CMakefile.
...@@ -13,3 +13,26 @@ ...@@ -13,3 +13,26 @@
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jitkernels/jitcode/jitcode.h" #include "paddle/fluid/operators/jitkernels/jitcode/jitcode.h"
namespace paddle {
namespace operators {
namespace jitkernels {
template <>
size_t GetKey<int>(int d) {
return d;
}
// template <>
// std::shared_ptr<const JitBase> CreateJitCode<KernelType::vmul, int>(int attr)
// {
// if (UseJitCode<KernelType::vmul, int>(attr)) {
// return std::make_shared<jitcode::VMulJitCode<int>>(attr,
// CodeSize<KernelType::vmul, int>(attr)));
// }
// return nullptr;
// }
} // namespace jitkernels
} // namespace operators
} // namespace paddle
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <type_traits> #include <type_traits>
#include "paddle/fluid/operators/jitkernels/jitcode_base.h"
#include "paddle/fluid/operators/jitkernels/kernels.h" #include "paddle/fluid/operators/jitkernels/kernels.h"
#define XBYAK_USE_MMAP_ALLOCATOR #define XBYAK_USE_MMAP_ALLOCATOR
...@@ -31,10 +32,10 @@ constexpr Xbyak::Operand::Code abi_param1(Xbyak::Operand::RDI), ...@@ -31,10 +32,10 @@ constexpr Xbyak::Operand::Code abi_param1(Xbyak::Operand::RDI),
abi_param2(Xbyak::Operand::RSI), abi_param3(Xbyak::Operand::RDX), abi_param2(Xbyak::Operand::RSI), abi_param3(Xbyak::Operand::RDX),
abi_param4(Xbyak::Operand::RCX), abi_not_param1(Xbyak::Operand::RCX); abi_param4(Xbyak::Operand::RCX), abi_not_param1(Xbyak::Operand::RCX);
template <KernelType KT, typename Attr> template <typename Attr>
class JitCode : public JitBase, public Xbyak::CodeGenerator { class VMulJitCode : public JitBase, public Xbyak::CodeGenerator {
public: public:
JitCode(Attr attr, size_t code_size, void* code_ptr = nullptr) VMulJitCode(Attr attr, size_t code_size, void* code_ptr = nullptr)
: Xbyak::CodeGenerator(code_size, code_ptr) { : Xbyak::CodeGenerator(code_size, code_ptr) {
this->genCode(); this->genCode();
} }
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <gflags/gflags.h> #include <gflags/gflags.h>
#include <memory> // for shared_ptr
#include "paddle/fluid/operators/jitkernels/kernel_base.h" #include "paddle/fluid/operators/jitkernels/kernel_base.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
...@@ -42,11 +43,6 @@ bool UseJitCode(Attr attr) { ...@@ -42,11 +43,6 @@ bool UseJitCode(Attr attr) {
template <typename Attr> template <typename Attr>
size_t GetKey(Attr attr); size_t GetKey(Attr attr);
template <>
size_t GetKey<int>(int d) {
return d;
}
class JitBase { class JitBase {
public: public:
JitBase() = default; JitBase() = default;
...@@ -68,6 +64,9 @@ class JitBase { ...@@ -68,6 +64,9 @@ class JitBase {
void dumpCode(const unsigned char* code); void dumpCode(const unsigned char* code);
}; };
template <KernelType KT, typename Attr>
std::shared_ptr<const JitBase> CreateJitCode(Attr attr);
} // namespace jitkernels } // namespace jitkernels
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -25,6 +25,7 @@ typedef enum { vmul = 0, vadd = 1, vsub, vexp } KernelType; ...@@ -25,6 +25,7 @@ typedef enum { vmul = 0, vadd = 1, vsub, vexp } KernelType;
class Kernel { class Kernel {
public: public:
Kernel() = default; Kernel() = default;
virtual ~Kernel() = default;
DISABLE_COPY_AND_ASSIGN(Kernel); DISABLE_COPY_AND_ASSIGN(Kernel);
}; };
...@@ -32,16 +33,20 @@ template <typename T, typename Func, typename Attr> // TODO(TJ): use tuple ...@@ -32,16 +33,20 @@ template <typename T, typename Func, typename Attr> // TODO(TJ): use tuple
class KernelImpl : public Kernel { class KernelImpl : public Kernel {
public: public:
using ELEMENT_TYPE = T; // TODO(TJ): remove me? using ELEMENT_TYPE = T; // TODO(TJ): remove me?
KernelImpl() = default; virtual Func GetFunc() const { return func; }
virtual ~KernelImpl() = default;
virtual Func GetFunc() { return func; }
virtual bool UseMe(Attr attr) const = 0; virtual bool UseMe(Attr attr) const = 0;
protected: protected:
Func func{nullptr}; Func func{nullptr};
}; };
template <typename T, typename Func, typename Attr> // TODO(TJ): use tuple
class ReferKernel : public KernelImpl<T, Func, Attr> {
public:
// Refer code can always be used
bool UseMe(Attr attr) const override { return true; }
};
} // namespace jitkernels } // namespace jitkernels
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -21,13 +21,16 @@ namespace paddle { ...@@ -21,13 +21,16 @@ namespace paddle {
namespace operators { namespace operators {
namespace jitkernels { namespace jitkernels {
// refer do not need useme, it would be the last one.
KernelPool& KernelPool::Instance() { KernelPool& KernelPool::Instance() {
static KernelPool g_kernel_pool; static KernelPool g_kernel_pool;
return g_kernel_pool; return g_kernel_pool;
} }
ReferKernelPool& ReferKernelPool::Instance() {
static ReferKernelPool g_refer_kernel_pool;
return g_refer_kernel_pool;
}
} // namespace jitkernels } // namespace jitkernels
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -18,22 +18,21 @@ ...@@ -18,22 +18,21 @@
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/operators/jitkernels/jitcode_base.h" #include "paddle/fluid/operators/jitkernels/jitcode_base.h"
#include "paddle/fluid/operators/jitkernels/kernel_base.h" #include "paddle/fluid/operators/jitkernels/kernel_base.h"
#include "paddle/fluid/operators/jitkernels/kernel_key.h" #include "paddle/fluid/operators/jitkernels/kernel_key.h"
#include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_XBYAK
#include "paddle/fluid/operators/jitkernels/jitcode/jitcode.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace jitkernels { namespace jitkernels {
// TODO(TJ): rename file to kernel_pool
template <KernelType KT> template <KernelType KT>
class JitCodePool { class JitCodePool {
public: public:
JitCodePool() = default;
static JitCodePool& Instance() { static JitCodePool& Instance() {
static thread_local JitCodePool<KT> g_jit_codes; static thread_local JitCodePool<KT> g_jit_codes;
return g_jit_codes; return g_jit_codes;
...@@ -51,13 +50,11 @@ class JitCodePool { ...@@ -51,13 +50,11 @@ class JitCodePool {
} }
private: private:
JitCodePool() = default;
std::unordered_map<size_t, std::shared_ptr<const JitBase>> codes_; std::unordered_map<size_t, std::shared_ptr<const JitBase>> codes_;
DISABLE_COPY_AND_ASSIGN(JitCodePool); DISABLE_COPY_AND_ASSIGN(JitCodePool);
}; };
// std::tuple<T, Func, Attr> // TODO(TJ): std::tuple<T, Func, Attr>
template <typename T, typename Func, typename Attr> template <typename T, typename Func, typename Attr>
struct KernelAttr { struct KernelAttr {
typedef T data_type; typedef T data_type;
...@@ -65,76 +62,99 @@ struct KernelAttr { ...@@ -65,76 +62,99 @@ struct KernelAttr {
typedef Attr attr_type; typedef Attr attr_type;
}; };
typedef std::unique_ptr<const Kernel> KernelPtr;
typedef std::unordered_map<KernelKey, std::vector<KernelPtr>, KernelKey::Hash>
KernelMap;
class KernelPool { class KernelPool {
public: public:
static KernelPool& Instance(); static KernelPool& Instance();
KernelPool() = default;
typedef std::unique_ptr<const Kernel> KernelPtr;
typedef std::unordered_map<KernelKey, std::vector<KernelPtr>, KernelKey::Hash>
KernelMap;
KernelMap& AllKernels() { return pool_; } KernelMap& AllKernels() { return pool_; }
void Insert(const KernelKey& key, KernelPtr value) { void Insert(const KernelKey& key, KernelPtr value) {
if (pool_.find(key) == pool_.end()) { if (pool_.find(key) == pool_.end()) {
pool_.emplace(key, std::vector<KernelPtr>()); pool_.emplace(key, std::vector<KernelPtr>());
} }
pool_.at(key).emplace_back(std::move(value)); pool_.at(key).emplace_back(std::move(value));
} }
KernelPool() = default;
private: private:
KernelMap pool_; KernelMap pool_;
DISABLE_COPY_AND_ASSIGN(KernelPool); DISABLE_COPY_AND_ASSIGN(KernelPool);
}; };
// TODO(TJ): create_jitcode; // Every kernel should have refer code and it should be used in unit tests,
// so refer kernels should have it's independent kernel pool
class ReferKernelPool {
public:
static ReferKernelPool& Instance();
ReferKernelPool() = default;
KernelMap& AllKernels() { return pool_; }
void Insert(const KernelKey& key, KernelPtr value) {
if (pool_.find(key) == pool_.end()) {
pool_.emplace(key, std::vector<KernelPtr>());
}
pool_.at(key).emplace_back(std::move(value));
}
private:
KernelMap pool_;
DISABLE_COPY_AND_ASSIGN(ReferKernelPool);
};
// Refer code do not related with attr, and always on CPUPlace
template <KernelType KT, typename T, typename Func, typename Attr>
inline Func GetRefer() {
auto& ref_pool = ReferKernelPool().Instance().AllKernels();
KernelKey kkey(KT, platform::CPUPlace());
auto ref_iter = ref_pool.find(kkey);
PADDLE_ENFORCE(ref_iter != ref_pool.end(),
"Every Kernel should have reference function.");
auto& ref_impls = ref_iter->second;
for (auto& impl : ref_impls) {
auto i = dynamic_cast<const ReferKernel<T, Func, Attr>*>(impl.get());
if (i) {
return i->GetFunc();
}
}
return nullptr;
}
// TODO(TJ): make tuple? named KernelAttr // TODO(TJ): make tuple? named KernelAttr
template <KernelType KT, typename T, typename Func, typename Attr, template <KernelType KT, typename T, typename Func, typename Attr,
typename PlaceType = platform::CPUPlace> typename PlaceType = platform::CPUPlace>
Func Get(Attr attr) { Func Get(Attr attr) {
size_t key = GetKey<Attr>(attr); // size_t key = GetKey<Attr>(attr);
auto jitcode = JitCodePool<KT>().Instance().Get(key); // auto jitcode = JitCodePool<KT>().Instance().Get(key);
if (jitcode) { // if (jitcode) {
return jitcode->template getCode<Func>(); // return jitcode->template getCode<Func>();
// }
if (std::is_same<PlaceType, platform::CPUPlace>::value &&
std::is_same<T, float>::value) { // TODO(TJ): float move to create
// auto p = CreateJitCode<KT, Attr>(attr);
// if (p) {
// JitCodePool<KT>().Instance().Insert(key, p);
// return p->template getCode<Func>();
// }
} }
#ifdef PADDLE_WITH_XBYAK // pool: (KernelKey(type, place), vector<Kernel>)
// // jitcode::JitCode is under protection of PADDLE_WITH_XBYAK
// if (std::is_same<PlaceType, platform::CPUPlace>::value) {
// if (UseJitCode<KT, T, Attr>(attr)) {
// std::shared_ptr<JitBase> p(std::make_shared<jitcode::JitCode<KT, Attr>>(
// attr, CodeSize<KT, Attr>(attr)));
// JitCodePool<KT>().Instance().Insert(key, p);
// return p->getCode<Func>();
// }
// }
#endif
// (KernelKey(type, place), vector<Kernel>)
auto& pool = KernelPool().Instance().AllKernels(); auto& pool = KernelPool().Instance().AllKernels();
KernelKey kkey(KT, PlaceType()); KernelKey kkey(KT, PlaceType());
auto iter = pool.find(kkey); auto iter = pool.find(kkey);
if (iter != pool.end()) { if (iter != pool.end()) {
auto impls = iter->second; auto& impls = iter->second;
for (auto impl : impls) { for (auto& impl : impls) {
auto i = std::dynamic_pointer_cast<KernelImpl<T, Func, Attr>>(impl.get()); auto i = dynamic_cast<const KernelImpl<T, Func, Attr>*>(impl.get());
if (i && i->UseMe(attr)) { if (i && i->UseMe(attr)) {
return i->GetFunc(); return i->GetFunc();
} }
} }
} }
// The last implementation should be reference function on CPU // The last implementation should be reference function on CPUPlace.
// Every kernel should have refer code. return GetRefer<KT, T, Func, Attr>();
// because of test refer should have it's own pool
// PADDLE_ENFORCE_GT(list.size(), 1) << "Should have refer implemtation";
// const auto& refer = KernelRefer<KT, T>().AllKernels();
// return refer.Get<Func>();
return nullptr;
} }
} // namespace jitkernels } // namespace jitkernels
......
...@@ -17,4 +17,5 @@ ...@@ -17,4 +17,5 @@
namespace refer = paddle::operators::jitkernels::refer; namespace refer = paddle::operators::jitkernels::refer;
// REGISTER_JITKERNEL_REFER(vmul, refer::VMul<float>, refer::VMul<double>); REGISTER_JITKERNEL_REFER(vmul, refer::VMulKernel<float>,
refer::VMulKernel<double>);
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
* limitations under the License. */ * limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/jitkernels/kernel_base.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
...@@ -27,6 +28,13 @@ void VMul(const T* x, const T* y, T* z, int n) { ...@@ -27,6 +28,13 @@ void VMul(const T* x, const T* y, T* z, int n) {
} }
} }
template <typename T>
class VMulKernel
: public ReferKernel<T, void (*)(const T*, const T*, T*, int), int> {
public:
VMulKernel() { this->func = VMul<T>; }
};
} // namespace refer } // namespace refer
} // namespace jitkernels } // namespace jitkernels
} // namespace operators } // namespace operators
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "paddle/fluid/operators/jitkernels/kernel_base.h" #include "paddle/fluid/operators/jitkernels/kernel_base.h"
#include "paddle/fluid/operators/jitkernels/kernels.h" #include "paddle/fluid/operators/jitkernels/kernels.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/variant.h" // for UNUSED
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -32,37 +33,40 @@ inline std::unique_ptr<T> make_unique(Args&&... args) { ...@@ -32,37 +33,40 @@ inline std::unique_ptr<T> make_unique(Args&&... args) {
return std::unique_ptr<T>(new T(std::forward<Args>(args)...)); return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
} }
template <typename PlaceType, bool IsEnd, size_t I, typename... KernelImpls> template <typename Pool, typename PlaceType, bool IsEnd, size_t I,
typename... KernelImpls>
struct JitKernelRegistrarFunctor; struct JitKernelRegistrarFunctor;
template <typename PlaceType, size_t I, typename... KernelImpls> template <typename Pool, typename PlaceType, size_t I, typename... KernelImpls>
struct JitKernelRegistrarFunctor<PlaceType, true, I, KernelImpls...> { struct JitKernelRegistrarFunctor<Pool, PlaceType, true, I, KernelImpls...> {
void operator()(KernelType kt) const {} void operator()(KernelType kt) const {}
}; };
template <typename PlaceType, size_t I, typename... KernelImpls> template <typename Pool, typename PlaceType, size_t I, typename... KernelImpls>
struct JitKernelRegistrarFunctor<PlaceType, false, I, KernelImpls...> { struct JitKernelRegistrarFunctor<Pool, PlaceType, false, I, KernelImpls...> {
using KERNEL_IMPL_TYPE = using KERNEL_IMPL_TYPE =
typename std::tuple_element<I, std::tuple<KernelImpls...>>::type; typename std::tuple_element<I, std::tuple<KernelImpls...>>::type;
void operator()(KernelType kt) const { void operator()(KernelType kt) const {
KernelKey kkey(kt, PlaceType()); KernelKey kkey(kt, PlaceType());
KernelPool().Instance().Insert( Pool().Instance().Insert(kkey,
kkey, std::move(make_unique<const KERNEL_IMPL_TYPE>())); std::move(make_unique<const KERNEL_IMPL_TYPE>()));
constexpr auto size = std::tuple_size<std::tuple<KernelImpls...>>::value; constexpr auto size = std::tuple_size<std::tuple<KernelImpls...>>::value;
JitKernelRegistrarFunctor<PlaceType, I + 1 == size, I + 1, KernelImpls...> JitKernelRegistrarFunctor<Pool, PlaceType, I + 1 == size, I + 1,
KernelImpls...>
func; func;
func(kt); func(kt);
} }
}; };
template <typename PlaceType, typename... KernelImpls> template <typename Pool, typename PlaceType, typename... KernelImpls>
class JitKernelRegistrar { class JitKernelRegistrar {
public: public:
explicit JitKernelRegistrar(KernelType kt) { explicit JitKernelRegistrar(KernelType kt) {
JitKernelRegistrarFunctor<PlaceType, false, 0, KernelImpls...> func; JitKernelRegistrarFunctor<Pool, PlaceType, false, 0, KernelImpls...> func;
func(kt); func(kt);
} }
void Touch() {}
}; };
#define STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE(uniq_name, msg) \ #define STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE(uniq_name, msg) \
...@@ -71,17 +75,40 @@ class JitKernelRegistrar { ...@@ -71,17 +75,40 @@ class JitKernelRegistrar {
__test_global_namespace_##uniq_name##__>::value, \ __test_global_namespace_##uniq_name##__>::value, \
msg) msg)
// Refer always on CPUPlace
#define REGISTER_JITKERNEL_REFER(kernel_type, ...) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_##kernel_type##_refer_CPUPlace, \
"REGISTER_KERNEL_REFER must be called in global namespace"); \
static ::paddle::operators::jitkernels::JitKernelRegistrar< \
::paddle::operators::jitkernels::ReferKernelPool, \
::paddle::platform::CPUPlace, __VA_ARGS__> \
__jit_kernel_registrar_##kernel_type##_refer_CPUPlace_( \
::paddle::operators::jitkernels::KernelType::kernel_type); \
int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_() { \
__jit_kernel_registrar_##kernel_type##_refer_CPUPlace_.Touch(); \
return 0; \
}
// kernel_type: should be in paddle::operators::jitkernels::KernelType // kernel_type: should be in paddle::operators::jitkernels::KernelType
// place_type: should be one of CPUPlace and GPUPlace in paddle::platform // place_type: should be one of CPUPlace and GPUPlace in paddle::platform
#define REGISTER_KERNEL_MORE(kernel_type, impl_type, place_type, ...) \ #define REGISTER_KERNEL_MORE(kernel_type, impl_type, place_type, ...) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \ STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_##kernel_type##_##impl_type##_##place_type, \ __reg_jitkernel_##kernel_type##_##impl_type##_##place_type, \
"REGISTER_KERNEL_MORE must be called in global namespace"); \ "REGISTER_KERNEL_MORE must be called in global namespace"); \
extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static int __assert_##kernel_type##_##impl_type##_##place_type##_has_refer_ \
UNUSED = TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static ::paddle::operators::jitkernels::JitKernelRegistrar< \ static ::paddle::operators::jitkernels::JitKernelRegistrar< \
::paddle::operators::jitkernels::KernelPool, \
::paddle::platform::place_type, __VA_ARGS__> \ ::paddle::platform::place_type, __VA_ARGS__> \
__jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##__( \ __jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##_( \
::paddle::operators::jitkernels::KernelType::kernel_type) ::paddle::operators::jitkernels::KernelType::kernel_type); \
// TODO(TJ): Add Touch and use me int TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_() { \
__jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##_ \
.Touch(); \
return 0; \
}
#define REGISTER_JITKERNEL_MORE(kernel_type, impl_type, ...) \ #define REGISTER_JITKERNEL_MORE(kernel_type, impl_type, ...) \
REGISTER_KERNEL_MORE(kernel_type, impl_type, CPUPlace, __VA_ARGS__) REGISTER_KERNEL_MORE(kernel_type, impl_type, CPUPlace, __VA_ARGS__)
...@@ -89,45 +116,28 @@ class JitKernelRegistrar { ...@@ -89,45 +116,28 @@ class JitKernelRegistrar {
#define REGISTER_GPUKERNEL_MORE(kernel_type, impl_type, ...) \ #define REGISTER_GPUKERNEL_MORE(kernel_type, impl_type, ...) \
REGISTER_KERNEL_MORE(kernel_type, impl_type, GPUPlace, __VA_ARGS__) REGISTER_KERNEL_MORE(kernel_type, impl_type, GPUPlace, __VA_ARGS__)
/* // REGISTER_JITKERNEL_JITCODE(vmul, JitKernelCode<vmul, int>);
REGISTER_JITKERNEL_JITCODE(vmul, JitKernelCode<vmul, int>);
// refer must be only one and at least one
REGISTER_JITKERNEL_REFER(vmul, VMul); // Refer need support dtype
// you can register more implementations and the condition when use it #define USE_JITKERNEL_REFER(kernel_type) \
REGISTER_JITKERNEL_MORE(vmul, mkl::VMUL<float>, UseMe<float>, mkl::VMUL<double>, STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
UseMe<double>) __reg_jitkernel_##kernel_type##_refer_CPUPlace_, \
"USE_JITKERNEL_REFER must be called in global namespace"); \
#define STATIC_ASSERT_PASS_GLOBAL_NAMESPACE(uniq_name, msg) \ extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
struct __test_global_namespace_##uniq_name##__ {}; \ static int use_jitkernel_##kernel_type##_refer_CPUPlace_ UNUSED = \
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \ TouchJitKernelReg_##kernel_type##_refer_CPUPlace_()
__test_global_namespace_##uniq_name##__>::value, \
msg)
// Register a new pass that can be applied on the IR. #define USE_KERNEL_MORE(kernel_type, impl_type, place_type) \
#define REGISTER_PASS(pass_type, pass_class) \ STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \ __reg_jitkernel_##kernel_type##_##impl_type##_##place_type##_, \
__reg_pass__##pass_type, \ "USE_JITKERNEL_MORE must be called in global namespace"); \
"REGISTER_PASS must be called in global namespace"); \ extern int \
static ::paddle::framework::ir::PassRegistrar<pass_class> \ TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_(); \
__pass_registrar_##pass_type##__(#pass_type); \ static int use_jitkernel_##kernel_type##_##impl_type##_##place_type##_ \
int TouchPassRegistrar_##pass_type() { \ UNUSED = \
__pass_registrar_##pass_type##__.Touch(); \ TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_()
return 0; \
} \ #define USE_JITKERNEL_MORE(kernel_type, impl_type) \
static ::paddle::framework::ir::PassRegistrar<pass_class>& \ USE_KERNEL_MORE(kernel_type, impl_type, CPUPlace)
__pass_tmp_registrar_##pass_type##__ UNUSED = \
__pass_registrar_##pass_type##__
#define USE_PASS(pass_type) \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
__use_pass_itself_##pass_type, \
"USE_PASS must be called in global namespace"); \
extern int TouchPassRegistrar_##pass_type(); \
static int use_pass_itself_##pass_type##_ UNUSED = \
TouchPassRegistrar_##pass_type()
*/
} // namespace jitkernels } // namespace jitkernels
} // namespace operators } // namespace operators
......
...@@ -19,8 +19,11 @@ ...@@ -19,8 +19,11 @@
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "glog/logging.h" #include "glog/logging.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/operators/math/jit_kernel.h" #include "paddle/fluid/operators/jitkernels/kernels.h"
#include "paddle/fluid/operators/math/jit_kernel_refer.h" // TODO(TJ): remove me
#include "paddle/fluid/operators/jitkernels/registry.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/port.h" #include "paddle/fluid/platform/port.h"
constexpr int repeat = 20000; constexpr int repeat = 20000;
...@@ -31,6 +34,75 @@ inline double GetCurrentUS() { ...@@ -31,6 +34,75 @@ inline double GetCurrentUS() {
return 1e+6 * time.tv_sec + time.tv_usec; return 1e+6 * time.tv_sec + time.tv_usec;
} }
TEST(JitKernel, vmul) {} template <typename T>
void RandomVec(const int n, T* a, const T lower = static_cast<T>(-20.f),
const T upper = static_cast<T>(20.f)) {
static unsigned int seed = 100;
std::mt19937 rng(seed++);
std::uniform_real_distribution<double> uniform_dist(0, 1);
for (int i = 0; i < n; ++i) {
a[i] = static_cast<T>(uniform_dist(rng) * (upper - lower) + lower);
}
}
template <typename T>
void ExpectEQ(const T* target, const T* refer, int n) {
if (std::is_floating_point<T>::value) {
for (int i = 0; i < n; ++i) {
EXPECT_NEAR(target[i], refer[i], 1e-3);
}
} else {
for (int i = 0; i < n; ++i) {
EXPECT_EQ(target[i], refer[i]);
}
}
}
// TODO(TJ): remove me
USE_JITKERNEL_MORE(vmul, mkl);
USE_JITKERNEL_REFER(vmul);
TEST(JitKernel, vmul) {
using T = float;
using PlaceType = paddle::platform::CPUPlace;
namespace jit = paddle::operators::jitkernels;
// TODO(TJ): test more vector size
for (int d = 1; d < 30; ++d) {
auto ref = jit::GetRefer<jit::vmul, T,
void (*)(const T*, const T*, T*, int), int>();
auto tgt = jit::Get<jit::vmul, T, void (*)(const T*, const T*, T*, int),
int, PlaceType>(d);
EXPECT_TRUE(ref != nullptr);
EXPECT_TRUE(tgt != nullptr);
std::vector<T> x(d), y(d);
std::vector<T> zref(d), ztgt(d);
RandomVec<T>(d, x.data());
RandomVec<T>(d, y.data());
const float* x_data = x.data();
const float* y_data = y.data();
float* ztgt_data = ztgt.data();
float* zref_data = zref.data();
tgt(x_data, y_data, ztgt_data, d);
ref(x_data, y_data, zref_data, d);
ExpectEQ<T>(ztgt_data, zref_data, d);
// test inplace x
std::copy(x.begin(), x.end(), zref.begin());
std::copy(x.begin(), x.end(), ztgt.begin());
tgt(ztgt_data, y_data, ztgt_data, d);
ref(zref_data, y_data, zref_data, d);
ExpectEQ<T>(ztgt_data, zref_data, d);
// test inplace y
std::copy(y.begin(), y.end(), zref.begin());
std::copy(y.begin(), y.end(), ztgt.begin());
tgt(x_data, ztgt_data, ztgt_data, d);
ref(x_data, zref_data, zref_data, d);
ExpectEQ<T>(ztgt_data, zref_data, d);
}
}
TEST(JitKernel, pool) {} TEST(JitKernel, pool) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册