未验证 提交 693e5e65 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #14958 from tensor-tang/refine/jit

enhance jit
...@@ -16,6 +16,7 @@ add_subdirectory(metrics) ...@@ -16,6 +16,7 @@ add_subdirectory(metrics)
add_subdirectory(optimizers) add_subdirectory(optimizers)
add_subdirectory(reduce_ops) add_subdirectory(reduce_ops)
add_subdirectory(sequence_ops) add_subdirectory(sequence_ops)
add_subdirectory(jit)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
add_subdirectory(distributed) add_subdirectory(distributed)
...@@ -64,7 +65,7 @@ set(COMMON_OP_DEPS ${OP_HEADER_DEPS}) ...@@ -64,7 +65,7 @@ set(COMMON_OP_DEPS ${OP_HEADER_DEPS})
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel concat_and_split cross_entropy softmax vol2col im2col sampler) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions)
if (WITH_GPU) if (WITH_GPU)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu)
......
...@@ -16,7 +16,7 @@ limitations under the License. */ ...@@ -16,7 +16,7 @@ limitations under the License. */
#include <limits> #include <limits>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/jit_kernel.h" #include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
namespace paddle { namespace paddle {
...@@ -82,10 +82,9 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> { ...@@ -82,10 +82,9 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
Tensor track; Tensor track;
int* track_value = int* track_value =
track.mutable_data<int>(emission_dims, platform::CPUPlace()); track.mutable_data<int>(emission_dims, platform::CPUPlace());
const auto& ker = math::jitkernel::KernelPool::Instance() auto ker = jit::Get<jit::kCRFDecoding, jit::CRFDecodingTuples<T>,
.template Get<math::jitkernel::CRFDecodeKernel<T>>( platform::CPUPlace>(tag_num);
static_cast<int>(tag_num)); ker(static_cast<int>(seq_len), x, w, alpha_value, track_value, tag_num);
ker->Compute(static_cast<int>(seq_len), x, w, alpha_value, track_value);
T max_score = -std::numeric_limits<T>::max(); T max_score = -std::numeric_limits<T>::max();
int max_i = 0; int max_i = 0;
for (size_t i = 0; i < tag_num; ++i) { for (size_t i = 0; i < tag_num; ++i) {
......
...@@ -18,7 +18,7 @@ limitations under the License. */ ...@@ -18,7 +18,7 @@ limitations under the License. */
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/operators/math/jit_kernel.h" #include "paddle/fluid/operators/jit/kernels.h"
#include "xbyak/xbyak.h" #include "xbyak/xbyak.h"
#include "xbyak/xbyak_util.h" #include "xbyak/xbyak_util.h"
...@@ -108,10 +108,8 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> { ...@@ -108,10 +108,8 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
constexpr int simd_width = 16; constexpr int simd_width = 16;
int C = c / simd_width; int C = c / simd_width;
const auto& multiply = auto multiply = jit::Get<jit::kNCHW16CMulNC, jit::NCHW16CMulNCTuples<T>,
math::jitkernel::KernelPool::Instance() platform::CPUPlace>(0);
.template Get<math::jitkernel::EltwiseMulnChw16cNCKernel<T>>(n);
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (int ni = 0; ni < n; ni++) { for (int ni = 0; ni < n; ni++) {
for (int ci = 0; ci < C; ci++) { for (int ci = 0; ci < C; ci++) {
...@@ -122,7 +120,7 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> { ...@@ -122,7 +120,7 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
auto ptr_z = auto ptr_z =
z_data + ni * C * h * w * simd_width + ci * h * w * simd_width; z_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
multiply->Compute(ptr_x, ptr_y, ptr_z, h, w); multiply(ptr_x, ptr_y, ptr_z, h, w);
} }
} }
} }
......
...@@ -15,9 +15,9 @@ limitations under the License. */ ...@@ -15,9 +15,9 @@ limitations under the License. */
#include "paddle/fluid/operators/fused/fusion_gru_op.h" #include "paddle/fluid/operators/fused/fusion_gru_op.h"
#include <cstring> // for memcpy #include <cstring> // for memcpy
#include <string> #include <string>
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/fc_compute.h" #include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
#include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/fluid/operators/math/sequence2batch.h"
namespace paddle { namespace paddle {
...@@ -182,27 +182,29 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -182,27 +182,29 @@ class FusionGRUKernel : public framework::OpKernel<T> {
const int total_T = x_dims[0]; \ const int total_T = x_dims[0]; \
const int D3 = wh_dims[1] const int D3 = wh_dims[1]
#define INIT_OTHER_DEFINES \ #define INIT_OTHER_DEFINES \
auto* h0 = ctx.Input<Tensor>("H0"); \ auto* h0 = ctx.Input<Tensor>("H0"); \
auto* wx = ctx.Input<Tensor>("WeightX"); \ auto* wx = ctx.Input<Tensor>("WeightX"); \
auto* bias = ctx.Input<Tensor>("Bias"); \ auto* bias = ctx.Input<Tensor>("Bias"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \ auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
bool is_reverse = ctx.Attr<bool>("is_reverse"); \ bool is_reverse = ctx.Attr<bool>("is_reverse"); \
const int M = x_dims[1]; \ const int M = x_dims[1]; \
const int D = wh_dims[0]; \ const int D = wh_dims[0]; \
const int D2 = D * 2; \ const int D2 = D * 2; \
const math::jitkernel::gru_attr_t attr( \ const jit::gru_attr_t attr( \
D, ctx.Attr<std::string>("gate_activation"), \ D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \
ctx.Attr<std::string>("activation")); \ jit::to_kerneltype(ctx.Attr<std::string>("activation"))); \
math::jitkernel::gru_t one_step; \ jit::gru_t one_step; \
const auto& ker = \ auto ComputeH1 = \
math::jitkernel::KernelPool::Instance() \ jit::Get<jit::kGRUH1, jit::GRUTuples<T>, platform::CPUPlace>(attr); \
.template Get<math::jitkernel::GRUKernel<T>, \ auto ComputeHtPart1 = \
const math::jitkernel::gru_attr_t&>(attr); \ jit::Get<jit::kGRUHtPart1, jit::GRUTuples<T>, platform::CPUPlace>(attr); \
const T* x_data = x->data<T>(); \ auto ComputeHtPart2 = \
const T* wx_data = wx->data<T>(); \ jit::Get<jit::kGRUHtPart2, jit::GRUTuples<T>, platform::CPUPlace>(attr); \
const T* wh_data = wh->data<T>(); \ const T* x_data = x->data<T>(); \
auto place = ctx.GetPlace(); \ const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \
auto place = ctx.GetPlace(); \
T* xx_data = xx->mutable_data<T>(place) T* xx_data = xx->mutable_data<T>(place)
void SeqCompute(const framework::ExecutionContext& ctx) const { void SeqCompute(const framework::ExecutionContext& ctx) const {
...@@ -241,7 +243,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -241,7 +243,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
} else { } else {
one_step.gates = xx_data; one_step.gates = xx_data;
one_step.ht = hidden_out_data; one_step.ht = hidden_out_data;
ker->ComputeH1(&one_step, &attr); ComputeH1(&one_step, &attr);
prev_hidden_data = hidden_out_data; prev_hidden_data = hidden_out_data;
tstart = 1; tstart = 1;
move_step(); move_step();
...@@ -254,12 +256,12 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -254,12 +256,12 @@ class FusionGRUKernel : public framework::OpKernel<T> {
one_step.gates = xx_data; one_step.gates = xx_data;
one_step.ht_1 = prev_hidden_data; one_step.ht_1 = prev_hidden_data;
one_step.ht = hidden_out_data; one_step.ht = hidden_out_data;
ker->ComputeHtPart1(&one_step, &attr); ComputeHtPart1(&one_step, &attr);
// gemm rt * Ws // gemm rt * Ws
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast<T>(1), blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast<T>(1),
hidden_out_data, D, wh_state_data, D, static_cast<T>(1), hidden_out_data, D, wh_state_data, D, static_cast<T>(1),
xx_data + D2, D3); xx_data + D2, D3);
ker->ComputeHtPart2(&one_step, &attr); ComputeHtPart2(&one_step, &attr);
// save prev // save prev
prev_hidden_data = hidden_out_data; prev_hidden_data = hidden_out_data;
move_step(); move_step();
...@@ -323,7 +325,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -323,7 +325,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
for (int i = 0; i < max_bs; ++i) { for (int i = 0; i < max_bs; ++i) {
one_step.gates = cur_in_data; one_step.gates = cur_in_data;
one_step.ht = cur_out_data; one_step.ht = cur_out_data;
ker->ComputeH1(&one_step, &attr); ComputeH1(&one_step, &attr);
// add offset // add offset
cur_in_data += D3; cur_in_data += D3;
cur_out_data += D; cur_out_data += D;
...@@ -351,7 +353,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -351,7 +353,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
one_step.gates = cur_batched_data; one_step.gates = cur_batched_data;
one_step.ht_1 = cur_prev_hidden_data; one_step.ht_1 = cur_prev_hidden_data;
one_step.ht = cur_out_data; one_step.ht = cur_out_data;
ker->ComputeHtPart1(&one_step, &attr); ComputeHtPart1(&one_step, &attr);
cur_batched_data += D3; cur_batched_data += D3;
cur_prev_hidden_data += D; cur_prev_hidden_data += D;
...@@ -369,7 +371,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -369,7 +371,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
one_step.gates = cur_batched_data; one_step.gates = cur_batched_data;
one_step.ht_1 = cur_prev_hidden_data; one_step.ht_1 = cur_prev_hidden_data;
one_step.ht = cur_out_data; one_step.ht = cur_out_data;
ker->ComputeHtPart2(&one_step, &attr); ComputeHtPart2(&one_step, &attr);
cur_batched_data += D3; cur_batched_data += D3;
cur_prev_hidden_data += D; cur_prev_hidden_data += D;
cur_out_data += D; cur_out_data += D;
......
...@@ -14,9 +14,9 @@ limitations under the License. */ ...@@ -14,9 +14,9 @@ limitations under the License. */
#include "paddle/fluid/operators/fused/fusion_lstm_op.h" #include "paddle/fluid/operators/fused/fusion_lstm_op.h"
#include <string> #include <string>
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/fc_compute.h" #include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
#include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/fluid/operators/math/sequence2batch.h"
namespace paddle { namespace paddle {
...@@ -235,31 +235,32 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -235,31 +235,32 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
const int D = wh_dims[0]; \ const int D = wh_dims[0]; \
const int D4 = wh_dims[1] const int D4 = wh_dims[1]
#define INIT_OTHER_DEFINES \ #define INIT_OTHER_DEFINES \
const T* x_data = x->data<T>(); \ const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \ const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \ const T* wh_data = wh->data<T>(); \
/* diagonal weight*/ \ /* diagonal weight*/ \
const T* wp_data = bias->data<T>() + D4; \ const T* wp_data = bias->data<T>() + D4; \
/* for peephole only*/ \ /* for peephole only*/ \
T* checked_cell_data = nullptr; \ T* checked_cell_data = nullptr; \
auto place = ctx.GetPlace(); \ auto place = ctx.GetPlace(); \
if (use_peepholes) { \ if (use_peepholes) { \
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \ /* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \
auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \ auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \
checked_cell_data = checked_cell->mutable_data<T>(place); \ checked_cell_data = checked_cell->mutable_data<T>(place); \
} \ } \
const math::jitkernel::lstm_attr_t attr( \ const jit::lstm_attr_t attr( \
D, ctx.Attr<std::string>("gate_activation"), \ D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \
ctx.Attr<std::string>("candidate_activation"), \ jit::to_kerneltype(ctx.Attr<std::string>("candidate_activation")), \
ctx.Attr<std::string>("cell_activation"), use_peepholes); \ jit::to_kerneltype(ctx.Attr<std::string>("cell_activation")), \
math::jitkernel::lstm_t one_step; \ use_peepholes); \
one_step.wp = wp_data; \ jit::lstm_t one_step; \
one_step.checked = checked_cell_data; \ one_step.wp = wp_data; \
const auto& ker = \ one_step.checked = checked_cell_data; \
math::jitkernel::KernelPool::Instance() \ auto ComputeC1H1 = \
.template Get<math::jitkernel::LSTMKernel<T>, \ jit::Get<jit::kLSTMC1H1, jit::LSTMTuples<T>, platform::CPUPlace>(attr); \
const math::jitkernel::lstm_attr_t&>(attr) auto ComputeCtHt = \
jit::Get<jit::kLSTMCtHt, jit::LSTMTuples<T>, platform::CPUPlace>(attr)
// Wh GEMM // Wh GEMM
#define GEMM_WH_ADDON(bs, prev, out) \ #define GEMM_WH_ADDON(bs, prev, out) \
...@@ -305,7 +306,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -305,7 +306,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
one_step.gates = xx_data; one_step.gates = xx_data;
one_step.ct = c_out_data; one_step.ct = c_out_data;
one_step.ht = h_out_data; one_step.ht = h_out_data;
ker->ComputeC1H1(&one_step, &attr); ComputeC1H1(&one_step, &attr);
tstart = 1; tstart = 1;
// move one step // move one step
prev_h_data = h_out_data; prev_h_data = h_out_data;
...@@ -321,7 +322,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -321,7 +322,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
one_step.ct_1 = prev_c_data; one_step.ct_1 = prev_c_data;
one_step.ct = c_out_data; one_step.ct = c_out_data;
one_step.ht = h_out_data; one_step.ht = h_out_data;
ker->ComputeCtHt(&one_step, &attr); ComputeCtHt(&one_step, &attr);
// move one step // move one step
prev_h_data = h_out_data; prev_h_data = h_out_data;
prev_c_data = c_out_data; prev_c_data = c_out_data;
...@@ -401,7 +402,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -401,7 +402,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
one_step.gates = cur_in_data; one_step.gates = cur_in_data;
one_step.ct = cur_c_out_data; one_step.ct = cur_c_out_data;
one_step.ht = cur_h_out_data; one_step.ht = cur_h_out_data;
ker->ComputeC1H1(&one_step, &attr); ComputeC1H1(&one_step, &attr);
cur_in_data += D4; cur_in_data += D4;
cur_c_out_data += D; cur_c_out_data += D;
...@@ -431,7 +432,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -431,7 +432,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
one_step.ct_1 = cur_prev_c_data; one_step.ct_1 = cur_prev_c_data;
one_step.ct = cur_c_out_data; one_step.ct = cur_c_out_data;
one_step.ht = cur_h_out_data; one_step.ht = cur_h_out_data;
ker->ComputeCtHt(&one_step, &attr); ComputeCtHt(&one_step, &attr);
// move one batch // move one batch
cur_in_data += D4; cur_in_data += D4;
......
set(jit_file ${PADDLE_BINARY_DIR}/paddle/fluid/operators/jit/kernels.h)
file(WRITE ${jit_file} "// Generated by the paddle/fluid/operators/jit/CMakeLists.txt. DO NOT EDIT!\n\n")
file(APPEND ${jit_file} "\#pragma once\n")
file(APPEND ${jit_file} "\#include \"paddle/fluid/operators/jit/helper.h\"\n")
file(APPEND ${jit_file} "\#include \"paddle/fluid/operators/jit/registry.h\"\n\n")
set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce place)
file(GLOB jit_kernel_cc_srcs RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cc")
list(REMOVE_ITEM jit_kernel_cc_srcs test.cc benchmark.cc)
cc_library(jit_kernel_base SRCS ${jit_kernel_cc_srcs} DEPS ${JIT_KERNEL_DEPS})
# refer must go first
add_subdirectory(refer)
add_subdirectory(more)
if(WITH_XBYAK)
add_subdirectory(gen)
endif()
cc_library(jit_kernel_helper SRCS ${jit_kernel_cc_srcs} DEPS ${JIT_KERNEL_DEPS})
cc_test(jit_kernel_test SRCS test.cc DEPS jit_kernel_helper)
if(NOT WIN32)
cc_binary(jit_kernel_benchmark SRCS benchmark.cc DEPS jit_kernel_helper)
endif()
# JIT Kernel
结合函数模板和JIT生成需要的kernel函数。
这里的kernel是比Operator中kernel更小级别的算子单元,更侧重的是在不同硬件上的性能。可以有多重第三方库的实现,每种实现有自己的`UseMe`函数负责什么条件下可以被调用。
这里实现的函数可以非常细粒度的函数方法,比如Vector MUL, 也可以是一个复杂的逻辑比如LSTM等。复杂的逻辑也可以由自己的底层函数拼接而成。
目前仅支持CPU上的高性能计算。
## 目录结构
```txt
PaddlePaddle/Paddle/paddle/fluid/
├── ...
├── operator/
│ ├── .../
└── jit/
├── ...
├── gen/
│ └── ...
|── more/
│ ├── ...
│ ├── mkl/
│ │ └── ...
│ ├── mkldnn/
│ │ └── ...
│ ├── mix/
│ │ └── ...
│ ├── intrinsic/
│ │ └── ...
│ └── openblas/
│ └── ...
└── refer/
└── ...
```
基本类的定义都放在根目录下,根目录下包括gen,more和refer三个目录。每个目录下都是一种或者多种实现,每种kernel算子都需要有reference的实现,用作单元测试的基准,其他的实现都是可选的。
- gen: 代表使用jit生成的code,需要依赖xbyak库。该实现最关心的就是性能。
- refer: 代表reference的实现,每种kernel算子都需要有在CPU上的reference的实现,他主要关心的算法逻辑的正确性。
- more: 下面可以放入跟多实现,可以包括mkl,mkldnn,intrinsic,openblas等,也可以是自身已有的kernel组合。
## 动态获取
提供一个`jit::Get`方法,根据kernel类别获取,每种实现都有自己的使用范围,根据范围动态和当前条件选择需要的kernel函数。
## 测试
- 逻辑测试
所有实现都要与refer的code对比,需要满足精度要求, 包括float和double的数据类型
- 性能测试
所有实现的性能对比,并且与最终的`jit::Get`方法对比,该方法拿到的性能需要在各种条件下都是最好的。
# 如何添加新的算子
-`KernelType` 中添加 `your_key` .
- 实现Reference 的逻辑,这个是必须是在CPU上的实现,并且不能依赖任何第三方库。实现后在`refer/CmakeLists.txt`中添加`USE_JITKERNEL_REFER(your_key)`来使用该kernel.
- (optional) 实现更多的算法在`more`目录下,可以依赖mkl,intrinsic或者mkldnn等第三方库。
- (optional) 实现基于Xbyak的生成code,在`gen`目下。 jitcode需要实现自己的`JitCodeCreator`,并注册在与refer相同的`KernelType`上。
- 必要时可以添加新的`KernelTuples`,可以参考`XYZNTuples`,新加的Attr类型需要特例化`JitCodeKey`方法。
-`test.cc`中添加unit test,至少需要测试`float``double`两种数据类型,如有必要需要支持额外的数据类型,比如`int8`的相关函数。
-`benchmark.cc`中添加相应的性能对比,同一种kernel需要对比所有实现,并且确保`jit::Get`得到的实现一直是速度最快的。
# 优点
- 统一的Get方法,接口简单。
- 同一套逻辑可以有多套实现,可以依赖多套第三方库,互不影响。
- 目录结构清晰,不会在某个文件中有多个宏定义,导致的可读性差问题。
- 优化方便,可以直接针对某种属性针对性优化,并不影响其他属性下的性能。
- 可以支持多种平台,包括Linux,Mac 和 Windows,至少可以保证每种平台都可以正常work。后期也可以针对不同平台有针对的优化。框架层面可以使用统一接口,不必关心底层实现。
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include <iostream>
#include <random>
#include <string>
#include <vector>
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/platform/device_tracer.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/port.h"
DEFINE_int32(burning, 10, "Burning times.");
DEFINE_int32(repeat, 3000, "Repeat times.");
DEFINE_int32(max_size, 1000, "The Max size would be tested.");
template <typename T>
void RandomVec(const int n, T* a, const T lower = static_cast<T>(-20.f),
const T upper = static_cast<T>(20.f), unsigned int seed = 100) {
std::mt19937 rng(seed);
std::uniform_real_distribution<double> uniform_dist(0, 1);
for (int i = 0; i < n; ++i) {
a[i] = static_cast<T>(uniform_dist(rng) * (upper - lower) + lower);
}
}
std::vector<int> TestSizes() {
std::vector<int> s;
for (int i = 1; i <= FLAGS_max_size; ++i) {
s.push_back(i);
}
return s;
}
template <typename KernelTuples, typename... Args>
struct BenchFunc {
// return this function avg time
double operator()(const typename KernelTuples::func_type tgt, Args... args) {
for (int i = 0; i < FLAGS_burning; ++i) {
tgt(args...);
}
auto start = paddle::platform::PosixInNsec() / 1e-3;
for (int i = 0; i < FLAGS_repeat; ++i) {
tgt(args...);
}
auto end = paddle::platform::PosixInNsec() / 1e-3;
return static_cast<double>(end - start) / FLAGS_repeat;
}
};
namespace jit = paddle::operators::jit;
template <jit::KernelType KT, typename KernelTuples, typename PlaceType,
typename... Args>
void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
BenchFunc<KernelTuples, Args...> benchmark;
std::vector<std::pair<std::string, double>> infos;
// test refer
auto refer = jit::GetRefer<KT, KernelTuples>();
if (!refer) {
LOG(FATAL) << "Refer can not be empty!";
}
infos.push_back(std::make_pair("Refer", benchmark(refer, args...)));
// test jitcode
auto jitcode = jit::GetJitCode<KT, KernelTuples, PlaceType>(attr);
if (jitcode) {
infos.push_back(std::make_pair("JitCode", benchmark(jitcode, args...)));
}
// test all impls in more
jit::KernelKey kkey(KT, PlaceType());
auto& pool = jit::KernelPool().Instance().AllKernels();
auto iter = pool.find(kkey);
if (iter != pool.end()) {
auto& impls = iter->second;
for (auto& impl : impls) {
auto i = dynamic_cast<const jit::KernelMore<KernelTuples>*>(impl.get());
if (i && i->UseMe(attr)) {
auto more = i->GetFunc();
infos.push_back(
std::make_pair(i->ImplType(), benchmark(more, args...)));
}
}
}
// Test result from Get function
auto tgt = jit::Get<KT, KernelTuples, PlaceType>(attr);
if (!tgt) {
LOG(FATAL) << "Target can not be empty!";
}
infos.push_back(std::make_pair("Target", benchmark(tgt, args...)));
// print
std::ostringstream loginfos;
loginfos << "Kernel Type " << jit::to_string(KT) << ": " << attr << ": ";
for (auto pair : infos) {
loginfos << pair.first << " takes " << pair.second << " us; ";
}
LOG(INFO) << loginfos.str();
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void BenchXYZNKernel() {
for (int d : TestSizes()) {
std::vector<T> x(d), y(d), z(d);
RandomVec<T>(d, x.data());
RandomVec<T>(d, y.data());
BenchAllImpls<KT, jit::XYZNTuples<T>, PlaceType>(d, x.data(), y.data(),
z.data(), d);
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void BenchAXYNKernel() {
for (int d : TestSizes()) {
const T a = static_cast<T>(3);
std::vector<T> x(d), y(d);
RandomVec<T>(d, x.data());
BenchAllImpls<KT, jit::AXYNTuples<T>, PlaceType>(d, &a, x.data(), y.data(),
d);
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void BenchXYNKernel() {
for (int d : TestSizes()) {
std::vector<T> x(d), y(d);
RandomVec<T>(d, x.data());
BenchAllImpls<KT, jit::XYNTuples<T>, PlaceType>(d, x.data(), y.data(), d);
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void BenchLSTMKernel() {
for (bool use_peephole : {true, false}) {
for (int d : TestSizes()) {
const jit::lstm_attr_t attr(d, jit::kVSigmoid, jit::kVTanh, jit::kVTanh,
use_peephole);
std::vector<T> x(4 * d), ct_1(d), ct(d), ht(d), wp(3 * d), checked(2 * d);
RandomVec<T>(4 * d, x.data(), -2.f, 2.f);
RandomVec<T>(3 * d, wp.data(), -2.f, 2.f);
RandomVec<T>(d, ct_1.data(), -2.f, 2.f);
const T* ct_1_data = ct_1.data();
const T* wp_data = wp.data();
T* x_data = x.data();
T* checked_data = checked.data();
T* ct_data = ct.data();
T* ht_data = ht.data();
jit::lstm_t step;
step.gates = x_data;
step.ct_1 = ct_1_data;
step.ct = ct_data;
step.ht = ht_data;
if (use_peephole) {
step.wp = wp_data;
step.checked = checked_data;
}
BenchAllImpls<KT, jit::LSTMTuples<T>, PlaceType>(attr, &step, &attr);
}
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void BenchGRUKernel() {
for (int d : TestSizes()) {
const jit::gru_attr_t attr(d, jit::kVSigmoid, jit::kVTanh);
std::vector<T> x(3 * d), ht_1(d), ht(d);
RandomVec<T>(3 * d, x.data(), -2.f, 2.f);
RandomVec<T>(d, ht_1.data(), -2.f, 2.f);
const T* ht_1_data = ht_1.data();
T* x_data = x.data();
T* ht_data = ht.data();
jit::gru_t step;
step.gates = x_data;
step.ht_1 = ht_1_data;
step.ht = ht_data;
BenchAllImpls<KT, jit::GRUTuples<T>, PlaceType>(attr, &step, &attr);
}
}
// Benchmark all jit kernels including jitcode, mkl and refer.
// To use this tool, run command: ./benchmark [options...]
// Options:
// --burning: the burning time before count
// --repeat: the repeat times
// --max_size: the max size would be tested
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
LOG(INFO) << "Burning " << FLAGS_burning << " times, Repeat " << FLAGS_repeat
<< " times.";
using T = float;
using PlaceType = paddle::platform::CPUPlace;
// xyzn
BenchXYZNKernel<jit::kVMul, T, PlaceType>();
BenchXYZNKernel<jit::kVAdd, T, PlaceType>();
BenchXYZNKernel<jit::kVAddRelu, T, PlaceType>();
BenchXYZNKernel<jit::kVSub, T, PlaceType>();
// axyn
BenchAXYNKernel<jit::kVScal, T, PlaceType>();
BenchAXYNKernel<jit::kVAddBias, T, PlaceType>();
// xyn
BenchXYNKernel<jit::kVRelu, T, PlaceType>();
BenchXYNKernel<jit::kVIdentity, T, PlaceType>();
BenchXYNKernel<jit::kVExp, T, PlaceType>();
BenchXYNKernel<jit::kVSigmoid, T, PlaceType>();
BenchXYNKernel<jit::kVTanh, T, PlaceType>();
// lstm and peephole
BenchLSTMKernel<jit::kLSTMCtHt, T, PlaceType>();
BenchLSTMKernel<jit::kLSTMC1H1, T, PlaceType>();
// gru functions
BenchGRUKernel<jit::kGRUH1, T, PlaceType>();
BenchGRUKernel<jit::kGRUHtPart1, T, PlaceType>();
BenchGRUKernel<jit::kGRUHtPart2, T, PlaceType>();
}
file(GLOB jitcode_cc_srcs RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cc")
cc_library(jit_kernel_jitcode SRCS ${jitcode_cc_srcs} DEPS jit_kernel_base xbyak)
set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} xbyak jit_kernel_jitcode PARENT_SCOPE)
function(USE_JITKERNEL_GEN TARGET)
file(APPEND ${jit_file} "USE_JITKERNEL_GEN(${TARGET});\n")
endfunction()
# use gen jitcode kernel by name
USE_JITKERNEL_GEN(kVMul)
USE_JITKERNEL_GEN(kVAdd)
#USE_JITKERNEL_GEN(kVSub) # TODO(TJ): enable me
USE_JITKERNEL_GEN(kVAddRelu)
USE_JITKERNEL_GEN(kVScal)
USE_JITKERNEL_GEN(kVAddBias)
USE_JITKERNEL_GEN(kVRelu)
USE_JITKERNEL_GEN(kVIdentity)
USE_JITKERNEL_GEN(kVExp)
USE_JITKERNEL_GEN(kVSigmoid)
USE_JITKERNEL_GEN(kVTanh)
USE_JITKERNEL_GEN(kLSTMCtHt)
USE_JITKERNEL_GEN(kLSTMC1H1)
USE_JITKERNEL_GEN(kGRUH1)
USE_JITKERNEL_GEN(kGRUHtPart1)
USE_JITKERNEL_GEN(kGRUHtPart2)
USE_JITKERNEL_GEN(kNCHW16CMulNC)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/act.h"
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
const float ALIGN32_BEG exp_float_consts[] ALIGN32_END = {
REPEAT_8TIMES(1.f),
REPEAT_8TIMES(2.f),
REPEAT_8TIMES(0.5f),
REPEAT_8TIMES(EXP_HIG),
REPEAT_8TIMES(EXP_LOW),
REPEAT_8TIMES(CEPHES_LOG2EF),
REPEAT_8TIMES(CEPHES_EXP_C1),
REPEAT_8TIMES(CEPHES_EXP_C2),
REPEAT_8TIMES(CEPHES_EXP_P0),
REPEAT_8TIMES(CEPHES_EXP_P1),
REPEAT_8TIMES(CEPHES_EXP_P2),
REPEAT_8TIMES(CEPHES_EXP_P3),
REPEAT_8TIMES(CEPHES_EXP_P4),
REPEAT_8TIMES(CEPHES_EXP_P5),
REPEAT_8TIMES(EXP_MAX_INPUT),
REPEAT_8TIMES(SIGMOID_THRESHOLD_MAX),
REPEAT_8TIMES(SIGMOID_THRESHOLD_MIN)};
const int ALIGN32_BEG exp_int_0x7f[] ALIGN32_END = {REPEAT_8TIMES(0x7f)};
int ALIGN32_BEG g_tmp_mem[16] ALIGN32_END = {0};
void VActJitCode::genCode() {
int offset = 0;
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
vmovups(ymm_src, ptr[param1 + offset]);
act<ymm_t>(ymm_dst, ymm_src, type_);
vmovups(ptr[param2 + offset], ymm_dst);
offset += sizeof(float) * YMM_FLOAT_BLOCK;
}
int rest = num_ % YMM_FLOAT_BLOCK;
while (rest > 0) {
int block = XMM_FLOAT_BLOCK;
if (rest >= 4) {
block = 4;
vmovups(xmm_src, ptr[param1 + offset]);
} else if (rest >= 2) {
block = 2;
vmovq(xmm_src, ptr[param1 + offset]);
} else {
block = 1;
vmovss(xmm_src, ptr[param1 + offset]);
}
act<xmm_t>(xmm_dst, xmm_src, type_);
if (rest >= 4) {
vmovups(ptr[param2 + offset], xmm_dst);
} else if (rest >= 2) {
vmovq(ptr[param2 + offset], xmm_dst);
} else {
vmovss(ptr[param2 + offset], xmm_dst);
}
offset += sizeof(float) * block;
rest -= block;
}
ret();
}
#define DECLARE_ACT_CREATOR(name) \
class name##Creator : public JitCodeCreator<int> { \
public: \
bool UseMe(const int& attr) const override { \
return platform::MayIUse(platform::avx); \
} \
size_t CodeSize(const int& d) const override; \
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { \
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
} \
}
DECLARE_ACT_CREATOR(VRelu);
DECLARE_ACT_CREATOR(VIdentity);
DECLARE_ACT_CREATOR(VExp);
DECLARE_ACT_CREATOR(VSigmoid);
DECLARE_ACT_CREATOR(VTanh);
// TODO(TJ): tuning use me
size_t VReluCreator::CodeSize(const int& d) const {
return 96 /* init size */ +
(d / YMM_FLOAT_BLOCK + 3) * 4 /* instructions */ *
8 /* average bytes for each instruction */;
}
size_t VIdentityCreator::CodeSize(const int& d) const {
return 96 + (d / YMM_FLOAT_BLOCK + 3) * 4 * 8;
}
size_t VExpCreator::CodeSize(const int& d) const {
return 96 + (d / YMM_FLOAT_BLOCK + 3) * 70 * 8;
}
size_t VSigmoidCreator::CodeSize(const int& d) const {
return 96 + (d / YMM_FLOAT_BLOCK + 3) * 82 * 8;
}
size_t VTanhCreator::CodeSize(const int& d) const {
return 96 + (d / YMM_FLOAT_BLOCK + 3) * 84 * 8;
}
#undef DECLARE_ACT_CREATOR
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(kVRelu, gen::VReluCreator);
REGISTER_JITKERNEL_GEN(kVIdentity, gen::VIdentityCreator);
REGISTER_JITKERNEL_GEN(kVExp, gen::VExpCreator);
REGISTER_JITKERNEL_GEN(kVSigmoid, gen::VSigmoidCreator);
REGISTER_JITKERNEL_GEN(kVTanh, gen::VTanhCreator);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
You may obtain a copy of the License at * You may obtain a copy of the License at
*
http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
*
Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
limitations under the License. */ * limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include "paddle/fluid/operators/math/jit_gen.h" #include "glog/logging.h"
#include "paddle/fluid/operators/math/jit_kernel_impl.h" #include "paddle/fluid/operators/jit/gen/jitcode.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace jit {
namespace jitkernel {
namespace gen { namespace gen {
using reg64_t = const Xbyak::Reg64;
using reg32_t = const Xbyak::Reg32;
using xmm_t = const Xbyak::Xmm;
using ymm_t = const Xbyak::Ymm;
using zmm_t = const Xbyak::Zmm;
using Label = Xbyak::Label;
typedef enum {
mul = 0,
add,
sub,
relu,
exp,
sigmoid,
tanh,
identity
} operand_type;
extern const float exp_float_consts[]; extern const float exp_float_consts[];
extern const int exp_int_0x7f[]; extern const int exp_int_0x7f[];
extern int g_tmp_mem[]; extern int g_tmp_mem[];
...@@ -79,94 +59,15 @@ extern int g_tmp_mem[]; ...@@ -79,94 +59,15 @@ extern int g_tmp_mem[];
#define OFFSET_SIGMOID_MAX 15 * YMM_FLOAT_BLOCK * sizeof(float) #define OFFSET_SIGMOID_MAX 15 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MIN 16 * YMM_FLOAT_BLOCK * sizeof(float) #define OFFSET_SIGMOID_MIN 16 * YMM_FLOAT_BLOCK * sizeof(float)
// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu) class VActFunc : public JitCode {
class VXXJitCode : public JitCode {
public:
const char* name() const override {
std::string base = "VXXJitCode";
if (scalar_index_ == 1) {
base += "_Scalar";
} else {
base += "_Vec";
}
if (type_ == operand_type::mul) {
base += "_Mul";
} else if (type_ == operand_type::add) {
base += "_Add";
}
if (scalar_index_ == 2) {
base += "_Scalar";
} else {
base += "_Vec";
}
base += (with_relu_ ? "_Relu" : "");
return base.c_str();
}
explicit VXXJitCode(int d, operand_type type, int scalar_index,
bool with_relu, size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr),
num_(d),
type_(type),
scalar_index_(scalar_index),
with_relu_(with_relu) {}
static bool init(int d, int scalar_index = 0);
void generate() override;
private:
int num_;
operand_type type_;
int scalar_index_;
bool with_relu_;
reg64_t param1{abi_param1};
reg64_t param2{abi_param2};
reg64_t param3{abi_param3};
xmm_t xmm_src1 = xmm_t(0);
xmm_t xmm_src2 = xmm_t(1);
xmm_t xmm_dst = xmm_t(2);
xmm_t xmm_zero = xmm_t(3);
ymm_t ymm_src1 = ymm_t(0);
ymm_t ymm_src2 = ymm_t(1);
ymm_t ymm_dst = ymm_t(2);
ymm_t ymm_zero = ymm_t(3);
};
class VActJitCode : public JitCode {
public: public:
const char* name() const override { explicit VActFunc(size_t code_size, void* code_ptr)
std::string base = "VActJitCode"; : JitCode(code_size, code_ptr) {}
switch (type_) { virtual const char* name() const = 0;
case operand_type::relu: virtual void genCode() = 0;
base += "_Relu";
break;
case operand_type::exp:
base += "_Exp";
break;
case operand_type::sigmoid:
base += "_Sigmoid";
break;
case operand_type::tanh:
base += "_Tanh";
break;
case operand_type::identity:
base += "_Identity";
break;
default:
break;
}
return base.c_str();
}
explicit VActJitCode(int d, operand_type type, size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr), num_(d), type_(type) {}
static bool init(int d, operand_type type);
void generate() override;
protected: protected:
// compute relu with ymm, xmm // compute RELU with ymm, xmm
template <typename JMM> template <typename JMM>
void relu_jmm(JMM& dst, JMM& src, int zero_idx = 15) { // NOLINT void relu_jmm(JMM& dst, JMM& src, int zero_idx = 15) { // NOLINT
JMM zero = JMM(zero_idx); JMM zero = JMM(zero_idx);
...@@ -174,7 +75,7 @@ class VActJitCode : public JitCode { ...@@ -174,7 +75,7 @@ class VActJitCode : public JitCode {
vmaxps(dst, src, zero); vmaxps(dst, src, zero);
} }
// compute exp with ymm, xmm // compute EXP with ymm, xmm
template <typename JMM> template <typename JMM>
void exp_jmm(JMM& dst, JMM& src, int src_idx = 11, int fx_idx = 12, // NOLINT void exp_jmm(JMM& dst, JMM& src, int src_idx = 11, int fx_idx = 12, // NOLINT
int fy_idx = 13, int mask_idx = 14, int tmp_idx = 15) { int fy_idx = 13, int mask_idx = 14, int tmp_idx = 15) {
...@@ -258,7 +159,7 @@ class VActJitCode : public JitCode { ...@@ -258,7 +159,7 @@ class VActJitCode : public JitCode {
pop(reg_ptr_global); pop(reg_ptr_global);
} }
// compute sigmoid with ymm, xmm // compute SIGMOID with ymm, xmm
template <typename JMM> template <typename JMM>
void sigmoid_jmm(JMM& dst, JMM& src, int src_idx = 11, // NOLINT void sigmoid_jmm(JMM& dst, JMM& src, int src_idx = 11, // NOLINT
int fx_idx = 12, int fy_idx = 13, int mask_idx = 14, int fx_idx = 12, int fy_idx = 13, int mask_idx = 14,
...@@ -283,7 +184,7 @@ class VActJitCode : public JitCode { ...@@ -283,7 +184,7 @@ class VActJitCode : public JitCode {
pop(reg_ptr_global); pop(reg_ptr_global);
} }
// compute tanh with ymm, xmm // compute TANH with ymm, xmm
template <typename JMM> template <typename JMM>
void tanh_jmm(JMM& dst, JMM& src, int src_idx = 11, // NOLINT void tanh_jmm(JMM& dst, JMM& src, int src_idx = 11, // NOLINT
int fx_idx = 12, int fy_idx = 13, int mask_idx = 14, int fx_idx = 12, int fy_idx = 13, int mask_idx = 14,
...@@ -310,223 +211,109 @@ class VActJitCode : public JitCode { ...@@ -310,223 +211,109 @@ class VActJitCode : public JitCode {
pop(reg_ptr_global); pop(reg_ptr_global);
} }
// compute IDENTITY with ymm, xmm
template <typename JMM>
void identity_jmm(JMM& dst, JMM& src, int zero_idx) { // NOLINT
JMM zero = JMM(zero_idx);
vxorps(zero, zero, zero);
vaddps(dst, src, zero);
// TODO(TJ): use below
// dst.setIdx(src.getIdx());
}
template <typename JMM> template <typename JMM>
void act(JMM& dst, JMM& src, operand_type type) { // NOLINT void act(JMM& dst, JMM& src, operand_type type) { // NOLINT
// use 11~15 // use 11~15
switch (type) { switch (type) {
case operand_type::relu: case operand_type::RELU:
relu_jmm<JMM>(dst, src, 15); relu_jmm<JMM>(dst, src, 15);
break; break;
case operand_type::exp: case operand_type::EXP:
exp_jmm<JMM>(dst, src, 11, 12, 13, 14, 15); exp_jmm<JMM>(dst, src, 11, 12, 13, 14, 15);
break; break;
case operand_type::sigmoid: case operand_type::SIGMOID:
sigmoid_jmm<JMM>(dst, src, 11, 12, 13, 14, 15); sigmoid_jmm<JMM>(dst, src, 11, 12, 13, 14, 15);
break; break;
case operand_type::tanh: case operand_type::TANH:
tanh_jmm<JMM>(dst, src, 11, 12, 13, 14, 15); tanh_jmm<JMM>(dst, src, 11, 12, 13, 14, 15);
break; break;
case operand_type::identity: case operand_type::IDENTITY:
identity_jmm<JMM>(dst, src, 15);
break; break;
default: default:
// throw error LOG(FATAL) << "Do not support this operand type: " << type;
break; break;
} }
} }
protected:
int num_;
operand_type type_;
reg64_t param1{abi_param1};
reg64_t param2{abi_param2};
xmm_t xmm_src = xmm_t(0);
ymm_t ymm_src = ymm_t(0);
xmm_t xmm_dst = xmm_t(1);
ymm_t ymm_dst = ymm_t(1);
}; };
class LSTMJitCode : public VActJitCode { class VActJitCode : public VActFunc {
public: public:
const char* name() const override { explicit VActJitCode(int d, operand_type type, size_t code_size,
std::string base = "LSTMJitCode"; void* code_ptr = nullptr)
if (use_peephole_) { : VActFunc(code_size, code_ptr), num_(d), type_(type) {
base += "_Peephole"; if (!(type_ == operand_type::RELU || type_ == operand_type::EXP ||
} type_ == operand_type::SIGMOID || type_ == operand_type::TANH ||
if (compute_c1h1_) { type_ == operand_type::IDENTITY)) {
base += "_C1H1"; LOG(FATAL) << "Do not support this operand type: " << type_;
} }
auto AddTypeStr = [&](operand_type type) { this->genCode();
switch (type) {
case operand_type::relu:
base += "_Relu";
break;
case operand_type::exp:
base += "_Exp";
break;
case operand_type::sigmoid:
base += "_Sigmoid";
break;
case operand_type::tanh:
base += "_Tanh";
break;
case operand_type::identity:
base += "_Identity";
break;
default:
break;
}
};
AddTypeStr(act_gate_);
AddTypeStr(act_cand_);
AddTypeStr(act_cell_);
return base.c_str();
}
explicit LSTMJitCode(bool compute_c1h1, const lstm_attr_t& attr,
size_t code_size = 256 * 1024, void* code_ptr = nullptr)
: VActJitCode(attr.d, operand_type::sigmoid /* this is bugy*/, code_size,
code_ptr),
compute_c1h1_(compute_c1h1) {
auto typeExchange = [](const std::string& type) -> gen::operand_type {
if (type == "sigmoid") {
return operand_type::sigmoid;
} else if (type == "relu") {
return operand_type::relu;
} else if (type == "tanh") {
return operand_type::tanh;
} else if (type == "identity" || type == "") {
return operand_type::identity;
} // else throw error
return operand_type::identity;
};
num_ = attr.d;
use_peephole_ = attr.use_peephole;
act_gate_ = typeExchange(attr.act_gate);
act_cand_ = typeExchange(attr.act_cand);
act_cell_ = typeExchange(attr.act_cell);
} }
static bool init(int d);
void generate() override;
protected:
int num_;
bool compute_c1h1_;
bool use_peephole_;
operand_type act_gate_;
operand_type act_cand_;
operand_type act_cell_;
reg64_t param1{abi_param1};
};
class GRUJitCode : public VActJitCode {
public:
const char* name() const override { const char* name() const override {
std::string base = "GRUJitCode"; std::string base = "VActJitCode";
if (id_ == 0) { switch (type_) {
base += "_H1"; case operand_type::RELU:
} else if (id_ == 1) { base += "_Relu";
base += "_HtPart1"; break;
} else if (id_ == 2) { case operand_type::EXP:
base += "_HtPart2"; base += "_Exp";
break;
case operand_type::SIGMOID:
base += "_Sigmoid";
break;
case operand_type::TANH:
base += "_Tanh";
break;
case operand_type::IDENTITY:
base += "_Identity";
break;
default:
break;
} }
auto AddTypeStr = [&](operand_type type) {
switch (type) {
case operand_type::relu:
base += "_Relu";
break;
case operand_type::exp:
base += "_Exp";
break;
case operand_type::sigmoid:
base += "_Sigmoid";
break;
case operand_type::tanh:
base += "_Tanh";
break;
case operand_type::identity:
base += "_Identity";
break;
default:
break;
}
};
AddTypeStr(act_gate_);
AddTypeStr(act_cand_);
return base.c_str(); return base.c_str();
} }
void genCode() override;
explicit GRUJitCode(int id, const gru_attr_t& attr,
size_t code_size = 256 * 1024, void* code_ptr = nullptr)
: VActJitCode(attr.d, operand_type::sigmoid /* this is bugy*/, code_size,
code_ptr),
id_(id) {
auto typeExchange = [](const std::string& type) -> gen::operand_type {
if (type == "sigmoid") {
return operand_type::sigmoid;
} else if (type == "relu") {
return operand_type::relu;
} else if (type == "tanh") {
return operand_type::tanh;
} else if (type == "identity" || type == "") {
return operand_type::identity;
} // else throw error
return operand_type::identity;
};
num_ = attr.d;
act_gate_ = typeExchange(attr.act_gate);
act_cand_ = typeExchange(attr.act_cand);
}
static bool init(int d);
void generate() override;
protected: protected:
int id_;
int num_; int num_;
operand_type act_gate_; operand_type type_;
operand_type act_cand_;
reg64_t param1{abi_param1}; reg64_t param1{abi_param1};
}; reg64_t param2{abi_param2};
#ifdef PADDLE_WITH_MKLDNN xmm_t xmm_src = xmm_t(0);
struct EltwiseMulnChw16cNC : public Xbyak::CodeGenerator { ymm_t ymm_src = ymm_t(0);
explicit EltwiseMulnChw16cNC(size_t code_size = 256 * 1024)
: Xbyak::CodeGenerator(code_size) {
// RDI is ptr x_input
// RSI is ptr y_input
// RDX is ptr output
// RCX is height
// r8 is width
push(rbx); xmm_t xmm_dst = xmm_t(1);
ymm_t ymm_dst = ymm_t(1);
};
xor_(rax, rax); #define DECLARE_ACT_JITCODE(name, op_type) \
xor_(r10, r10); class name##JitCode : public VActJitCode { \
vmovups(zmm3, ptr[rsi]); public: \
explicit name##JitCode(int d, size_t code_size, void* code_ptr = nullptr) \
: VActJitCode(d, op_type, code_size, code_ptr) {} \
};
L("h_loop"); DECLARE_ACT_JITCODE(VRelu, operand_type::RELU);
xor_(rbx, rbx); DECLARE_ACT_JITCODE(VIdentity, operand_type::IDENTITY);
L("w_loop"); DECLARE_ACT_JITCODE(VExp, operand_type::EXP);
vmovups(zmm2, ptr[rdi + rax]); DECLARE_ACT_JITCODE(VSigmoid, operand_type::SIGMOID);
vmulps(zmm1, zmm2, zmm3); DECLARE_ACT_JITCODE(VTanh, operand_type::TANH);
vmovups(ptr[rdx + rax], zmm1);
add(rax, 64);
inc(rbx);
cmp(r8, rbx);
jnz("w_loop");
inc(r10);
cmp(r10, rcx);
jnz("h_loop");
pop(rbx); #undef DECLARE_ACT_JITCODE
ret();
}
};
#endif
} // namespace gen } // namespace gen
} // namespace jitkernel } // namespace jit
} // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/blas.h"
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
void VXXJitCode::genCode() {
// do not need push stack, and do not need save avx512reg if do not use avx512
int offset = 0;
if (with_relu_) {
vxorps(ymm_zero, ymm_zero, ymm_zero);
}
if (scalar_index_ == 1) {
vbroadcastss(ymm_src1, ptr[param1]);
} else if (scalar_index_ == 2) {
vbroadcastss(ymm_src2, ptr[param2]);
}
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
if (scalar_index_ != 1) {
vmovups(ymm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovups(ymm_src2, ptr[param2 + offset]);
}
if (type_ == operand_type::MUL) {
vmulps(ymm_dst, ymm_src1, ymm_src2);
} else if (type_ == operand_type::ADD) {
vaddps(ymm_dst, ymm_src1, ymm_src2);
}
if (with_relu_) {
vmaxps(ymm_dst, ymm_zero, ymm_dst);
}
vmovups(ptr[param3 + offset], ymm_dst);
offset += sizeof(float) * YMM_FLOAT_BLOCK;
}
int rest = num_ % YMM_FLOAT_BLOCK;
while (rest > 0) {
int block = XMM_FLOAT_BLOCK;
if (rest >= 4) {
block = 4;
if (scalar_index_ != 1) {
vmovups(xmm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovups(xmm_src2, ptr[param2 + offset]);
}
} else if (rest >= 2) {
block = 2;
if (scalar_index_ != 1) {
vmovq(xmm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovq(xmm_src2, ptr[param2 + offset]);
}
} else {
block = 1;
if (scalar_index_ != 1) {
vmovss(xmm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovss(xmm_src2, ptr[param2 + offset]);
}
}
switch (type_) {
case operand_type::MUL:
vmulps(xmm_dst, xmm_src1, xmm_src2);
break;
case operand_type::ADD:
vaddps(xmm_dst, xmm_src1, xmm_src2);
break;
default:
break;
}
if (with_relu_) {
vmaxps(xmm_dst, xmm_zero, xmm_dst);
}
if (rest >= 4) {
vmovups(ptr[param3 + offset], xmm_dst);
} else if (rest >= 2) {
vmovq(ptr[param3 + offset], xmm_dst);
} else {
vmovss(ptr[param3 + offset], xmm_dst);
}
offset += sizeof(float) * block;
rest -= block;
}
ret();
}
void NCHW16CMulNCJitCode::genCode() {
// RDI is ptr x_input
// RSI is ptr y_input
// RDX is ptr output
// RCX is height
// r8 is width
push(rbx);
xor_(rax, rax);
xor_(r10, r10);
vmovups(zmm3, ptr[rsi]);
L("h_loop");
xor_(rbx, rbx);
L("w_loop");
vmovups(zmm2, ptr[rdi + rax]);
vmulps(zmm1, zmm2, zmm3);
vmovups(ptr[rdx + rax], zmm1);
add(rax, 64);
inc(rbx);
cmp(r8, rbx);
jnz("w_loop");
inc(r10);
cmp(r10, rcx);
jnz("h_loop");
pop(rbx);
ret();
}
class NCHW16CMulNCCreator : public JitCodeCreator<int> {
public:
bool UseMe(const int& attr) const override {
return platform::MayIUse(platform::avx512f);
}
size_t CodeSize(const int& d) const override { return 256 * 1024; }
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override {
return make_unique<NCHW16CMulNCJitCode>(attr, CodeSize(attr));
}
};
#define DECLARE_BLAS_CREATOR(name) \
class name##Creator : public JitCodeCreator<int> { \
public: \
bool UseMe(const int& attr) const override { \
return platform::MayIUse(platform::avx); \
} \
size_t CodeSize(const int& d) const override { \
return 96 + d / YMM_FLOAT_BLOCK * 4 * 8; \
} \
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { \
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
} \
}
DECLARE_BLAS_CREATOR(VMul);
DECLARE_BLAS_CREATOR(VAdd);
DECLARE_BLAS_CREATOR(VSub);
DECLARE_BLAS_CREATOR(VAddRelu);
DECLARE_BLAS_CREATOR(VScal);
DECLARE_BLAS_CREATOR(VAddBias);
#undef DECLARE_BLAS_CREATOR
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(kVMul, gen::VMulCreator);
REGISTER_JITKERNEL_GEN(kVAdd, gen::VAddCreator);
// TODO(TJ): enable sub
// REGISTER_JITKERNEL_GEN(kVSub, gen::VSubCreator);
REGISTER_JITKERNEL_GEN(kVAddRelu, gen::VAddReluCreator);
REGISTER_JITKERNEL_GEN(kVScal, gen::VScalCreator);
REGISTER_JITKERNEL_GEN(kVAddBias, gen::VAddBiasCreator);
REGISTER_JITKERNEL_GEN(kNCHW16CMulNC, gen::NCHW16CMulNCCreator);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu)
class VXXJitCode : public JitCode {
public:
explicit VXXJitCode(int d, operand_type type, int scalar_index,
bool with_relu, size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr),
num_(d),
type_(type),
scalar_index_(scalar_index),
with_relu_(with_relu) {
if (!(type_ == operand_type::MUL || type_ == operand_type::ADD)) {
LOG(FATAL) << "Do not support this operand type: " << type_;
}
this->genCode();
}
virtual const char* name() const {
std::string base = "VXXJitCode";
if (scalar_index_ == 1) {
base += "_Scalar";
} else {
base += "_Vec";
}
if (type_ == operand_type::MUL) {
base += "_Mul";
} else if (type_ == operand_type::ADD) {
base += "_Add";
}
if (scalar_index_ == 2) {
base += "_Scalar";
} else {
base += "_Vec";
}
base += (with_relu_ ? "_Relu" : "");
return base.c_str();
}
void genCode() override;
private:
int num_;
operand_type type_;
int scalar_index_;
bool with_relu_;
reg64_t param1{abi_param1};
reg64_t param2{abi_param2};
reg64_t param3{abi_param3};
xmm_t xmm_src1 = xmm_t(0);
xmm_t xmm_src2 = xmm_t(1);
xmm_t xmm_dst = xmm_t(2);
xmm_t xmm_zero = xmm_t(3);
ymm_t ymm_src1 = ymm_t(0);
ymm_t ymm_src2 = ymm_t(1);
ymm_t ymm_dst = ymm_t(2);
ymm_t ymm_zero = ymm_t(3);
};
#define DECLARE_BLAS_JITCODE(name, op_type, scalar_idx, with_relu) \
class name##JitCode : public VXXJitCode { \
public: \
explicit name##JitCode(int d, size_t code_size, void* code_ptr = nullptr) \
: VXXJitCode(d, op_type, scalar_idx, with_relu, code_size, code_ptr) { \
} \
};
DECLARE_BLAS_JITCODE(VMul, operand_type::MUL, 0, false);
DECLARE_BLAS_JITCODE(VAdd, operand_type::ADD, 0, false);
DECLARE_BLAS_JITCODE(VSub, operand_type::SUB, 0, false);
DECLARE_BLAS_JITCODE(VAddRelu, operand_type::ADD, 0, true);
DECLARE_BLAS_JITCODE(VScal, operand_type::MUL, 1, false);
DECLARE_BLAS_JITCODE(VAddBias, operand_type::ADD, 1, false);
#undef DECLARE_BLAS_JITCODE
// nChw16c = nChw16c .* NC
class NCHW16CMulNCJitCode : public JitCode {
public:
DECLARE_JIT_CODE(NCHW16CMulNCJitCode);
explicit NCHW16CMulNCJitCode(int d /*unused*/, size_t code_size,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr) {
this->genCode();
}
void genCode() override;
};
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/gru.h"
#include <stddef.h> // offsetof
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
void GRUJitCode::genCode() {
reg64_t reg_ptr_gates = rax;
reg64_t reg_ptr_ht_1 = r9;
reg64_t reg_ptr_ht = r10;
mov(reg_ptr_gates, ptr[param1 + offsetof(gru_t, gates)]);
mov(reg_ptr_ht_1, ptr[param1 + offsetof(gru_t, ht_1)]);
mov(reg_ptr_ht, ptr[param1 + offsetof(gru_t, ht)]);
ymm_t ymm_one = ymm_t(0);
if (id_ == 2) {
reg64_t reg_ptr_tmp = r11;
mov(reg_ptr_tmp, reinterpret_cast<size_t>(exp_float_consts));
vmovaps(ymm_one, ptr[reg_ptr_tmp + OFFSET_EXP_ONE]);
}
int offset = 0;
int d = num_ * sizeof(float);
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
ymm_t ymm_u = ymm_t(1);
ymm_t ymm_r = ymm_t(2);
ymm_t ymm_s = ymm_t(3);
ymm_t ymm_ht_1 = ymm_t(4);
// W: {W_update, W_reset; W_state}
if (id_ == 0 || id_ == 2) {
vmovups(ymm_u, ptr[reg_ptr_gates + offset]);
vmovups(ymm_s, ptr[reg_ptr_gates + offset + 2 * d]);
}
if (id_ == 1) {
vmovups(ymm_r, ptr[reg_ptr_gates + offset + d]);
}
if (id_ == 1 || id_ == 2) {
vmovups(ymm_ht_1, ptr[reg_ptr_ht_1 + offset]);
}
if (id_ == 0) {
// ht = act_gate(u) * act_cand(s)
act<ymm_t>(ymm_u, ymm_u, act_gate_);
act<ymm_t>(ymm_s, ymm_s, act_cand_);
vmulps(ymm_s, ymm_s, ymm_u);
vmovups(ptr[reg_ptr_ht + offset], ymm_s);
} else if (id_ == 1) {
// ht = act_gate(r) * ht_1
act<ymm_t>(ymm_r, ymm_r, act_gate_);
vmulps(ymm_r, ymm_r, ymm_ht_1);
vmovups(ptr[reg_ptr_ht + offset], ymm_r);
} else if (id_ == 2) {
// ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1
ymm_t ymm_one_inner = ymm_t(ymm_one.getIdx());
act<ymm_t>(ymm_u, ymm_u, act_gate_);
act<ymm_t>(ymm_s, ymm_s, act_cand_);
vmulps(ymm_s, ymm_s, ymm_u);
vsubps(ymm_u, ymm_one_inner, ymm_u);
vmulps(ymm_u, ymm_ht_1, ymm_u);
vaddps(ymm_u, ymm_s, ymm_u);
vmovups(ptr[reg_ptr_ht + offset], ymm_u);
}
offset += sizeof(float) * YMM_FLOAT_BLOCK;
}
ret();
}
#define DECLARE_GRU_CREATOR(name) \
class name##Creator : public JitCodeCreator<gru_attr_t> { \
public: \
/* TODO(TJ): enable more */ \
bool UseMe(const gru_attr_t& attr) const override { \
return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \
} \
size_t CodeSize(const gru_attr_t& attr) const override { \
return 96 + attr.d / YMM_FLOAT_BLOCK * 96 * 2 * 8; \
} \
std::unique_ptr<GenBase> CreateJitCode( \
const gru_attr_t& attr) const override { \
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
} \
}
DECLARE_GRU_CREATOR(GRUH1);
DECLARE_GRU_CREATOR(GRUHtPart1);
DECLARE_GRU_CREATOR(GRUHtPart2);
#undef DECLARE_GRU_CREATOR
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(kGRUH1, gen::GRUH1Creator);
REGISTER_JITKERNEL_GEN(kGRUHtPart1, gen::GRUHtPart1Creator);
REGISTER_JITKERNEL_GEN(kGRUHtPart2, gen::GRUHtPart2Creator);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/act.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
class GRUJitCode : public VActFunc {
public:
explicit GRUJitCode(int id, const gru_attr_t& attr, size_t code_size,
void* code_ptr = nullptr)
: VActFunc(code_size, code_ptr), id_(id), num_(attr.d) {
auto typeExchange = [](KernelType type) -> gen::operand_type {
if (type == KernelType::kVSigmoid) {
return operand_type::SIGMOID;
} else if (type == KernelType::kVRelu) {
return operand_type::RELU;
} else if (type == KernelType::kVTanh) {
return operand_type::TANH;
} else if (type == KernelType::kVIdentity) {
return operand_type::IDENTITY;
} else {
LOG(FATAL) << "Do not support this jit::KernelType: " << type;
}
return operand_type::IDENTITY;
};
act_gate_ = typeExchange(attr.act_gate);
act_cand_ = typeExchange(attr.act_cand);
this->genCode();
}
const char* name() const override {
std::string base = "GRUJitCode";
if (id_ == 0) {
base += "_H1";
} else if (id_ == 1) {
base += "_HtPart1";
} else if (id_ == 2) {
base += "_HtPart2";
}
auto AddTypeStr = [&](operand_type type) {
switch (type) {
case operand_type::RELU:
base += "_Relu";
break;
case operand_type::EXP:
base += "_Exp";
break;
case operand_type::SIGMOID:
base += "_Sigmoid";
break;
case operand_type::TANH:
base += "_Tanh";
break;
case operand_type::IDENTITY:
base += "_Identity";
break;
default:
break;
}
};
AddTypeStr(act_gate_);
AddTypeStr(act_cand_);
return base.c_str();
}
void genCode() override;
protected:
int id_;
int num_;
operand_type act_gate_;
operand_type act_cand_;
reg64_t param1{abi_param1};
};
#define DECLARE_GRU_JITCODE(name, id) \
class name##JitCode : public GRUJitCode { \
public: \
explicit name##JitCode(const gru_attr_t& attr, size_t code_size, \
void* code_ptr = nullptr) \
: GRUJitCode(id, attr, code_size, code_ptr) {} \
};
DECLARE_GRU_JITCODE(GRUH1, 0);
DECLARE_GRU_JITCODE(GRUHtPart1, 1);
DECLARE_GRU_JITCODE(GRUHtPart2, 2);
#undef DECLARE_GRU_JITCODE
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <type_traits>
#include "paddle/fluid/operators/jit/gen_base.h"
#include "paddle/fluid/platform/cpu_info.h"
#define XBYAK_USE_MMAP_ALLOCATOR
#include "xbyak/xbyak.h"
#include "xbyak/xbyak_util.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
// Application Binary Interface
constexpr Xbyak::Operand::Code abi_param1(Xbyak::Operand::RDI),
abi_param2(Xbyak::Operand::RSI), abi_param3(Xbyak::Operand::RDX),
abi_param4(Xbyak::Operand::RCX);
constexpr Xbyak::Operand::Code g_abi_regs[] = {
Xbyak::Operand::RBX, Xbyak::Operand::RBP, Xbyak::Operand::R12,
Xbyak::Operand::R13, Xbyak::Operand::R14, Xbyak::Operand::R15};
constexpr int num_g_abi_regs = sizeof(g_abi_regs) / sizeof(g_abi_regs[0]);
using reg64_t = const Xbyak::Reg64;
using reg32_t = const Xbyak::Reg32;
using xmm_t = const Xbyak::Xmm;
using ymm_t = const Xbyak::Ymm;
using zmm_t = const Xbyak::Zmm;
using Label = Xbyak::Label;
typedef enum {
MUL = 0,
ADD,
SUB,
RELU,
EXP,
SIGMOID,
TANH,
IDENTITY
} operand_type;
#define DECLARE_JIT_CODE(codename) \
const char* name() const override { return #codename; }
class JitCode : public GenBase, public Xbyak::CodeGenerator {
public:
explicit JitCode(size_t code_size, void* code_ptr = nullptr)
: Xbyak::CodeGenerator(
(code_size % 4096 != 0 ? (code_size / 4096 + 1) * 4096 : code_size),
code_ptr) {}
virtual const char* name() const = 0;
virtual void genCode() = 0;
size_t getSize() const override { return CodeGenerator::getSize(); }
const unsigned char* getCodeInternal() override {
const Xbyak::uint8* code = CodeGenerator::getCode();
return code;
}
protected:
Xbyak::Reg64 param1{abi_param1};
const int EVEX_max_8b_offt = 0x200;
const Xbyak::Reg64 reg_EVEX_max_8b_offt = rbp;
virtual void preCode() {
for (int i = 0; i < num_g_abi_regs; ++i) {
push(Xbyak::Reg64(g_abi_regs[i]));
}
if (platform::MayIUse(platform::avx512f)) {
mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt);
}
}
virtual void postCode() {
for (int i = 0; i < num_g_abi_regs; ++i) {
pop(Xbyak::Reg64(g_abi_regs[num_g_abi_regs - 1 - i]));
}
ret();
}
void L(const char* label) { Xbyak::CodeGenerator::L(label); }
void L(const Xbyak::Label& label) { Xbyak::CodeGenerator::L(label); }
// Enhanced vector extension
Xbyak::Address EVEX_compress_addr(Xbyak::Reg64 base, int offt,
bool bcast = false) {
int scale = 0;
// Learn from https://github.com/intel/mkl-dnn
if (EVEX_max_8b_offt <= offt && offt < 3 * EVEX_max_8b_offt) {
offt = offt - 2 * EVEX_max_8b_offt;
scale = 1;
} else if (3 * EVEX_max_8b_offt <= offt && offt < 5 * EVEX_max_8b_offt) {
offt = offt - 4 * EVEX_max_8b_offt;
scale = 2;
}
auto re = Xbyak::RegExp() + base + offt;
if (scale) {
re = re + reg_EVEX_max_8b_offt * scale;
}
if (bcast) {
return zword_b[re];
} else {
return zword[re];
}
}
};
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/lstm.h"
#include <stddef.h> // offsetof
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
void LSTMJitCode::genCode() {
if (use_peephole_) {
preCode();
}
reg64_t reg_ptr_gates = rax;
reg64_t reg_ptr_ct_1 = r9;
reg64_t reg_ptr_ct = r10;
reg64_t reg_ptr_ht = r11;
reg64_t reg_ptr_wp = r12;
mov(reg_ptr_gates, ptr[param1 + offsetof(lstm_t, gates)]);
mov(reg_ptr_ct_1, ptr[param1 + offsetof(lstm_t, ct_1)]);
mov(reg_ptr_ct, ptr[param1 + offsetof(lstm_t, ct)]);
mov(reg_ptr_ht, ptr[param1 + offsetof(lstm_t, ht)]);
if (use_peephole_) {
mov(reg_ptr_wp, ptr[param1 + offsetof(lstm_t, wp)]);
}
int offset = 0;
int d = num_ * sizeof(float);
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
/* gates: W_ch, W_ih, W_fh, W_oh */
ymm_t ymm_c = ymm_t(0);
ymm_t ymm_i = ymm_t(1);
ymm_t ymm_f = ymm_t(2);
ymm_t ymm_o = ymm_t(3);
ymm_t ymm_ct_1 = ymm_t(4);
ymm_t ymm_wp0 = ymm_t(5);
ymm_t ymm_wp1 = ymm_t(6);
ymm_t ymm_wp2 = ymm_t(7);
vmovups(ymm_c, ptr[reg_ptr_gates + offset]);
vmovups(ymm_i, ptr[reg_ptr_gates + offset + d]);
vmovups(ymm_f, ptr[reg_ptr_gates + offset + 2 * d]);
vmovups(ymm_o, ptr[reg_ptr_gates + offset + 3 * d]);
if (!compute_c1h1_) {
vmovups(ymm_ct_1, ptr[reg_ptr_ct_1 + offset]);
}
if (use_peephole_) {
vmovups(ymm_wp0, ptr[reg_ptr_wp + offset]);
vmovups(ymm_wp1, ptr[reg_ptr_wp + offset + d]);
vmovups(ymm_wp2, ptr[reg_ptr_wp + offset + 2 * d]);
}
/* C_t = act_cand(c) * act_gate(i) + C_t-1 * act_gate(f) */
// act_cand(c)
act<ymm_t>(ymm_c, ymm_c, act_cand_);
// act_gate(i) or act_gate(ct_1 * wp0 + i)
if (!compute_c1h1_ && use_peephole_) {
vmulps(ymm_wp0, ymm_ct_1, ymm_wp0);
vaddps(ymm_i, ymm_i, ymm_wp0);
}
act<ymm_t>(ymm_i, ymm_i, act_gate_);
vmulps(ymm_c, ymm_c, ymm_i);
if (!compute_c1h1_) {
// act_gate(f) or act_gate(ct_1 * wp1 + f)
if (use_peephole_) {
vmulps(ymm_wp1, ymm_ct_1, ymm_wp1);
vaddps(ymm_f, ymm_f, ymm_wp1);
}
act<ymm_t>(ymm_f, ymm_f, act_gate_);
// ct
vmulps(ymm_f, ymm_f, ymm_ct_1);
vaddps(ymm_f, ymm_f, ymm_c);
}
/* H_t = act_cell(C_t) * act_gate(o) */
// act_cell(C_t)
ymm_t ymm_ct = compute_c1h1_ ? ymm_c : ymm_f;
ymm_t ymm_tmp = ymm_i;
act<ymm_t>(ymm_tmp, ymm_ct, act_cell_);
// act_gate(o) or act_gate(ct * wp2 + o)
if (use_peephole_) {
vmulps(ymm_wp2, ymm_ct, ymm_wp2);
vaddps(ymm_o, ymm_o, ymm_wp2);
}
act<ymm_t>(ymm_o, ymm_o, act_gate_);
// ht
vmulps(ymm_o, ymm_o, ymm_tmp);
// save ct and ht
vmovups(ptr[reg_ptr_ct + offset], ymm_ct);
vmovups(ptr[reg_ptr_ht + offset], ymm_o);
offset += sizeof(float) * YMM_FLOAT_BLOCK;
}
if (use_peephole_) {
postCode();
} else {
ret();
}
}
#define DECLARE_LSTM_CREATOR(name) \
class name##Creator : public JitCodeCreator<lstm_attr_t> { \
public: \
/* TODO(TJ): enable more */ \
bool UseMe(const lstm_attr_t& attr) const override { \
return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \
} \
size_t CodeSize(const lstm_attr_t& attr) const override { \
return 96 + attr.d / YMM_FLOAT_BLOCK * 90 * 4 * 8; \
} \
std::unique_ptr<GenBase> CreateJitCode( \
const lstm_attr_t& attr) const override { \
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
} \
}
DECLARE_LSTM_CREATOR(LSTMCtHt);
DECLARE_LSTM_CREATOR(LSTMC1H1);
#undef DECLARE_LSTM_CREATOR
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(kLSTMCtHt, gen::LSTMCtHtCreator);
REGISTER_JITKERNEL_GEN(kLSTMC1H1, gen::LSTMC1H1Creator);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/act.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
class LSTMJitCode : public VActFunc {
public:
explicit LSTMJitCode(bool compute_c1h1, const lstm_attr_t& attr,
size_t code_size, void* code_ptr = nullptr)
: VActFunc(code_size, code_ptr),
num_(attr.d),
compute_c1h1_(compute_c1h1),
use_peephole_(attr.use_peephole) {
auto typeExchange = [](KernelType type) -> gen::operand_type {
if (type == KernelType::kVSigmoid) {
return operand_type::SIGMOID;
} else if (type == KernelType::kVRelu) {
return operand_type::RELU;
} else if (type == KernelType::kVTanh) {
return operand_type::TANH;
} else if (type == KernelType::kVIdentity) {
return operand_type::IDENTITY;
} else {
LOG(FATAL) << "Do not support this jit::KernelType: " << type;
}
return operand_type::IDENTITY;
};
act_gate_ = typeExchange(attr.act_gate);
act_cand_ = typeExchange(attr.act_cand);
act_cell_ = typeExchange(attr.act_cell);
this->genCode();
}
const char* name() const override {
std::string base = "LSTMJitCode";
if (use_peephole_) {
base += "_Peephole";
}
if (compute_c1h1_) {
base += "_C1H1";
}
auto AddTypeStr = [&](operand_type type) {
switch (type) {
case operand_type::RELU:
base += "_Relu";
break;
case operand_type::EXP:
base += "_Exp";
break;
case operand_type::SIGMOID:
base += "_Sigmoid";
break;
case operand_type::TANH:
base += "_Tanh";
break;
case operand_type::IDENTITY:
base += "_Identity";
break;
default:
break;
}
};
AddTypeStr(act_gate_);
AddTypeStr(act_cand_);
AddTypeStr(act_cell_);
return base.c_str();
}
void genCode() override;
protected:
int num_;
bool compute_c1h1_;
bool use_peephole_;
operand_type act_gate_;
operand_type act_cand_;
operand_type act_cell_;
reg64_t param1{abi_param1};
};
#define DECLARE_LSTM_JITCODE(name, compute_c1h1) \
class name##JitCode : public LSTMJitCode { \
public: \
explicit name##JitCode(const lstm_attr_t& attr, size_t code_size, \
void* code_ptr = nullptr) \
: LSTMJitCode(compute_c1h1, attr, code_size, code_ptr) {} \
};
DECLARE_LSTM_JITCODE(LSTMCtHt, false);
DECLARE_LSTM_JITCODE(LSTMC1H1, true);
#undef DECLARE_LSTM_JITCODE
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen_base.h"
#include <fstream>
#include <iostream>
#include <sstream>
DEFINE_bool(dump_jitcode, false, "Whether to dump the jitcode to file");
namespace paddle {
namespace operators {
namespace jit {
// refer do not need useme, it would be the last one.
void GenBase::dumpCode(const unsigned char* code) const {
if (code) {
static int counter = 0;
std::ostringstream filename;
filename << "paddle_jitcode_" << name() << "." << counter << ".bin";
counter++;
std::ofstream fout(filename.str(), std::ios::out);
if (fout.is_open()) {
fout.write(reinterpret_cast<const char*>(code), this->getSize());
fout.close();
}
}
}
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <gflags/gflags.h>
#include <memory> // for unique_ptr
#include "paddle/fluid/operators/jit/kernel_base.h"
DECLARE_bool(dump_jitcode);
namespace paddle {
namespace operators {
namespace jit {
class GenBase : public Kernel {
public:
virtual ~GenBase() = default;
virtual const char* name() const = 0;
virtual size_t getSize() const = 0;
virtual const unsigned char* getCodeInternal() = 0;
template <typename Func>
Func getCode() {
const unsigned char* code = this->getCodeInternal();
if (FLAGS_dump_jitcode) {
this->dumpCode(code);
}
return reinterpret_cast<Func>(const_cast<unsigned char*>(code));
}
protected:
void dumpCode(const unsigned char* code) const;
};
// Creator is used to creat the jitcode and save in pool.
// Every JitCode should have one creator.
class GenCreator {
public:
virtual ~GenCreator() = default;
};
template <typename Attr>
class JitCodeCreator : public GenCreator {
public:
virtual ~JitCodeCreator() = default;
// condition when this jit code can be used.
virtual bool UseMe(const Attr& attr) const = 0;
// estimate this code size
virtual size_t CodeSize(const Attr& attr) const = 0;
// create this code
virtual std::unique_ptr<GenBase> CreateJitCode(const Attr& attr) const = 0;
};
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/helper.h"
#include <algorithm> // tolower
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace jit {
#define ONE_CASE(key) \
case key: \
return #key
const char* to_string(KernelType kt) {
switch (kt) {
ONE_CASE(kVMul);
ONE_CASE(kVAdd);
ONE_CASE(kVAddRelu);
ONE_CASE(kVSub);
ONE_CASE(kVScal);
ONE_CASE(kVAddBias);
ONE_CASE(kVRelu);
ONE_CASE(kVIdentity);
ONE_CASE(kVExp);
ONE_CASE(kVSigmoid);
ONE_CASE(kVTanh);
ONE_CASE(kLSTMCtHt);
ONE_CASE(kLSTMC1H1);
ONE_CASE(kGRUH1);
ONE_CASE(kGRUHtPart1);
ONE_CASE(kGRUHtPart2);
ONE_CASE(kCRFDecoding);
ONE_CASE(kLayerNorm);
ONE_CASE(kNCHW16CMulNC);
default:
PADDLE_THROW("Not support type: %d, or forget to add it.", kt);
return "NOT JITKernel";
}
return nullptr;
}
#undef ONE_CASE
KernelType to_kerneltype(const std::string& act) {
std::string lower = act;
std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower);
if (lower == "relu" || lower == "vrelu") {
return kVRelu;
} else if (lower == "identity" || lower == "videntity" || lower == "") {
return kVIdentity;
} else if (lower == "exp" || lower == "vexp") {
return kVExp;
} else if (lower == "sigmoid" || lower == "vsigmoid") {
return kVSigmoid;
} else if (lower == "tanh" || lower == "vtanh") {
return kVTanh;
}
PADDLE_THROW("Not support type: %s, or forget to add this case", act);
return kNone;
}
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <iostream>
#include <string>
#include <vector>
#include "paddle/fluid/operators/jit/gen_base.h"
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/operators/jit/kernel_key.h"
#include "paddle/fluid/operators/jit/kernel_pool.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace operators {
namespace jit {
template <KernelType KT, typename KernelTuples, typename PlaceType>
inline typename std::enable_if<
std::is_same<typename KernelTuples::data_type, float>::value &&
std::is_same<PlaceType, platform::CPUPlace>::value,
typename KernelTuples::func_type>::type
GetJitCode(const typename KernelTuples::attr_type& attr) {
using Func = typename KernelTuples::func_type;
using Attr = typename KernelTuples::attr_type;
size_t key = JitCodeKey<Attr>(attr);
auto& codes = JitCodePool<KT>().Instance();
if (codes.Has(key)) {
return codes.AllKernels().at(key)->template getCode<Func>();
}
// creator is not related with attr, so can use KernelKey as key
KernelKey kkey(KT, PlaceType());
// pool: (KernelKey(type, place), vector<GenCreatorPtr>)
auto& creator_map = JitCodeCreatorPool().Instance().AllCreators();
auto iter = creator_map.find(kkey);
if (iter != creator_map.end()) {
auto& creators = iter->second;
for (auto& cur : creators) {
auto i = dynamic_cast<const JitCodeCreator<Attr>*>(cur.get());
if (i && i->UseMe(attr)) {
auto p = i->CreateJitCode(attr);
if (p) {
auto f = p->template getCode<Func>();
codes.Insert(key, std::move(p));
return f;
}
}
}
}
return nullptr;
}
template <KernelType KT, typename KernelTuples, typename PlaceType>
inline typename std::enable_if<
!std::is_same<typename KernelTuples::data_type, float>::value ||
!std::is_same<PlaceType, platform::CPUPlace>::value,
typename KernelTuples::func_type>::type
GetJitCode(const typename KernelTuples::attr_type& attr) {
return nullptr;
}
// Refer code do not related with attr, which is just for cast
// Refer is always on CPUPlace
template <KernelType KT, typename KernelTuples>
inline typename KernelTuples::func_type GetRefer() {
auto& ref_pool = ReferKernelPool().Instance().AllKernels();
KernelKey kkey(KT, platform::CPUPlace());
auto ref_iter = ref_pool.find(kkey);
PADDLE_ENFORCE(ref_iter != ref_pool.end(),
"Every Kernel should have reference function.");
auto& ref_impls = ref_iter->second;
for (auto& impl : ref_impls) {
auto i = dynamic_cast<const ReferKernel<KernelTuples>*>(impl.get());
if (i) {
return i->GetFunc();
}
}
return nullptr;
}
template <KernelType KT, typename KernelTuples,
typename PlaceType = platform::CPUPlace>
typename KernelTuples::func_type Get(
const typename KernelTuples::attr_type& attr) {
auto jitfunc = GetJitCode<KT, KernelTuples, PlaceType>(attr);
if (jitfunc) {
return jitfunc;
}
// pool: (KernelKey(type, place), vector<KernelPtr>)
KernelKey kkey(KT, PlaceType());
auto& pool = KernelPool().Instance().AllKernels();
auto iter = pool.find(kkey);
if (iter != pool.end()) {
auto& impls = iter->second;
for (auto& impl : impls) {
auto i = dynamic_cast<const KernelMore<KernelTuples>*>(impl.get());
if (i && i->UseMe(attr)) {
return i->GetFunc();
}
}
}
// The last implementation should be reference function on CPUPlace.
return GetRefer<KT, KernelTuples>();
}
const char* to_string(KernelType kt);
KernelType to_kerneltype(const std::string& act);
inline std::ostream& operator<<(std::ostream& os, const lstm_attr_t& attr) {
os << "dim_size[" << attr.d << "],act_gate[" << to_string(attr.act_gate)
<< "],act_cand[" << to_string(attr.act_cand) << "],act_cell["
<< to_string(attr.act_cell) << "],use_peephole["
<< (attr.use_peephole ? "True" : "False") << "]";
return os;
}
inline std::ostream& operator<<(std::ostream& os, const gru_attr_t& attr) {
os << "dim_size[" << attr.d << "],act_gate[" << to_string(attr.act_gate)
<< "],act_cand[" << to_string(attr.act_cand) << "]";
return os;
}
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include "paddle/fluid/operators/jit/macro.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace operators {
namespace jit {
typedef enum {
kNone = 0,
kVMul = 1,
kVAdd = 2,
kVAddRelu,
kVSub,
kVScal,
kVAddBias,
kVRelu,
kVIdentity,
kVExp,
kVSigmoid,
kVTanh,
kLSTMCtHt,
kLSTMC1H1,
kGRUH1,
kGRUHtPart1,
kGRUHtPart2,
kCRFDecoding,
kLayerNorm,
kNCHW16CMulNC,
} KernelType;
template <typename T>
struct XYZNTuples {
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const T*, const T*, T*, int);
};
template <typename T>
struct AXYNTuples : public XYZNTuples<T> {};
template <typename T>
struct XYNTuples {
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const T*, T*, int);
};
typedef struct {
void* gates; // gates: x_ch, x_ih, x_fh, x_oh
const void* ct_1;
void* ct;
void* ht;
/* weight_peephole and checked data are only used in peephole*/
const void* wp{nullptr}; // W_ic, W_fc, W_oc
void* checked{nullptr}; // size: 2 * d
} lstm_t;
typedef struct {
void* gates; // gates: {x_update, x_reset; x_state}
const void* ht_1;
void* ht;
} gru_t;
struct rnn_attr_s {
int d;
KernelType act_gate, act_cand;
rnn_attr_s() = default;
explicit rnn_attr_s(int _d, KernelType _act_gate, KernelType _act_cand)
: d(_d), act_gate(_act_gate), act_cand(_act_cand) {}
};
struct lstm_attr_s : public rnn_attr_s {
bool use_peephole;
KernelType act_cell;
lstm_attr_s() = default;
explicit lstm_attr_s(int _d, KernelType _act_gate, KernelType _act_cand,
KernelType _act_cell, bool _use_peephole = false)
: rnn_attr_s(_d, _act_gate, _act_cand),
use_peephole(_use_peephole),
act_cell(_act_cell) {}
};
typedef struct rnn_attr_s gru_attr_t;
typedef struct lstm_attr_s lstm_attr_t;
template <typename T>
struct LSTMTuples {
typedef T data_type;
typedef lstm_attr_t attr_type;
typedef void (*func_type)(lstm_t*, const lstm_attr_t*);
};
template <typename T>
struct GRUTuples {
typedef T data_type;
typedef gru_attr_t attr_type;
typedef void (*func_type)(gru_t*, const gru_attr_t*);
};
template <typename T>
struct CRFDecodingTuples {
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const int, const T*, const T*, T*, int*, int);
};
template <typename T>
struct LayerNormTuples {
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(T*, T*, T*, T*, const T*, const T*, int,
const float, int);
};
// nChw16c = nChw16c .* NC
template <typename T>
struct NCHW16CMulNCTuples {
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const T*, const T*, T*, int, int);
};
// Just for adding to kernel pool without template
class Kernel {
public:
Kernel() = default;
virtual ~Kernel() = default;
DISABLE_COPY_AND_ASSIGN(Kernel);
};
template <typename KernelTuples>
class KernelMore : public Kernel {
public:
using T = typename KernelTuples::data_type;
using Func = typename KernelTuples::func_type;
using Attr = typename KernelTuples::attr_type;
virtual Func GetFunc() const { return func; }
virtual bool UseMe(const Attr& attr) const = 0;
virtual const char* ImplType() const = 0;
protected:
Func func{nullptr};
};
template <typename KernelTuples>
class ReferKernel : public KernelMore<KernelTuples> {
public:
// Refer code can always be used
bool UseMe(const typename KernelTuples::attr_type& attr) const override {
return true;
}
const char* ImplType() const override { return "Refer"; }
};
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/kernel_key.h"
namespace paddle {
namespace operators {
namespace jit {
template <>
size_t JitCodeKey<int>(const int& d) {
return d;
}
constexpr int act_type_shift = 3; // suppot 2^3 act types
template <>
size_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) {
size_t key = attr.d;
int gate_key = static_cast<int>(attr.act_gate) << 1;
int cand_key = static_cast<int>(attr.act_cand) << (1 + act_type_shift);
int cell_key = static_cast<int>(attr.act_cell) << (1 + act_type_shift * 2);
return (key << (1 + act_type_shift * 3)) + gate_key + cand_key + cell_key +
attr.use_peephole;
}
template <>
size_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) {
size_t key = attr.d;
return (key << (act_type_shift * 2)) + static_cast<int>(attr.act_gate) +
(static_cast<int>(attr.act_cand) << act_type_shift);
}
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace operators {
namespace jit {
struct KernelKey {
struct Hash {
size_t operator()(const KernelKey& key) const {
int place = key.place_.which(); // less than 2^8
int type = static_cast<int>(key.type_) << 8; // less than 2^(32-8)
std::hash<int> hasher;
return hasher(place + type);
}
};
KernelType type_;
platform::Place place_;
KernelKey(KernelType type, platform::Place place)
: type_(type), place_(place) {}
size_t hash_key() const { return Hash()(*this); }
bool operator==(const KernelKey& o) const {
return platform::places_are_same_class(place_, o.place_) &&
type_ == o.type_;
}
bool operator!=(const KernelKey& o) const { return !(*this == o); }
};
// Every JitCode should have a method to get the key from attribution
template <typename Attr>
size_t JitCodeKey(const Attr& attr);
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/kernel_pool.h"
#include <memory> // for shared_ptr
#include <string>
#include <unordered_map>
namespace paddle {
namespace operators {
namespace jit {
JitCodeCreatorPool& JitCodeCreatorPool::Instance() {
static JitCodeCreatorPool g_creator_pool;
return g_creator_pool;
}
KernelPool& KernelPool::Instance() {
static KernelPool g_kernel_pool;
return g_kernel_pool;
}
ReferKernelPool& ReferKernelPool::Instance() {
static ReferKernelPool g_refer_kernel_pool;
return g_refer_kernel_pool;
}
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <memory> // for unique_ptr
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/operators/jit/gen_base.h"
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/operators/jit/kernel_key.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace operators {
namespace jit {
template <KernelType KT>
class JitCodePool {
typedef std::unique_ptr<GenBase> GenBasePtr;
typedef std::unordered_map<size_t, GenBasePtr> JitCodeMap;
public:
JitCodePool() = default;
static JitCodePool& Instance() {
static thread_local JitCodePool<KT> g_jit_codes;
return g_jit_codes;
}
const JitCodeMap& AllKernels() { return codes_; }
bool Has(size_t key) const { return codes_.find(key) != codes_.end(); }
void Insert(size_t key, GenBasePtr value) {
codes_.emplace(key, std::move(value));
}
private:
JitCodeMap codes_;
DISABLE_COPY_AND_ASSIGN(JitCodePool);
};
class JitCodeCreatorPool {
typedef std::unique_ptr<const GenCreator> GenCreatorPtr;
typedef std::unordered_map<KernelKey, std::vector<GenCreatorPtr>,
KernelKey::Hash>
GenCreatorPtrMap;
public:
JitCodeCreatorPool() = default;
static JitCodeCreatorPool& Instance();
GenCreatorPtrMap& AllCreators() { return creators_; }
void Insert(const KernelKey& key, GenCreatorPtr value) {
if (creators_.find(key) == creators_.end()) {
creators_.emplace(key, std::vector<GenCreatorPtr>());
}
creators_.at(key).emplace_back(std::move(value));
}
private:
GenCreatorPtrMap creators_;
DISABLE_COPY_AND_ASSIGN(JitCodeCreatorPool);
};
typedef std::unique_ptr<const Kernel> KernelPtr;
typedef std::unordered_map<KernelKey, std::vector<KernelPtr>, KernelKey::Hash>
KernelMap;
class KernelPool {
public:
static KernelPool& Instance();
KernelPool() = default;
KernelMap& AllKernels() { return pool_; }
void Insert(const KernelKey& key, KernelPtr value) {
if (pool_.find(key) == pool_.end()) {
pool_.emplace(key, std::vector<KernelPtr>());
}
pool_.at(key).emplace_back(std::move(value));
}
private:
KernelMap pool_;
DISABLE_COPY_AND_ASSIGN(KernelPool);
};
// Every kernel should have refer code and it should be used in unit tests,
// so refer kernels should have it's independent kernel pool
class ReferKernelPool {
public:
static ReferKernelPool& Instance();
ReferKernelPool() = default;
KernelMap& AllKernels() { return pool_; }
void Insert(const KernelKey& key, KernelPtr value) {
if (pool_.find(key) == pool_.end()) {
pool_.emplace(key, std::vector<KernelPtr>());
}
pool_.at(key).emplace_back(std::move(value));
}
private:
KernelMap pool_;
DISABLE_COPY_AND_ASSIGN(ReferKernelPool);
};
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <type_traits>
namespace paddle {
namespace operators {
namespace jit {
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
#define XMM_FLOAT_BLOCK 4
#define YMM_FLOAT_BLOCK 8
#define ZMM_FLOAT_BLOCK 16
} // namespace jit
} // namespace operators
} // namespace paddle
function(USE_JITKERNEL_MORE TARGET TYPE)
file(APPEND ${jit_file} "USE_JITKERNEL_MORE(${TARGET} ${TYPE});\n")
endfunction()
if(WITH_MKLML)
add_subdirectory(mkl)
endif()
if(WITH_AVX)
add_subdirectory(intrinsic)
endif()
# mix should be last
add_subdirectory(mix)
set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} PARENT_SCOPE)
file(GLOB jit_kernel_cc_intrinsic RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cc")
cc_library(jit_kernel_intrinsic SRCS ${jit_kernel_cc_intrinsic} DEPS jit_kernel_base)
set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} jit_kernel_intrinsic PARENT_SCOPE)
# use mkl kernels by name and type
USE_JITKERNEL_MORE(kCRFDecoding, intrinsic)
USE_JITKERNEL_MORE(kLayerNorm, intrinsic)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/more/intrinsic/crf_decoding.h"
#include <limits>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
namespace jit {
namespace more {
namespace intrinsic {
// Note: intrinsic code is not runtime build.
// For example, if you build code on AVX, and run on AVX512 it can only use AVX
void CRFDecoding(const int seq_len, const float* x, const float* w,
float* alpha, int* track, int tag_num) {
#ifdef __AVX512F__
const int step_size = ZMM_FLOAT_BLOCK;
#else
const int step_size = YMM_FLOAT_BLOCK;
#endif
const int end = tag_num / step_size;
const int rest = tag_num % step_size;
/* Setup the alpha initial value.*/
int i_offset = 0;
int last_offset = rest - step_size;
for (int i = 0; i <= end; ++i) {
#ifdef __AVX512F__
// Declare the variable for the content of weights, input and alpha values.
__m512 w_content, x_content, alpha_content;
// Load the relevant data into the variables from un-aligned address.
w_content = _mm512_loadu_ps(w + i_offset);
x_content = _mm512_loadu_ps(x + i_offset);
alpha_content = _mm512_add_ps(w_content, x_content);
// Save the alpha value.
_mm512_storeu_ps(alpha_value + i_offset, alpha_content);
#else
// AVX or AVX2
// weights, input and alpha values.
__m256 w_content, x_content, alpha_content;
// Load the relevant data into the variables from un-aligned address.
w_content = _mm256_loadu_ps(w + i_offset);
x_content = _mm256_loadu_ps(x + i_offset);
alpha_content = _mm256_add_ps(w_content, x_content);
_mm256_storeu_ps(alpha + i_offset, alpha_content);
#endif
i_offset += step_size;
if (i == end - 1) {
if (rest > 0) {
i_offset += last_offset;
} else {
break;
}
}
}
// Use the column-major strategy to get the location of maximum score.
int seq_offset = 0;
constexpr int state_trans_base_idx = 2;
for (int k = 1; k < seq_len; ++k) {
int j_offset = 0;
for (int j = 0; j <= end; ++j) {
/* Initialize the variables of maximum score and location.*/
#ifdef __AVX512F__
__m512 max_score = _mm512_set1_ps(-std::numeric_limits<float>::max());
__m512i max_j = _mm512_setzero_si512();
#else
__m256 max_score = _mm256_set1_ps(-std::numeric_limits<float>::max());
__m256i max_j = _mm256_set1_epi32(0);
#endif
/* Calculate the offset of transition_weights.*/
int trans_offset = state_trans_base_idx * tag_num + j_offset;
for (int i = 0; i < tag_num; ++i) {
/* Initalize the content of alpha variable with related offset.*/
#ifdef __AVX512F__
__m512 alpha_content = _mm512_set1_ps(*(alpha + seq_offset + i));
/* Obtain the content of weights from un-aligned address.*/
__m512 w_content = _mm512_loadu_ps(w + trans_offset);
__m512 score_v = _mm512_add_ps(alpha_content, w_content);
__mmask16 mask = _mm512_cmp_ps_mask(score_v, max_score, _CMP_GT_OS);
/* AVX512 instructions.*/
max_j = _mm512_mask_set1_epi32(max_j, mask, i);
/* Update the max_score value.*/
max_score = _mm512_max_ps(max_score, score_v);
#else
__m256 alpha_content = _mm256_broadcast_ss(alpha + seq_offset + i);
/* Obtain the content of weights from un-aligned address.*/
__m256 w_content = _mm256_loadu_ps(w + trans_offset);
__m256 score_v = _mm256_add_ps(alpha_content, w_content);
__m256 mask = _mm256_cmp_ps(score_v, max_score, _CMP_GT_OS);
/* According to the mask value, update the index of the max_score.*/
#ifdef __AVX2__
max_j = _mm256_or_si256(
_mm256_andnot_si256((__m256i)mask, max_j),
_mm256_and_si256((__m256i)mask, _mm256_set1_epi32(i)));
#else
__m128i lo_max_j = _mm256_extractf128_si256(max_j, 0);
__m128i hi_max_j = _mm256_extractf128_si256(max_j, 1);
__m128i lo_mask =
_mm256_extractf128_si256(*(__m256i*)&mask, 0); // NOLINT
__m128i hi_mask =
_mm256_extractf128_si256(*(__m256i*)&mask, 1); // NOLINT
lo_max_j = _mm_andnot_si128(lo_mask, lo_max_j);
hi_max_j = _mm_andnot_si128(hi_mask, hi_max_j);
lo_mask = _mm_and_si128(lo_mask, _mm_set1_epi32(i));
hi_mask = _mm_and_si128(hi_mask, _mm_set1_epi32(i));
lo_max_j = _mm_or_si128(lo_mask, lo_max_j);
hi_max_j = _mm_or_si128(hi_mask, hi_max_j);
max_j = _mm256_insertf128_si256(max_j, lo_max_j, 0);
max_j = _mm256_insertf128_si256(max_j, hi_max_j, 1);
#endif
/* Update the max_score value.*/
max_score = _mm256_max_ps(max_score, score_v);
#endif
trans_offset += tag_num;
}
/* Update the alpha and track values. */
#ifdef __AVX512F__
__m512 x_content =
_mm512_loadu_ps(x + seq_offset + this->num_ + j_offset);
max_score = _mm512_add_ps(max_score, x_content);
_mm512_storeu_ps(alpha + seq_offset + this->num_ + j_offset, max_score);
_mm512_storeu_si512(reinterpret_cast<__m512i*>(track + seq_offset +
this->num_ + j_offset),
max_j);
#else
__m256 x_content = _mm256_loadu_ps(x + seq_offset + tag_num + j_offset);
max_score = _mm256_add_ps(max_score, x_content);
_mm256_storeu_ps(alpha + seq_offset + tag_num + j_offset, max_score);
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(track + seq_offset + tag_num + j_offset),
max_j);
#endif
/* Calculate the offset of next step*/
j_offset += step_size;
if (j == end - 1) {
if (rest > 0) {
j_offset += last_offset;
} else {
break;
}
}
}
seq_offset += tag_num;
}
}
bool CRFDecodingKernel::UseMe(const int& d) const {
#ifdef __AVX512F__
constexpr int block = ZMM_FLOAT_BLOCK;
#else
constexpr int block = YMM_FLOAT_BLOCK;
#endif
return platform::MayIUse(platform::avx) && d >= block;
}
} // namespace intrinsic
} // namespace more
} // namespace jit
} // namespace operators
} // namespace paddle
namespace intrinsic = paddle::operators::jit::more::intrinsic;
REGISTER_JITKERNEL_MORE(kCRFDecoding, intrinsic, intrinsic::CRFDecodingKernel);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <type_traits>
#include "paddle/fluid/operators/jit/kernel_base.h"
namespace paddle {
namespace operators {
namespace jit {
namespace more {
namespace intrinsic {
void CRFDecoding(const int seq_len, const float* x, const float* w,
float* alpha, int* track, int tag_num);
class CRFDecodingKernel : public KernelMore<CRFDecodingTuples<float>> {
public:
CRFDecodingKernel() { this->func = CRFDecoding; }
bool UseMe(
const typename CRFDecodingTuples<float>::attr_type&) const override;
const char* ImplType() const override { return "Intrinsic"; }
};
} // namespace intrinsic
} // namespace more
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/more/intrinsic/layer_norm.h"
#include <limits>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
namespace jit {
namespace more {
namespace intrinsic {
void LayerNorm(float* x, float* out, float* mean, float* var,
const float* scale, const float* bias, int height,
const float epsilon, int right) {
__m256 sum;
__m256 mean_vec, var_vec;
__m128 hi, lo;
__m256 tmp;
size_t offset;
size_t j;
int block = YMM_FLOAT_BLOCK;
const int rest = right % block;
const int end = right - rest;
__m256 reverse_num_vec =
_mm256_div_ps(_mm256_set1_ps(1.0), _mm256_set1_ps(right));
__m256 epsilon_vec = _mm256_set1_ps(epsilon);
int rest_mask =
((-1) & (~((~0U) >> (sizeof(int) * 8 - (block - rest))))) & 0x0ff;
__m256i mask_vec = _mm256_set_epi32(
rest_mask & 0x80 ? 0xffffffff : 0, rest_mask & 0x40 ? 0xffffffff : 0,
rest_mask & 0x20 ? 0xffffffff : 0, rest_mask & 0x10 ? 0xffffffff : 0,
rest_mask & 0x8 ? 0xffffffff : 0, rest_mask & 0x4 ? 0xffffffff : 0,
rest_mask & 0x2 ? 0xffffffff : 0, rest_mask & 0x1 ? 0xffffffff : 0);
for (int i = 0; i < height; ++i) {
offset = i * right;
/* get mean */
sum = _mm256_setzero_ps();
for (j = offset; j < end + offset; j += block) {
sum = _mm256_add_ps(sum, _mm256_loadu_ps((const float*)x + j));
}
if (rest != 0) {
j = offset + right - block;
tmp = _mm256_loadu_ps((const float*)x + j);
tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp,
*(__m256*)&mask_vec); // NOLINT
sum = _mm256_add_ps(sum, tmp);
}
hi = _mm256_extractf128_ps(sum, 1);
lo = _mm256_extractf128_ps(sum, 0);
sum = _mm256_add_ps(
sum, _mm256_insertf128_ps(
_mm256_insertf128_ps(_mm256_setzero_ps(), hi, 0), lo, 1));
sum = _mm256_hadd_ps(sum, sum);
sum = _mm256_hadd_ps(sum, sum);
mean_vec = _mm256_mul_ps(sum, reverse_num_vec);
mean[i] = *reinterpret_cast<float*>(&mean_vec);
/* get variance */
sum = _mm256_setzero_ps();
for (j = offset; j < end + offset; j += block) {
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec);
tmp = _mm256_mul_ps(tmp, tmp);
sum = _mm256_add_ps(sum, tmp);
}
if (rest != 0) {
j = offset + right - block;
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec);
tmp = _mm256_mul_ps(tmp, tmp);
tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp,
*(__m256*)&mask_vec); // NOLINT
sum = _mm256_add_ps(sum, tmp);
}
hi = _mm256_extractf128_ps(sum, 1);
lo = _mm256_extractf128_ps(sum, 0);
sum = _mm256_add_ps(
sum, _mm256_insertf128_ps(
_mm256_insertf128_ps(_mm256_setzero_ps(), hi, 0), lo, 1));
sum = _mm256_hadd_ps(sum, sum);
sum = _mm256_hadd_ps(sum, sum);
var_vec = _mm256_mul_ps(sum, reverse_num_vec);
var[i] = *reinterpret_cast<float*>(&var_vec);
/* get x_norm and calculate output*/
for (j = offset; j < end + offset; j += block) {
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec);
tmp = _mm256_div_ps(tmp,
_mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec)));
_mm256_storeu_ps(reinterpret_cast<float*>(out) + j, tmp);
}
if (rest != 0) {
j = offset + right - block;
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec);
tmp = _mm256_div_ps(tmp,
_mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec)));
_mm256_storeu_ps(reinterpret_cast<float*>(out) + j, tmp);
}
if (scale) {
if (rest != 0) {
j = offset + right - block;
tmp = _mm256_loadu_ps((const float*)out + j);
}
for (j = offset; j < end + offset; j += block) {
_mm256_storeu_ps(
reinterpret_cast<float*>(out) + j,
_mm256_mul_ps(_mm256_loadu_ps((const float*)out + j),
_mm256_loadu_ps((const float*)scale + j - offset)));
}
if (rest != 0) {
j = offset + right - block;
_mm256_storeu_ps(
reinterpret_cast<float*>(out) + j,
_mm256_mul_ps(tmp,
_mm256_loadu_ps((const float*)scale + j - offset)));
}
}
if (bias) {
if (rest != 0) {
j = offset + right - block;
tmp = _mm256_loadu_ps((const float*)out + j);
}
for (j = offset; j < end + offset; j += block) {
_mm256_storeu_ps(
reinterpret_cast<float*>(out) + j,
_mm256_add_ps(_mm256_loadu_ps((const float*)out + j),
_mm256_loadu_ps((const float*)bias + j - offset)));
}
if (rest != 0) {
j = offset + right - block;
_mm256_storeu_ps(reinterpret_cast<float*>(out) + j,
_mm256_add_ps(tmp, _mm256_loadu_ps((const float*)bias +
j - offset)));
}
}
}
}
bool LayerNormKernel::UseMe(const int& d) const {
return platform::MayIUse(platform::avx) && d >= YMM_FLOAT_BLOCK;
}
} // namespace intrinsic
} // namespace more
} // namespace jit
} // namespace operators
} // namespace paddle
namespace intrinsic = paddle::operators::jit::more::intrinsic;
REGISTER_JITKERNEL_MORE(kLayerNorm, intrinsic, intrinsic::LayerNormKernel);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <type_traits>
#include "paddle/fluid/operators/jit/kernel_base.h"
namespace paddle {
namespace operators {
namespace jit {
namespace more {
namespace intrinsic {
void LayerNorm(float* x, float* out, float* mean, float* var,
const float* scale, const float* bias, int height,
const float epsilon, int right);
class LayerNormKernel : public KernelMore<LayerNormTuples<float>> {
public:
LayerNormKernel() { this->func = LayerNorm; }
bool UseMe(const typename LayerNormTuples<float>::attr_type&) const override;
const char* ImplType() const override { return "Intrinsic"; }
};
} // namespace intrinsic
} // namespace more
} // namespace jit
} // namespace operators
} // namespace paddle
file(GLOB jit_kernel_mix_cc RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cc")
cc_library(jit_kernel_mix SRCS ${jit_kernel_mix_cc} DEPS jit_kernel_base)
set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} jit_kernel_mix PARENT_SCOPE)
USE_JITKERNEL_MORE(kVSigmoid, mix)
USE_JITKERNEL_MORE(kVTanh, mix)
USE_JITKERNEL_MORE(kLSTMCtHt, mix)
USE_JITKERNEL_MORE(kLSTMC1H1, mix)
USE_JITKERNEL_MORE(kGRUH1, mix)
USE_JITKERNEL_MORE(kGRUHtPart1, mix)
USE_JITKERNEL_MORE(kGRUHtPart2, mix)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/more/mix/mix.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
namespace jit {
namespace more {
namespace mix {
void VSigmoid(const T* x, T* y, int n) {
const float min = SIGMOID_THRESHOLD_MIN;
const float max = SIGMOID_THRESHOLD_MAX;
for (int i = 0; i < n; ++i) {
y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = static_cast<T>(0) - y[i];
}
auto compute = Get<KernelType::kVExp, XYNTuples<T>, platform::CPUPlace>(n);
compute(y, y, n);
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(1) / (static_cast<T>(1) + y[i]);
}
}
void VTanh(const T* x, T* y, int n) {
const T a = 2, b = -1;
auto compute_scal = Get<kVScal, AXYNTuples<T>, platform::CPUPlace>(n);
auto compute_addbias = Get<kVAddBias, AXYNTuples<T>, platform::CPUPlace>(n);
auto compute_sigmoid = Get<kVSigmoid, XYNTuples<T>, platform::CPUPlace>(n);
compute_scal(&a, x, y, n);
compute_sigmoid(y, y, n);
compute_scal(&a, y, y, n);
compute_addbias(&b, y, y, n);
}
void (*getActFunc(KernelType type, int d))(const T*, T*, int) { // NOLINT
if (type == kVSigmoid) {
return Get<kVSigmoid, XYNTuples<T>, platform::CPUPlace>(d);
} else if (type == kVRelu) {
return Get<kVRelu, XYNTuples<T>, platform::CPUPlace>(d);
} else if (type == kVTanh) {
return Get<kVTanh, XYNTuples<T>, platform::CPUPlace>(d);
} else if (type == kVIdentity) {
return Get<kVIdentity, XYNTuples<T>, platform::CPUPlace>(d);
}
PADDLE_THROW("Not support type: %s", type);
return nullptr;
}
void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) {
T* gates = reinterpret_cast<T*>(step->gates);
const T* ct_1 = reinterpret_cast<const T*>(step->ct_1);
T* ct = reinterpret_cast<T*>(step->ct);
T* ht = reinterpret_cast<T*>(step->ht);
const T* wp = reinterpret_cast<const T*>(step->wp);
T* checked = reinterpret_cast<T*>(step->checked);
const int d = attr->d;
const int d2 = d * 2;
const int d3 = d * 3;
auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(d);
auto vadd_d = Get<kVAdd, XYZNTuples<T>, platform::CPUPlace>(d);
auto vadd_d2 = Get<kVAdd, XYZNTuples<T>, platform::CPUPlace>(d2);
auto act_gate_d = getActFunc(attr->act_gate, d);
auto act_gate_d2 = getActFunc(attr->act_gate, d2);
auto act_gate_d3 = getActFunc(attr->act_gate, d3);
auto act_cand_d = getActFunc(attr->act_cand, d);
auto act_cell_d = getActFunc(attr->act_cell, d);
if (attr->use_peephole) {
vmul_d(wp, ct_1, checked, d);
vmul_d(wp + d, ct_1, checked + d, d);
vadd_d2(checked, gates + d, gates + d, d2);
act_gate_d2(gates + d, gates + d, d2);
} else {
act_gate_d3(gates + d, gates + d, d3);
}
// C_t = C_t-1 * fgated + cand_gated * igated
act_cand_d(gates, gates, d);
vmul_d(gates, gates + d, gates + d, d);
vmul_d(ct_1, gates + d2, gates + d2, d);
vadd_d(gates + d, gates + d2, ct, d);
if (attr->use_peephole) {
// get ogated
vmul_d(wp + d2, ct, gates + d, d);
vadd_d(gates + d, gates + d3, gates + d3, d);
act_gate_d(gates + d3, gates + d3, d);
}
// H_t = act_cell(C_t) * ogated
act_cell_d(ct, gates + d2, d);
vmul_d(gates + d2, gates + d3, ht, d);
}
void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr) {
T* gates = reinterpret_cast<T*>(step->gates);
T* ct = reinterpret_cast<T*>(step->ct);
T* ht = reinterpret_cast<T*>(step->ht);
int d = attr->d;
int d2 = d * 2;
int d3 = d * 3;
auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(d);
auto vadd_d = Get<kVAdd, XYZNTuples<T>, platform::CPUPlace>(d);
auto act_gate_d = getActFunc(attr->act_gate, d);
auto act_cand_d = getActFunc(attr->act_cand, d);
auto act_cell_d = getActFunc(attr->act_cell, d);
/* C_t = igated * cgated*/
act_gate_d(gates + d, gates + d, d);
act_cand_d(gates, gates, d);
vmul_d(gates, gates + d, ct, d);
if (attr->use_peephole) {
// get outgated, put W_oc * C_t on igated
const T* wp = reinterpret_cast<const T*>(step->wp);
vmul_d(wp + d2, ct, gates + d, d);
vadd_d(gates + d, gates + d3, gates + d3, d);
}
/* H_t = act_cell(C_t) * ogated */
act_gate_d(gates + d3, gates + d3, d);
act_cell_d(ct, gates + d2, d);
vmul_d(gates + d2, gates + d3, ht, d);
}
// compute h1 without h0
void GRUH1(gru_t* step, const gru_attr_t* attr) {
T* gates = reinterpret_cast<T*>(step->gates);
T* ht = reinterpret_cast<T*>(step->ht);
int d = attr->d;
int d2 = d * 2;
auto act_gate = getActFunc(attr->act_gate, d);
auto act_cand = getActFunc(attr->act_cand, d);
auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(d);
act_gate(gates, gates, d);
act_cand(gates + d2, gates + d2, d);
vmul_d(gates, gates + d2, ht, d);
}
// compute the first part of GRU: ht = act_gate(r) * ht_1
void GRUHtPart1(gru_t* step, const gru_attr_t* attr) {
// W: {W_update, W_reset; W_state}
T* gates = reinterpret_cast<T*>(step->gates);
T* ht = reinterpret_cast<T*>(step->ht);
const T* ht_1 = reinterpret_cast<const T*>(step->ht_1);
auto act_gate = getActFunc(attr->act_gate, attr->d);
auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(attr->d);
act_gate(gates + attr->d, gates + attr->d, attr->d);
vmul_d(ht_1, gates + attr->d, ht, attr->d);
}
// compute the second part of GRU:
// ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1
void GRUHtPart2(gru_t* step, const gru_attr_t* attr) {
T* gates = reinterpret_cast<T*>(step->gates);
T* ht = reinterpret_cast<T*>(step->ht);
const T* ht_1 = reinterpret_cast<const T*>(step->ht_1);
int d = attr->d;
auto act_gate = getActFunc(attr->act_gate, d);
auto act_cand = getActFunc(attr->act_cand, d);
T* y = gates + d * 2;
act_gate(gates, gates, d);
act_cand(y, y, d);
// out = zt*ht~ + (1-zt)*ht_1
for (int i = 0; i < d; ++i) {
ht[i] = gates[i] * y[i] + (static_cast<T>(1) - gates[i]) * ht_1[i];
}
}
// TODO(TJ): tuning me
bool VSigmoidKernel::UseMe(const int& d) const { return true; }
bool VTanhKernel::UseMe(const int& d) const { return true; }
bool LSTMCtHtKernel::UseMe(const lstm_attr_t& attr) const { return true; }
bool LSTMC1H1Kernel::UseMe(const lstm_attr_t& attr) const { return true; }
bool GRUH1Kernel::UseMe(const gru_attr_t& attr) const { return true; }
bool GRUHtPart1Kernel::UseMe(const gru_attr_t& attr) const { return true; }
bool GRUHtPart2Kernel::UseMe(const gru_attr_t& attr) const { return true; }
} // namespace mix
} // namespace more
} // namespace jit
} // namespace operators
} // namespace paddle
namespace mix = paddle::operators::jit::more::mix;
#define REGISTER_MORE_KERNEL(key, func) \
REGISTER_JITKERNEL_MORE(key, mix, mix::func##Kernel)
REGISTER_MORE_KERNEL(kVSigmoid, VSigmoid);
REGISTER_MORE_KERNEL(kVTanh, VTanh);
REGISTER_MORE_KERNEL(kLSTMCtHt, LSTMCtHt);
REGISTER_MORE_KERNEL(kLSTMC1H1, LSTMC1H1);
REGISTER_MORE_KERNEL(kGRUH1, GRUH1);
REGISTER_MORE_KERNEL(kGRUHtPart1, GRUHtPart1);
REGISTER_MORE_KERNEL(kGRUHtPart2, GRUHtPart2);
#undef REGISTER_MORE_KERNEL
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <type_traits>
#include "paddle/fluid/operators/jit/kernel_base.h"
namespace paddle {
namespace operators {
namespace jit {
namespace more {
namespace mix {
using T = float;
void VSigmoid(const T* x, T* y, int n);
void VTanh(const T* x, T* y, int n);
void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr);
void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr);
void GRUH1(gru_t* step, const gru_attr_t* attr);
void GRUHtPart1(gru_t* step, const gru_attr_t* attr);
void GRUHtPart2(gru_t* step, const gru_attr_t* attr);
#define DECLARE_MORE_KERNEL(name, tuples) \
class name##Kernel : public KernelMore<tuples<T>> { \
public: \
name##Kernel() { this->func = name; } \
bool UseMe(const typename tuples<T>::attr_type&) const override; \
const char* ImplType() const override { return "Mixed"; } \
}
// XYN
DECLARE_MORE_KERNEL(VSigmoid, XYNTuples);
DECLARE_MORE_KERNEL(VTanh, XYNTuples);
DECLARE_MORE_KERNEL(LSTMCtHt, LSTMTuples);
DECLARE_MORE_KERNEL(LSTMC1H1, LSTMTuples);
DECLARE_MORE_KERNEL(GRUH1, GRUTuples);
DECLARE_MORE_KERNEL(GRUHtPart1, GRUTuples);
DECLARE_MORE_KERNEL(GRUHtPart2, GRUTuples);
#undef DECLARE_MORE_KERNEL
} // namespace mix
} // namespace more
} // namespace jit
} // namespace operators
} // namespace paddle
cc_library(jit_kernel_mkl SRCS mkl.cc DEPS jit_kernel_base dynload_mklml)
set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} dynload_mklml jit_kernel_mkl PARENT_SCOPE)
# use mkl kernels by name and type
USE_JITKERNEL_MORE(kVMul, mkl)
USE_JITKERNEL_MORE(kVAdd, mkl)
USE_JITKERNEL_MORE(kVScal, mkl)
USE_JITKERNEL_MORE(kVExp, mkl)
USE_JITKERNEL_MORE(kVSigmoid, mkl)
USE_JITKERNEL_MORE(kVTanh, mkl)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/more/mkl/mkl.h"
#include "paddle/fluid/operators/jit/refer/refer.h"
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/dynload/mklml.h"
namespace paddle {
namespace operators {
namespace jit {
namespace more {
namespace mkl {
template <>
void VMul<float>(const float* x, const float* y, float* z, int n) {
platform::dynload::vsMul(n, x, y, z);
}
template <>
void VMul<double>(const double* x, const double* y, double* z, int n) {
platform::dynload::vdMul(n, x, y, z);
}
template <>
void VAdd<float>(const float* x, const float* y, float* z, int n) {
platform::dynload::vsAdd(n, x, y, z);
}
template <>
void VAdd<double>(const double* x, const double* y, double* z, int n) {
platform::dynload::vdAdd(n, x, y, z);
}
template <>
void VScal<float>(const float* a, const float* x, float* y, int n) {
if (x == y) {
platform::dynload::cblas_sscal(n, *a, y, 1);
} else {
refer::VScal<float>(a, x, y, n);
}
}
template <>
void VScal<double>(const double* a, const double* x, double* y, int n) {
if (x == y) {
platform::dynload::cblas_dscal(n, *a, y, 1);
} else {
refer::VScal<double>(a, x, y, n);
}
}
template <>
void VExp<float>(const float* x, float* y, int n) {
platform::dynload::vsExp(n, x, y);
}
template <>
void VExp<double>(const double* x, double* y, int n) {
platform::dynload::vdExp(n, x, y);
}
// TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512
template <>
bool VMulKernel<float>::UseMe(const int& d) const {
return platform::MayIUse(platform::avx512f) && d > 512;
}
template <>
bool VAddKernel<float>::UseMe(const int& d) const {
return platform::MayIUse(platform::avx512f) && d > 512;
}
template <>
bool VScalKernel<float>::UseMe(const int& d) const {
return platform::MayIUse(platform::avx512f) && d > 512;
}
template <>
bool VExpKernel<float>::UseMe(const int& d) const {
return d > 7;
}
template <>
bool VSigmoidKernel<float>::UseMe(const int& d) const {
return d > 7;
}
template <>
bool VTanhKernel<float>::UseMe(const int& d) const {
return d > 7;
}
#define AWALYS_USE_ME_WITH_DOUBLE(func) \
template <> \
bool func##Kernel<double>::UseMe(const int& d) const { \
return true; \
}
AWALYS_USE_ME_WITH_DOUBLE(VMul);
AWALYS_USE_ME_WITH_DOUBLE(VAdd);
AWALYS_USE_ME_WITH_DOUBLE(VScal);
AWALYS_USE_ME_WITH_DOUBLE(VExp);
AWALYS_USE_ME_WITH_DOUBLE(VSigmoid);
AWALYS_USE_ME_WITH_DOUBLE(VTanh);
#undef AWALYS_USE_ME_WITH_DOUBLE
} // namespace mkl
} // namespace more
} // namespace jit
} // namespace operators
} // namespace paddle
namespace mkl = paddle::operators::jit::more::mkl;
#define REGISTER_MKL_KERNEL(key, func) \
REGISTER_JITKERNEL_MORE(key, mkl, mkl::func##Kernel<float>, \
mkl::func##Kernel<double>)
REGISTER_MKL_KERNEL(kVMul, VMul);
REGISTER_MKL_KERNEL(kVAdd, VAdd);
REGISTER_MKL_KERNEL(kVScal, VScal);
REGISTER_MKL_KERNEL(kVExp, VExp);
REGISTER_MKL_KERNEL(kVSigmoid, VSigmoid);
REGISTER_MKL_KERNEL(kVTanh, VTanh);
#undef REGISTER_MKL_KERNEL
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <type_traits>
#include "paddle/fluid/operators/jit/kernel_base.h"
namespace paddle {
namespace operators {
namespace jit {
namespace more {
namespace mkl {
template <typename T>
void VMul(const T* x, const T* y, T* z, int n);
template <typename T>
void VAdd(const T* x, const T* y, T* z, int n);
template <typename T>
void VScal(const T* a, const T* x, T* y, int n);
template <typename T>
void VExp(const T* x, T* y, int n);
template <typename T>
void VSigmoid(const T* x, T* y, int n) {
const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX;
for (int i = 0; i < n; ++i) {
y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = static_cast<T>(0) - y[i];
}
VExp(y, y, n);
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(1) / (static_cast<T>(1) + y[i]);
}
}
template <typename T>
void VTanh(const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(2) * x[i];
}
VSigmoid(y, y, n);
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(2) * y[i] - static_cast<T>(1);
}
}
#define DECLARE_MKL_KERNEL(name, tuples) \
template <typename T> \
class name##Kernel : public KernelMore<tuples<T>> { \
public: \
name##Kernel() { this->func = name<T>; } \
bool UseMe(const typename tuples<T>::attr_type&) const override; \
const char* ImplType() const override { return "MKL"; } \
}
// XYZN
DECLARE_MKL_KERNEL(VMul, XYZNTuples);
DECLARE_MKL_KERNEL(VAdd, XYZNTuples);
// AXYN
DECLARE_MKL_KERNEL(VScal, AXYNTuples);
// XYN
DECLARE_MKL_KERNEL(VExp, XYNTuples);
DECLARE_MKL_KERNEL(VSigmoid, XYNTuples);
DECLARE_MKL_KERNEL(VTanh, XYNTuples);
#undef DECLARE_MKL_KERNEL
} // namespace mkl
} // namespace more
} // namespace jit
} // namespace operators
} // namespace paddle
cc_library(jit_kernel_refer SRCS refer.cc DEPS jit_kernel_base)
set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} jit_kernel_refer PARENT_SCOPE)
function(USE_JITKERNEL_REFER TARGET)
file(APPEND ${jit_file} "USE_JITKERNEL_REFER(${TARGET});\n")
endfunction()
# use refer kernel by name
USE_JITKERNEL_REFER(kVMul)
USE_JITKERNEL_REFER(kVAdd)
USE_JITKERNEL_REFER(kVAddRelu)
USE_JITKERNEL_REFER(kVSub)
USE_JITKERNEL_REFER(kVScal)
USE_JITKERNEL_REFER(kVAddBias)
USE_JITKERNEL_REFER(kVRelu)
USE_JITKERNEL_REFER(kVIdentity)
USE_JITKERNEL_REFER(kVExp)
USE_JITKERNEL_REFER(kVSigmoid)
USE_JITKERNEL_REFER(kVTanh)
USE_JITKERNEL_REFER(kLSTMCtHt)
USE_JITKERNEL_REFER(kLSTMC1H1)
USE_JITKERNEL_REFER(kGRUH1)
USE_JITKERNEL_REFER(kGRUHtPart1)
USE_JITKERNEL_REFER(kGRUHtPart2)
USE_JITKERNEL_REFER(kCRFDecoding)
USE_JITKERNEL_REFER(kLayerNorm)
USE_JITKERNEL_REFER(kNCHW16CMulNC)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/refer/refer.h"
#include "paddle/fluid/operators/jit/registry.h"
namespace refer = paddle::operators::jit::refer;
#define REGISTER_REFER_KERNEL(key, func) \
REGISTER_JITKERNEL_REFER(key, refer::func##Kernel<float>, \
refer::func##Kernel<double>)
REGISTER_REFER_KERNEL(kVMul, VMul);
REGISTER_REFER_KERNEL(kVAdd, VAdd);
REGISTER_REFER_KERNEL(kVAddRelu, VAddRelu);
REGISTER_REFER_KERNEL(kVSub, VSub);
REGISTER_REFER_KERNEL(kVScal, VScal);
REGISTER_REFER_KERNEL(kVAddBias, VAddBias);
REGISTER_REFER_KERNEL(kVRelu, VRelu);
REGISTER_REFER_KERNEL(kVIdentity, VIdentity);
REGISTER_REFER_KERNEL(kVExp, VExp);
REGISTER_REFER_KERNEL(kVSigmoid, VSigmoid);
REGISTER_REFER_KERNEL(kVTanh, VTanh);
REGISTER_REFER_KERNEL(kLSTMCtHt, LSTMCtHt);
REGISTER_REFER_KERNEL(kLSTMC1H1, LSTMC1H1);
REGISTER_REFER_KERNEL(kGRUH1, GRUH1);
REGISTER_REFER_KERNEL(kGRUHtPart1, GRUHtPart1);
REGISTER_REFER_KERNEL(kGRUHtPart2, GRUHtPart2);
REGISTER_REFER_KERNEL(kCRFDecoding, CRFDecoding);
REGISTER_REFER_KERNEL(kLayerNorm, LayerNorm);
REGISTER_REFER_KERNEL(kNCHW16CMulNC, NCHW16CMulNC);
#undef REGISTER_REFER_KERNEL
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
You may obtain a copy of the License at * You may obtain a copy of the License at
*
http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
*
Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
limitations under the License. */ * limitations under the License. */
#pragma once #pragma once
#include <cmath> #include <cmath>
#include <string> #include <limits>
#include "paddle/fluid/operators/math/jit_kernel_impl.h" #include "paddle/fluid/operators/jit/helper.h"
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace jit {
namespace jitkernel {
namespace refer { namespace refer {
/* Refer code only focus on correctness */
// Refer code only focus on correctness
template <typename T> template <typename T>
void VMul(const T* x, const T* y, T* z, int n) { void VMul(const T* x, const T* y, T* z, int n) {
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
...@@ -47,6 +48,13 @@ void VAddRelu(const T* x, const T* y, T* z, int n) { ...@@ -47,6 +48,13 @@ void VAddRelu(const T* x, const T* y, T* z, int n) {
} }
} }
template <typename T>
void VSub(const T* x, const T* y, T* z, int n) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] - y[i];
}
}
template <typename T> template <typename T>
void VScal(const T* a, const T* x, T* y, int n) { void VScal(const T* a, const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
...@@ -69,7 +77,11 @@ void VRelu(const T* x, T* y, int n) { ...@@ -69,7 +77,11 @@ void VRelu(const T* x, T* y, int n) {
} }
template <typename T> template <typename T>
inline void VIdentity(const T* x, T* y, int n) {} inline void VIdentity(const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = x[i];
}
}
template <typename T> template <typename T>
void VExp(const T* x, T* y, int n) { void VExp(const T* x, T* y, int n) {
...@@ -102,20 +114,22 @@ void VTanh(const T* x, T* y, int n) { ...@@ -102,20 +114,22 @@ void VTanh(const T* x, T* y, int n) {
} }
template <typename T> template <typename T>
void (*getActFunc(const std::string& type))(const T*, T*, int) { // NOLINT void (*getActFunc(KernelType type))(const T*, T*, int) { // NOLINT
if (type == "sigmoid") { if (type == kVSigmoid) {
return VSigmoid<T>; return VSigmoid<T>;
} else if (type == "relu") { } else if (type == kVRelu) {
return VRelu<T>; return VRelu<T>;
} else if (type == "tanh") { } else if (type == kVTanh) {
return VTanh<T>; return VTanh<T>;
} else if (type == "identity" || type == "") { } else if (type == kVIdentity) {
return VIdentity<T>; return VIdentity<T>;
} }
PADDLE_THROW("Not support type: %s", type); PADDLE_THROW("Not support type: %s", type);
return nullptr; return nullptr;
} }
// TODO(TJ): add refer gemm and make LSTM kernels combine as same GRU kernels
// compute ct and ht // compute ct and ht
template <typename T> template <typename T>
void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) { void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) {
...@@ -231,8 +245,134 @@ void GRUHtPart2(gru_t* step, const gru_attr_t* attr) { ...@@ -231,8 +245,134 @@ void GRUHtPart2(gru_t* step, const gru_attr_t* attr) {
} }
} }
template <typename T>
void CRFDecoding(const int seq_len, const T* x, const T* w, T* alpha,
int* track, int right) {
constexpr int state_trans_base_idx = 2;
for (int i = 0; i < right; ++i) {
alpha[i] = w[i] + x[i];
}
for (int k = 1; k < seq_len; ++k) {
for (int i = 0; i < right; ++i) {
T max_score = -std::numeric_limits<T>::max();
int max_j = 0;
for (int j = 0; j < right; ++j) {
T score = alpha[(k - 1) * right + j] +
w[(j + state_trans_base_idx) * right + i];
if (score > max_score) {
max_score = score;
max_j = j;
}
}
alpha[k * right + i] = max_score + x[k * right + i];
track[k * right + i] = max_j;
}
}
}
template <typename T>
void LayerNorm(T* x, T* out, T* mean, T* var, const T* scale, const T* bias,
int height, const float epsilon, int right) {
// get mean
for (int i = 0; i < height; i++) {
T sum = 0.0;
int offset = i * right;
for (int j = 0; j < right; j++) {
sum += x[offset + j];
}
mean[i] = sum / right;
}
// get variance
for (int i = 0; i < height; i++) {
T sum = 0.0;
int offset = i * right;
for (int j = 0; j < right; j++) {
sum += (x[offset + j] - mean[i]) * (x[offset + j] - mean[i]);
}
var[i] = sum / right;
}
for (int i = 0; i < height; i++) {
int offset = i * right;
T sqrt_var = std::sqrt(var[i] + (T)epsilon);
for (int j = 0; j < right; j++) {
out[offset + j] = (x[offset + j] - mean[i]) / sqrt_var;
}
}
if (scale) {
for (int i = 0; i < height; i++) {
int offset = i * right;
for (int j = 0; j < right; j++) {
out[offset + j] *= scale[j];
}
}
}
if (bias) {
for (int i = 0; i < height; i++) {
int offset = i * right;
for (int j = 0; j < right; j++) {
out[offset + j] += bias[j];
}
}
}
}
template <typename T>
void NCHW16CMulNC(const T* x, const T* y, T* z, int height, int width) {
int offset = 0;
for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) {
for (int i = 0; i < 16; ++i) {
z[i + offset] = y[i] * x[i + offset];
}
offset += ZMM_FLOAT_BLOCK;
}
}
}
#define DECLARE_REFER_KERNEL(name, tuples) \
template <typename T> \
class name##Kernel : public ReferKernel<tuples<T>> { \
public: \
name##Kernel() { this->func = name<T>; } \
}
// const T* x, const T* y, T* z, int n
DECLARE_REFER_KERNEL(VMul, XYZNTuples);
DECLARE_REFER_KERNEL(VAdd, XYZNTuples);
DECLARE_REFER_KERNEL(VAddRelu, XYZNTuples);
DECLARE_REFER_KERNEL(VSub, XYZNTuples);
// const T* a, const T* x, T* y, int n
DECLARE_REFER_KERNEL(VScal, AXYNTuples);
DECLARE_REFER_KERNEL(VAddBias, AXYNTuples);
// const T* x, T* y, int n
DECLARE_REFER_KERNEL(VRelu, XYNTuples);
DECLARE_REFER_KERNEL(VIdentity, XYNTuples);
DECLARE_REFER_KERNEL(VExp, XYNTuples);
DECLARE_REFER_KERNEL(VSigmoid, XYNTuples);
DECLARE_REFER_KERNEL(VTanh, XYNTuples);
// lstm_t*, const lstm_attr_t*
DECLARE_REFER_KERNEL(LSTMCtHt, LSTMTuples);
DECLARE_REFER_KERNEL(LSTMC1H1, LSTMTuples);
// gru_t*, const gru_attr_t*
DECLARE_REFER_KERNEL(GRUH1, GRUTuples);
DECLARE_REFER_KERNEL(GRUHtPart1, GRUTuples);
DECLARE_REFER_KERNEL(GRUHtPart2, GRUTuples);
DECLARE_REFER_KERNEL(CRFDecoding, CRFDecodingTuples);
DECLARE_REFER_KERNEL(LayerNorm, LayerNormTuples);
DECLARE_REFER_KERNEL(NCHW16CMulNC, NCHW16CMulNCTuples);
#undef DECLARE_REFER_KERNEL
} // namespace refer } // namespace refer
} // namespace jitkernel } // namespace jit
} // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <memory>
#include <tuple>
#include <type_traits>
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/operators/jit/kernel_pool.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/variant.h" // for UNUSED
namespace paddle {
namespace operators {
namespace jit {
// make_unique is supported since c++14
template <typename T, typename... Args>
inline std::unique_ptr<T> make_unique(Args&&... args) {
static_assert(!std::is_array<T>::value, "T must not be array");
return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
}
template <typename Pool, typename PlaceType, bool IsEnd, size_t I,
typename... KernelImpls>
struct JitKernelRegistrarFunctor;
template <typename Pool, typename PlaceType, size_t I, typename... KernelImpls>
struct JitKernelRegistrarFunctor<Pool, PlaceType, true, I, KernelImpls...> {
void operator()(KernelType kt) const {}
};
template <typename Pool, typename PlaceType, size_t I, typename... KernelImpls>
struct JitKernelRegistrarFunctor<Pool, PlaceType, false, I, KernelImpls...> {
using KERNEL_IMPL_TYPE =
typename std::tuple_element<I, std::tuple<KernelImpls...>>::type;
void operator()(KernelType kt) const {
KernelKey kkey(kt, PlaceType());
Pool().Instance().Insert(kkey,
std::move(make_unique<const KERNEL_IMPL_TYPE>()));
constexpr auto size = std::tuple_size<std::tuple<KernelImpls...>>::value;
JitKernelRegistrarFunctor<Pool, PlaceType, I + 1 == size, I + 1,
KernelImpls...>
func;
func(kt);
}
};
template <typename Pool, typename PlaceType, typename... KernelImpls>
class JitKernelRegistrar {
public:
explicit JitKernelRegistrar(KernelType kt) {
JitKernelRegistrarFunctor<Pool, PlaceType, false, 0, KernelImpls...> func;
func(kt);
}
void Touch() {}
};
#define STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE(uniq_name, msg) \
struct __test_global_namespace_##uniq_name##__ {}; \
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
__test_global_namespace_##uniq_name##__>::value, \
msg)
// Refer always on CPUPlace
#define REGISTER_JITKERNEL_REFER(kernel_type, ...) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_##kernel_type##_refer_CPUPlace, \
"REGISTER_KERNEL_REFER must be called in global namespace"); \
static ::paddle::operators::jit::JitKernelRegistrar< \
::paddle::operators::jit::ReferKernelPool, ::paddle::platform::CPUPlace, \
__VA_ARGS__> \
__jit_kernel_registrar_##kernel_type##_refer_CPUPlace_( \
::paddle::operators::jit::KernelType::kernel_type); \
int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_() { \
__jit_kernel_registrar_##kernel_type##_refer_CPUPlace_.Touch(); \
return 0; \
}
// kernel_type: should be in paddle::operators::jit::KernelType
// place_type: should be one of CPUPlace and GPUPlace in paddle::platform
#define REGISTER_KERNEL_MORE(kernel_type, impl_type, place_type, ...) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_##kernel_type##_##impl_type##_##place_type, \
"REGISTER_KERNEL_MORE must be called in global namespace"); \
extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static int __assert_##kernel_type##_##impl_type##_##place_type##_has_refer_ \
UNUSED = TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static ::paddle::operators::jit::JitKernelRegistrar< \
::paddle::operators::jit::KernelPool, ::paddle::platform::place_type, \
__VA_ARGS__> \
__jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##_( \
::paddle::operators::jit::KernelType::kernel_type); \
int TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_() { \
__jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##_ \
.Touch(); \
return 0; \
}
#define REGISTER_JITKERNEL_MORE(kernel_type, impl_type, ...) \
REGISTER_KERNEL_MORE(kernel_type, impl_type, CPUPlace, __VA_ARGS__)
#define REGISTER_GPUKERNEL_MORE(kernel_type, impl_type, ...) \
REGISTER_KERNEL_MORE(kernel_type, impl_type, GPUPlace, __VA_ARGS__)
#define REGISTER_JITKERNEL_GEN(kernel_type, ...) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_gen_##kernel_type##_CPUPlace_, \
"REGISTER_JITKERNEL_GEN must be called in global namespace"); \
extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static int __assert_gen_##kernel_type##_has_refer_ UNUSED = \
TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static ::paddle::operators::jit::JitKernelRegistrar< \
::paddle::operators::jit::JitCodeCreatorPool, \
::paddle::platform::CPUPlace, __VA_ARGS__> \
__jit_kernel_registrar_gen_##kernel_type##_CPUPlace_( \
::paddle::operators::jit::KernelType::kernel_type); \
int TouchJitKernelReg_gen_##kernel_type##_CPUPlace_() { \
__jit_kernel_registrar_gen_##kernel_type##_CPUPlace_.Touch(); \
return 0; \
}
#define USE_JITKERNEL_GEN(kernel_type) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_gen_##kernel_type##_CPUPlace_, \
"USE_JITKERNEL_GEN must be called in global namespace"); \
extern int TouchJitKernelReg_gen_##kernel_type##_CPUPlace_(); \
static int use_jitkernel_gen_##kernel_type##_CPUPlace_ UNUSED = \
TouchJitKernelReg_gen_##kernel_type##_CPUPlace_()
#define USE_JITKERNEL_REFER(kernel_type) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_##kernel_type##_refer_CPUPlace_, \
"USE_JITKERNEL_REFER must be called in global namespace"); \
extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static int use_jitkernel_##kernel_type##_refer_CPUPlace_ UNUSED = \
TouchJitKernelReg_##kernel_type##_refer_CPUPlace_()
#define USE_KERNEL_MORE(kernel_type, impl_type, place_type) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_##kernel_type##_##impl_type##_##place_type##_, \
"USE_JITKERNEL_MORE must be called in global namespace"); \
extern int \
TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_(); \
static int use_jitkernel_##kernel_type##_##impl_type##_##place_type##_ \
UNUSED = \
TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_()
#define USE_JITKERNEL_MORE(kernel_type, impl_type) \
USE_KERNEL_MORE(kernel_type, impl_type, CPUPlace)
} // namespace jit
} // namespace operators
} // namespace paddle
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册