提交 a3c67b2f 编写于 作者: 石晓伟 提交者: GitHub

Jit macro definition ambiguity fix, test=develop (#2713)

上级 ff13e3af
......@@ -89,7 +89,7 @@ All kernels are inlcuded in `lite/backends/x86/jit/kernels.h`, which is automati
3. Add reference function of `your_key`.
Note:
- this should be run on CPU and do not depend on any third-party.
- Add `USE_JITKERNEL_REFER(your_key)` in `refer/CmakeLists.txt` to make sure this code can be used.
- Add `USE_JITKERNEL_REFER_LITE(your_key)` in `refer/CmakeLists.txt` to make sure this code can be used.
4. Add unit test in `test.cc`, and verfiy at least `float` and `double`.
Test more data type for some special functions if necessary, for example `int8`.
5. Add functions in `benchmark.cc` to test all function of same `KernelType`. Make sure `GetDefaultBestFunc` always get the best one.
......
......@@ -79,7 +79,7 @@ PaddlePaddle/Paddle/paddle/fluid/
# 如何添加新的算子
1.`KernelType` 中添加 `your_key`
2. 实现Reference 的逻辑,这个是必须是在CPU上的实现,并且不能依赖任何第三方库。实现后在`refer/CmakeLists.txt`中添加`USE_JITKERNEL_REFER(your_key)`来使用该kernel。
2. 实现Reference 的逻辑,这个是必须是在CPU上的实现,并且不能依赖任何第三方库。实现后在`refer/CmakeLists.txt`中添加`USE_JITKERNEL_REFER_LITE(your_key)`来使用该kernel。
3. (optional) 实现更多的算法在`more`目录下,可以依赖mkl,intrinsic或者mkldnn等第三方库。
4. (optional) 实现基于Xbyak的生成code,在`gen`目下。 jitcode需要实现自己的`JitCodeCreator`,并注册在与refer相同的`KernelType`上。
5. 添加新的`KernelTuple`,需要与`KernelType`一一对应,是所有类型的一个打包,包括数据类型,属性的类型,以及返回的函数类型。可以参考`SeqPoolTuple`,新加的Attr类型需要特例化`JitCodeKey`方法。
......
......@@ -4,33 +4,33 @@ file(GLOB jitcode_cc_srcs RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cc")
cc_library(jit_kernel_jitcode SRCS ${jitcode_cc_srcs} DEPS jit_kernel_base xbyak)
set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} xbyak jit_kernel_jitcode PARENT_SCOPE)
function(USE_JITKERNEL_GEN TARGET)
file(APPEND ${jit_file} "USE_JITKERNEL_GEN(${TARGET});\n")
function(USE_JITKERNEL_GEN_LITE TARGET)
file(APPEND ${jit_file} "USE_JITKERNEL_GEN_LITE(${TARGET});\n")
endfunction()
# use gen jitcode kernel by name
USE_JITKERNEL_GEN(kMatMul)
USE_JITKERNEL_GEN(kVMul)
USE_JITKERNEL_GEN(kVAdd)
USE_JITKERNEL_GEN(kVSub)
USE_JITKERNEL_GEN(kVAddRelu)
USE_JITKERNEL_GEN(kVScal)
USE_JITKERNEL_GEN(kVAddBias)
USE_JITKERNEL_GEN(kVRelu)
USE_JITKERNEL_GEN(kVSquare)
USE_JITKERNEL_GEN(kVIdentity)
USE_JITKERNEL_GEN(kVExp)
USE_JITKERNEL_GEN(kVSigmoid)
USE_JITKERNEL_GEN(kVTanh)
USE_JITKERNEL_GEN(kLSTMCtHt)
USE_JITKERNEL_GEN(kLSTMC1H1)
USE_JITKERNEL_GEN(kGRUH1)
USE_JITKERNEL_GEN(kGRUHtPart1)
USE_JITKERNEL_GEN(kGRUHtPart2)
USE_JITKERNEL_GEN(kNCHW16CMulNC)
USE_JITKERNEL_GEN(kSeqPool)
USE_JITKERNEL_GEN(kHMax)
USE_JITKERNEL_GEN(kHSum)
USE_JITKERNEL_GEN(kEmbSeqPool)
USE_JITKERNEL_GEN(kSgd)
USE_JITKERNEL_GEN(kVBroadcast)
USE_JITKERNEL_GEN_LITE(kMatMul)
USE_JITKERNEL_GEN_LITE(kVMul)
USE_JITKERNEL_GEN_LITE(kVAdd)
USE_JITKERNEL_GEN_LITE(kVSub)
USE_JITKERNEL_GEN_LITE(kVAddRelu)
USE_JITKERNEL_GEN_LITE(kVScal)
USE_JITKERNEL_GEN_LITE(kVAddBias)
USE_JITKERNEL_GEN_LITE(kVRelu)
USE_JITKERNEL_GEN_LITE(kVSquare)
USE_JITKERNEL_GEN_LITE(kVIdentity)
USE_JITKERNEL_GEN_LITE(kVExp)
USE_JITKERNEL_GEN_LITE(kVSigmoid)
USE_JITKERNEL_GEN_LITE(kVTanh)
USE_JITKERNEL_GEN_LITE(kLSTMCtHt)
USE_JITKERNEL_GEN_LITE(kLSTMC1H1)
USE_JITKERNEL_GEN_LITE(kGRUH1)
USE_JITKERNEL_GEN_LITE(kGRUHtPart1)
USE_JITKERNEL_GEN_LITE(kGRUHtPart2)
USE_JITKERNEL_GEN_LITE(kNCHW16CMulNC)
USE_JITKERNEL_GEN_LITE(kSeqPool)
USE_JITKERNEL_GEN_LITE(kHMax)
USE_JITKERNEL_GEN_LITE(kHSum)
USE_JITKERNEL_GEN_LITE(kEmbSeqPool)
USE_JITKERNEL_GEN_LITE(kSgd)
USE_JITKERNEL_GEN_LITE(kVBroadcast)
......@@ -156,9 +156,9 @@ size_t VTanhCreator::CodeSize(const int& d) const {
namespace gen = paddle::lite::jit::gen;
REGISTER_JITKERNEL_GEN(kVRelu, gen::VReluCreator);
REGISTER_JITKERNEL_GEN(kVSquare, gen::VSquareCreator);
REGISTER_JITKERNEL_GEN(kVIdentity, gen::VIdentityCreator);
REGISTER_JITKERNEL_GEN(kVExp, gen::VExpCreator);
REGISTER_JITKERNEL_GEN(kVSigmoid, gen::VSigmoidCreator);
REGISTER_JITKERNEL_GEN(kVTanh, gen::VTanhCreator);
REGISTER_JITKERNEL_GEN_LITE(kVRelu, gen::VReluCreator);
REGISTER_JITKERNEL_GEN_LITE(kVSquare, gen::VSquareCreator);
REGISTER_JITKERNEL_GEN_LITE(kVIdentity, gen::VIdentityCreator);
REGISTER_JITKERNEL_GEN_LITE(kVExp, gen::VExpCreator);
REGISTER_JITKERNEL_GEN_LITE(kVSigmoid, gen::VSigmoidCreator);
REGISTER_JITKERNEL_GEN_LITE(kVTanh, gen::VTanhCreator);
......@@ -181,10 +181,10 @@ DECLARE_BLAS_CREATOR(VAddBias);
namespace gen = paddle::lite::jit::gen;
REGISTER_JITKERNEL_GEN(kVMul, gen::VMulCreator);
REGISTER_JITKERNEL_GEN(kVAdd, gen::VAddCreator);
REGISTER_JITKERNEL_GEN(kVSub, gen::VSubCreator);
REGISTER_JITKERNEL_GEN(kVAddRelu, gen::VAddReluCreator);
REGISTER_JITKERNEL_GEN(kVScal, gen::VScalCreator);
REGISTER_JITKERNEL_GEN(kVAddBias, gen::VAddBiasCreator);
REGISTER_JITKERNEL_GEN(kNCHW16CMulNC, gen::NCHW16CMulNCCreator);
REGISTER_JITKERNEL_GEN_LITE(kVMul, gen::VMulCreator);
REGISTER_JITKERNEL_GEN_LITE(kVAdd, gen::VAddCreator);
REGISTER_JITKERNEL_GEN_LITE(kVSub, gen::VSubCreator);
REGISTER_JITKERNEL_GEN_LITE(kVAddRelu, gen::VAddReluCreator);
REGISTER_JITKERNEL_GEN_LITE(kVScal, gen::VScalCreator);
REGISTER_JITKERNEL_GEN_LITE(kVAddBias, gen::VAddBiasCreator);
REGISTER_JITKERNEL_GEN_LITE(kNCHW16CMulNC, gen::NCHW16CMulNCCreator);
......@@ -145,4 +145,4 @@ class EmbSeqPoolCreator : public JitCodeCreator<emb_seq_pool_attr_t> {
namespace gen = paddle::lite::jit::gen;
REGISTER_JITKERNEL_GEN(kEmbSeqPool, gen::EmbSeqPoolCreator);
REGISTER_JITKERNEL_GEN_LITE(kEmbSeqPool, gen::EmbSeqPoolCreator);
......@@ -111,6 +111,6 @@ DECLARE_GRU_CREATOR(GRUHtPart2);
namespace gen = paddle::lite::jit::gen;
REGISTER_JITKERNEL_GEN(kGRUH1, gen::GRUH1Creator);
REGISTER_JITKERNEL_GEN(kGRUHtPart1, gen::GRUHtPart1Creator);
REGISTER_JITKERNEL_GEN(kGRUHtPart2, gen::GRUHtPart2Creator);
REGISTER_JITKERNEL_GEN_LITE(kGRUH1, gen::GRUH1Creator);
REGISTER_JITKERNEL_GEN_LITE(kGRUHtPart1, gen::GRUHtPart1Creator);
REGISTER_JITKERNEL_GEN_LITE(kGRUHtPart2, gen::GRUHtPart2Creator);
......@@ -99,5 +99,5 @@ DECLARE_HOP_CREATOR(HSum);
namespace gen = paddle::lite::jit::gen;
REGISTER_JITKERNEL_GEN(kHMax, gen::HMaxCreator);
REGISTER_JITKERNEL_GEN(kHSum, gen::HSumCreator);
REGISTER_JITKERNEL_GEN_LITE(kHMax, gen::HMaxCreator);
REGISTER_JITKERNEL_GEN_LITE(kHSum, gen::HSumCreator);
......@@ -138,5 +138,5 @@ DECLARE_LSTM_CREATOR(LSTMC1H1);
namespace gen = paddle::lite::jit::gen;
REGISTER_JITKERNEL_GEN(kLSTMCtHt, gen::LSTMCtHtCreator);
REGISTER_JITKERNEL_GEN(kLSTMC1H1, gen::LSTMC1H1Creator);
REGISTER_JITKERNEL_GEN_LITE(kLSTMCtHt, gen::LSTMCtHtCreator);
REGISTER_JITKERNEL_GEN_LITE(kLSTMC1H1, gen::LSTMC1H1Creator);
......@@ -130,4 +130,4 @@ class MatMulCreator : public JitCodeCreator<matmul_attr_t> {
namespace gen = paddle::lite::jit::gen;
REGISTER_JITKERNEL_GEN(kMatMul, gen::MatMulCreator);
REGISTER_JITKERNEL_GEN_LITE(kMatMul, gen::MatMulCreator);
......@@ -82,4 +82,4 @@ class SeqPoolCreator : public JitCodeCreator<seq_pool_attr_t> {
namespace gen = paddle::lite::jit::gen;
REGISTER_JITKERNEL_GEN(kSeqPool, gen::SeqPoolCreator);
REGISTER_JITKERNEL_GEN_LITE(kSeqPool, gen::SeqPoolCreator);
......@@ -127,4 +127,4 @@ class SgdCreator : public JitCodeCreator<sgd_attr_t> {
namespace gen = paddle::lite::jit::gen;
REGISTER_JITKERNEL_GEN(kSgd, gen::SgdCreator);
REGISTER_JITKERNEL_GEN_LITE(kSgd, gen::SgdCreator);
......@@ -88,4 +88,4 @@ class VBroadcastCreator : public JitCodeCreator<int64_t> {
namespace gen = paddle::lite::jit::gen;
REGISTER_JITKERNEL_GEN(kVBroadcast, gen::VBroadcastCreator);
REGISTER_JITKERNEL_GEN_LITE(kVBroadcast, gen::VBroadcastCreator);
function(USE_JITKERNEL_MORE TARGET TYPE)
file(APPEND ${jit_file} "USE_JITKERNEL_MORE(${TARGET} ${TYPE});\n")
function(USE_JITKERNEL_MORE_LITE TARGET TYPE)
file(APPEND ${jit_file} "USE_JITKERNEL_MORE_LITE(${TARGET} ${TYPE});\n")
endfunction()
# enable it latter
......
......@@ -5,5 +5,5 @@ cc_library(jit_kernel_intrinsic SRCS ${jit_kernel_cc_intrinsic} DEPS jit_kernel_
set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} jit_kernel_intrinsic PARENT_SCOPE)
# use mkl kernels by name and type
USE_JITKERNEL_MORE(kCRFDecoding, intrinsic)
USE_JITKERNEL_MORE(kLayerNorm, intrinsic)
USE_JITKERNEL_MORE_LITE(kCRFDecoding, intrinsic)
USE_JITKERNEL_MORE_LITE(kLayerNorm, intrinsic)
......@@ -5,11 +5,11 @@ cc_library(jit_kernel_mix SRCS ${jit_kernel_mix_cc} DEPS jit_kernel_base)
set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} jit_kernel_mix PARENT_SCOPE)
USE_JITKERNEL_MORE(kVSigmoid, mix)
USE_JITKERNEL_MORE(kVTanh, mix)
USE_JITKERNEL_MORE(kLSTMCtHt, mix)
USE_JITKERNEL_MORE(kLSTMC1H1, mix)
USE_JITKERNEL_MORE(kGRUH1, mix)
USE_JITKERNEL_MORE(kGRUHtPart1, mix)
USE_JITKERNEL_MORE(kGRUHtPart2, mix)
USE_JITKERNEL_MORE(kSoftmax, mix)
USE_JITKERNEL_MORE_LITE(kVSigmoid, mix)
USE_JITKERNEL_MORE_LITE(kVTanh, mix)
USE_JITKERNEL_MORE_LITE(kLSTMCtHt, mix)
USE_JITKERNEL_MORE_LITE(kLSTMC1H1, mix)
USE_JITKERNEL_MORE_LITE(kGRUH1, mix)
USE_JITKERNEL_MORE_LITE(kGRUHtPart1, mix)
USE_JITKERNEL_MORE_LITE(kGRUHtPart2, mix)
USE_JITKERNEL_MORE_LITE(kSoftmax, mix)
......@@ -3,18 +3,18 @@ cc_library(jit_kernel_mkl SRCS mkl.cc DEPS jit_kernel_base dynload_mklml)
set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} dynload_mklml jit_kernel_mkl PARENT_SCOPE)
# use mkl kernels by name and type
USE_JITKERNEL_MORE(kMatMul, mkl)
USE_JITKERNEL_MORE(kVMul, mkl)
USE_JITKERNEL_MORE(kVAdd, mkl)
USE_JITKERNEL_MORE(kVScal, mkl)
USE_JITKERNEL_MORE(kStrideScal, mkl)
USE_JITKERNEL_MORE(kVExp, mkl)
USE_JITKERNEL_MORE(kVSquare, mkl)
USE_JITKERNEL_MORE(kVCopy, mkl)
USE_JITKERNEL_MORE(kVSigmoid, mkl)
USE_JITKERNEL_MORE(kVTanh, mkl)
USE_JITKERNEL_MORE(kSeqPool, mkl)
USE_JITKERNEL_MORE(kSoftmax, mkl)
USE_JITKERNEL_MORE(kEmbSeqPool, mkl)
USE_JITKERNEL_MORE(kSgd, mkl)
USE_JITKERNEL_MORE(kVBroadcast, mkl)
USE_JITKERNEL_MORE_LITE(kMatMul, mkl)
USE_JITKERNEL_MORE_LITE(kVMul, mkl)
USE_JITKERNEL_MORE_LITE(kVAdd, mkl)
USE_JITKERNEL_MORE_LITE(kVScal, mkl)
USE_JITKERNEL_MORE_LITE(kStrideScal, mkl)
USE_JITKERNEL_MORE_LITE(kVExp, mkl)
USE_JITKERNEL_MORE_LITE(kVSquare, mkl)
USE_JITKERNEL_MORE_LITE(kVCopy, mkl)
USE_JITKERNEL_MORE_LITE(kVSigmoid, mkl)
USE_JITKERNEL_MORE_LITE(kVTanh, mkl)
USE_JITKERNEL_MORE_LITE(kSeqPool, mkl)
USE_JITKERNEL_MORE_LITE(kSoftmax, mkl)
USE_JITKERNEL_MORE_LITE(kEmbSeqPool, mkl)
USE_JITKERNEL_MORE_LITE(kSgd, mkl)
USE_JITKERNEL_MORE_LITE(kVBroadcast, mkl)
......@@ -2,39 +2,39 @@
cc_library(jit_kernel_refer SRCS refer.cc DEPS jit_kernel_base)
set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} jit_kernel_refer PARENT_SCOPE)
function(USE_JITKERNEL_REFER TARGET)
file(APPEND ${jit_file} "USE_JITKERNEL_REFER(${TARGET});\n")
function(USE_JITKERNEL_REFER_LITE TARGET)
file(APPEND ${jit_file} "USE_JITKERNEL_REFER_LITE(${TARGET});\n")
endfunction()
# use refer kernel by name
USE_JITKERNEL_REFER(kVMul)
USE_JITKERNEL_REFER(kVAdd)
USE_JITKERNEL_REFER(kVAddRelu)
USE_JITKERNEL_REFER(kVSub)
USE_JITKERNEL_REFER(kVScal)
USE_JITKERNEL_REFER(kStrideScal)
USE_JITKERNEL_REFER(kVAddBias)
USE_JITKERNEL_REFER(kVCopy)
USE_JITKERNEL_REFER(kVRelu)
USE_JITKERNEL_REFER(kVIdentity)
USE_JITKERNEL_REFER(kVExp)
USE_JITKERNEL_REFER(kVSigmoid)
USE_JITKERNEL_REFER(kVTanh)
USE_JITKERNEL_REFER(kLSTMCtHt)
USE_JITKERNEL_REFER(kLSTMC1H1)
USE_JITKERNEL_REFER(kGRUH1)
USE_JITKERNEL_REFER(kGRUHtPart1)
USE_JITKERNEL_REFER(kGRUHtPart2)
USE_JITKERNEL_REFER(kCRFDecoding)
USE_JITKERNEL_REFER(kLayerNorm)
USE_JITKERNEL_REFER(kNCHW16CMulNC)
USE_JITKERNEL_REFER(kSeqPool)
USE_JITKERNEL_REFER(kMatMul)
USE_JITKERNEL_REFER(kVSquare)
USE_JITKERNEL_REFER(kHSum)
USE_JITKERNEL_REFER(kHMax)
USE_JITKERNEL_REFER(kStrideASum)
USE_JITKERNEL_REFER(kSoftmax)
USE_JITKERNEL_REFER(kEmbSeqPool)
USE_JITKERNEL_REFER(kSgd)
USE_JITKERNEL_REFER(kVBroadcast)
USE_JITKERNEL_REFER_LITE(kVMul)
USE_JITKERNEL_REFER_LITE(kVAdd)
USE_JITKERNEL_REFER_LITE(kVAddRelu)
USE_JITKERNEL_REFER_LITE(kVSub)
USE_JITKERNEL_REFER_LITE(kVScal)
USE_JITKERNEL_REFER_LITE(kStrideScal)
USE_JITKERNEL_REFER_LITE(kVAddBias)
USE_JITKERNEL_REFER_LITE(kVCopy)
USE_JITKERNEL_REFER_LITE(kVRelu)
USE_JITKERNEL_REFER_LITE(kVIdentity)
USE_JITKERNEL_REFER_LITE(kVExp)
USE_JITKERNEL_REFER_LITE(kVSigmoid)
USE_JITKERNEL_REFER_LITE(kVTanh)
USE_JITKERNEL_REFER_LITE(kLSTMCtHt)
USE_JITKERNEL_REFER_LITE(kLSTMC1H1)
USE_JITKERNEL_REFER_LITE(kGRUH1)
USE_JITKERNEL_REFER_LITE(kGRUHtPart1)
USE_JITKERNEL_REFER_LITE(kGRUHtPart2)
USE_JITKERNEL_REFER_LITE(kCRFDecoding)
USE_JITKERNEL_REFER_LITE(kLayerNorm)
USE_JITKERNEL_REFER_LITE(kNCHW16CMulNC)
USE_JITKERNEL_REFER_LITE(kSeqPool)
USE_JITKERNEL_REFER_LITE(kMatMul)
USE_JITKERNEL_REFER_LITE(kVSquare)
USE_JITKERNEL_REFER_LITE(kHSum)
USE_JITKERNEL_REFER_LITE(kHMax)
USE_JITKERNEL_REFER_LITE(kStrideASum)
USE_JITKERNEL_REFER_LITE(kSoftmax)
USE_JITKERNEL_REFER_LITE(kEmbSeqPool)
USE_JITKERNEL_REFER_LITE(kSgd)
USE_JITKERNEL_REFER_LITE(kVBroadcast)
......@@ -18,7 +18,7 @@
namespace refer = paddle::lite::jit::refer;
#define REGISTER_REFER_KERNEL(func) \
REGISTER_JITKERNEL_REFER( \
REGISTER_JITKERNEL_REFER_LITE( \
k##func, refer::func##Kernel<float>, refer::func##Kernel<double>)
REGISTER_REFER_KERNEL(VMul);
......
......@@ -77,16 +77,16 @@ class JitKernelRegistrar {
void Touch() {}
};
#define STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE(uniq_name, msg) \
#define STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE_LITE(uniq_name, msg) \
struct __test_global_namespace_##uniq_name##__ {}; \
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
__test_global_namespace_##uniq_name##__>::value, \
msg)
// Refer always on CPUPlace
#define REGISTER_JITKERNEL_REFER(kernel_type, ...) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_##kernel_type##_refer_CPUPlace, \
#define REGISTER_JITKERNEL_REFER_LITE(kernel_type, ...) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE_LITE( \
__reg_litejitkernel_##kernel_type##_refer_CPUPlace, \
"REGISTER_KERNEL_REFER must be called in global namespace"); \
static ::paddle::lite::jit::JitKernelRegistrar< \
::paddle::lite::jit::ReferKernelPool, \
......@@ -94,84 +94,84 @@ class JitKernelRegistrar {
__VA_ARGS__> \
__jit_kernel_registrar_##kernel_type##_refer_CPUPlace_( \
::paddle::lite::jit::KernelType::kernel_type); \
int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_() { \
int LiteTouchJitKernelReg_##kernel_type##_refer_CPUPlace_() { \
__jit_kernel_registrar_##kernel_type##_refer_CPUPlace_.Touch(); \
return 0; \
}
// kernel_type: should be in paddle::lite::jit::KernelType
// place_type: should be one of CPUPlace and GPUPlace in paddle::platform
#define REGISTER_KERNEL_MORE(kernel_type, impl_type, place_type, ...) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_##kernel_type##_##impl_type##_##place_type, \
"REGISTER_KERNEL_MORE must be called in global namespace"); \
extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
#define REGISTER_KERNEL_MORE_LITE(kernel_type, impl_type, place_type, ...) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE_LITE( \
__reg_litejitkernel_##kernel_type##_##impl_type##_##place_type, \
"REGISTER_KERNEL_MORE_LITE must be called in global namespace"); \
extern int LiteTouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static int __assert_##kernel_type##_##impl_type##_##place_type##_has_refer_ \
UNUSED = TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
UNUSED = LiteTouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static ::paddle::lite::jit::JitKernelRegistrar< \
::paddle::lite::jit::KernelPool, \
::paddle::lite::fluid::place_type, \
__VA_ARGS__> \
__jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##_( \
::paddle::lite::jit::KernelType::kernel_type); \
int TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_() { \
int LiteTouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_() { \
__jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##_ \
.Touch(); \
return 0; \
}
#define REGISTER_JITKERNEL_MORE(kernel_type, impl_type, ...) \
REGISTER_KERNEL_MORE(kernel_type, impl_type, CPUPlace, __VA_ARGS__)
#define REGISTER_GPUKERNEL_MORE(kernel_type, impl_type, ...) \
REGISTER_KERNEL_MORE(kernel_type, impl_type, GPUPlace, __VA_ARGS__)
#define REGISTER_JITKERNEL_GEN(kernel_type, ...) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_gen_##kernel_type##_CPUPlace_, \
"REGISTER_JITKERNEL_GEN must be called in global namespace"); \
extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static int __assert_gen_##kernel_type##_has_refer_ UNUSED = \
TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static ::paddle::lite::jit::JitKernelRegistrar< \
::paddle::lite::jit::JitCodeCreatorPool, \
::paddle::lite::fluid::CPUPlace, \
__VA_ARGS__> \
__jit_kernel_registrar_gen_##kernel_type##_CPUPlace_( \
::paddle::lite::jit::KernelType::kernel_type); \
int TouchJitKernelReg_gen_##kernel_type##_CPUPlace_() { \
__jit_kernel_registrar_gen_##kernel_type##_CPUPlace_.Touch(); \
return 0; \
REGISTER_KERNEL_MORE_LITE(kernel_type, impl_type, CPUPlace, __VA_ARGS__)
#define REGISTER_GPUKERNEL_MORE_LITE(kernel_type, impl_type, ...) \
REGISTER_KERNEL_MORE_LITE(kernel_type, impl_type, GPUPlace, __VA_ARGS__)
#define REGISTER_JITKERNEL_GEN_LITE(kernel_type, ...) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE_LITE( \
__reg_litejitkernel_gen_##kernel_type##_CPUPlace_, \
"REGISTER_JITKERNEL_GEN_LITE must be called in global namespace"); \
extern int LiteTouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static int __assert_gen_##kernel_type##_has_refer_ UNUSED = \
LiteTouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static ::paddle::lite::jit::JitKernelRegistrar< \
::paddle::lite::jit::JitCodeCreatorPool, \
::paddle::lite::fluid::CPUPlace, \
__VA_ARGS__> \
__jit_kernel_registrar_gen_##kernel_type##_CPUPlace_( \
::paddle::lite::jit::KernelType::kernel_type); \
int LiteTouchJitKernelReg_gen_##kernel_type##_CPUPlace_() { \
__jit_kernel_registrar_gen_##kernel_type##_CPUPlace_.Touch(); \
return 0; \
}
#define USE_JITKERNEL_GEN(kernel_type) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_gen_##kernel_type##_CPUPlace_, \
"USE_JITKERNEL_GEN must be called in global namespace"); \
extern int TouchJitKernelReg_gen_##kernel_type##_CPUPlace_(); \
static int use_jitkernel_gen_##kernel_type##_CPUPlace_ UNUSED = \
TouchJitKernelReg_gen_##kernel_type##_CPUPlace_()
#define USE_JITKERNEL_REFER(kernel_type) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_##kernel_type##_refer_CPUPlace_, \
"USE_JITKERNEL_REFER must be called in global namespace"); \
extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static int use_jitkernel_##kernel_type##_refer_CPUPlace_ UNUSED = \
TouchJitKernelReg_##kernel_type##_refer_CPUPlace_()
#define USE_KERNEL_MORE(kernel_type, impl_type, place_type) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_##kernel_type##_##impl_type##_##place_type##_, \
"USE_JITKERNEL_MORE must be called in global namespace"); \
extern int \
TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_(); \
static int use_jitkernel_##kernel_type##_##impl_type##_##place_type##_ \
UNUSED = \
TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_()
#define USE_JITKERNEL_MORE(kernel_type, impl_type) \
USE_KERNEL_MORE(kernel_type, impl_type, CPUPlace)
#define USE_JITKERNEL_GEN_LITE(kernel_type) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE_LITE( \
__reg_litejitkernel_gen_##kernel_type##_CPUPlace_, \
"USE_JITKERNEL_GEN_LITE must be called in global namespace"); \
extern int LiteTouchJitKernelReg_gen_##kernel_type##_CPUPlace_(); \
static int use_litejitkernel_gen_##kernel_type##_CPUPlace_ UNUSED = \
LiteTouchJitKernelReg_gen_##kernel_type##_CPUPlace_()
#define USE_JITKERNEL_REFER_LITE(kernel_type) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE_LITE( \
__reg_litejitkernel_##kernel_type##_refer_CPUPlace_, \
"USE_JITKERNEL_REFER_LITE must be called in global namespace"); \
extern int LiteTouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static int use_litejitkernel_##kernel_type##_refer_CPUPlace_ UNUSED = \
LiteTouchJitKernelReg_##kernel_type##_refer_CPUPlace_()
#define USE_KERNEL_MORE_LITE(kernel_type, impl_type, place_type) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE_LITE( \
__reg_litejitkernel_##kernel_type##_##impl_type##_##place_type##_, \
"USE_JITKERNEL_MORE_LITE must be called in global namespace"); \
extern int \
LiteTouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_(); \
static int use_litejitkernel_##kernel_type##_##impl_type##_##place_type##_ \
UNUSED = \
LiteTouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_()
#define USE_JITKERNEL_MORE_LITE(kernel_type, impl_type) \
USE_KERNEL_MORE_LITE(kernel_type, impl_type, CPUPlace)
} // namespace jit
} // namespace lite
......
......@@ -67,14 +67,14 @@ class Registry {
#define UNUSED __attribute__((unused))
#endif
#define STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE(uniq_name, msg) \
#define STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE_LITE(uniq_name, msg) \
struct __test_global_namespace_##uniq_name##__ {}; \
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
__test_global_namespace_##uniq_name##__>::value, \
msg)
#define REGISTER_SUBGRAPH_BRIDGE(op_type__, target__, cvt_func_name) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE_LITE( \
__reg_subgraph_bridge_##op_type__##_##target__##__, \
"REGISTER_SUBGRAPH_BRIDGE must be called in global namespace only " \
"once!"); \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册