未验证 提交 3ae6692a 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #14512 from tensor-tang/fea/jit/rnn

Fea/jit/rnn
...@@ -183,24 +183,27 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -183,24 +183,27 @@ class FusionGRUKernel : public framework::OpKernel<T> {
const int total_T = x_dims[0]; \ const int total_T = x_dims[0]; \
const int D3 = wh_dims[1] const int D3 = wh_dims[1]
#define INIT_OTHER_DEFINES \ #define INIT_OTHER_DEFINES \
auto* h0 = ctx.Input<Tensor>("H0"); \ auto* h0 = ctx.Input<Tensor>("H0"); \
auto* wx = ctx.Input<Tensor>("WeightX"); \ auto* wx = ctx.Input<Tensor>("WeightX"); \
auto* bias = ctx.Input<Tensor>("Bias"); \ auto* bias = ctx.Input<Tensor>("Bias"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \ auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
bool is_reverse = ctx.Attr<bool>("is_reverse"); \ bool is_reverse = ctx.Attr<bool>("is_reverse"); \
const int M = x_dims[1]; \ const int M = x_dims[1]; \
const int D = wh_dims[0]; \ const int D = wh_dims[0]; \
const int D2 = D * 2; \ const int D2 = D * 2; \
const auto& ker = math::jitkernel::KernelPool::Instance() \ const math::jitkernel::gru_attr_t attr( \
.template Get<math::jitkernel::GRUKernel<T>, \ D, ctx.Attr<std::string>("gate_activation"), \
const std::string&, const std::string&>( \ ctx.Attr<std::string>("activation")); \
ctx.Attr<std::string>("gate_activation"), \ math::jitkernel::gru_t one_step; \
ctx.Attr<std::string>("activation"), D); \ const auto& ker = \
const T* x_data = x->data<T>(); \ math::jitkernel::KernelPool::Instance() \
const T* wx_data = wx->data<T>(); \ .template Get<math::jitkernel::GRUKernel<T>, \
const T* wh_data = wh->data<T>(); \ const math::jitkernel::gru_attr_t&>(attr); \
auto place = ctx.GetPlace(); \ const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \
auto place = ctx.GetPlace(); \
T* xx_data = xx->mutable_data<T>(place) T* xx_data = xx->mutable_data<T>(place)
void SeqCompute(const framework::ExecutionContext& ctx) const { void SeqCompute(const framework::ExecutionContext& ctx) const {
...@@ -237,7 +240,9 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -237,7 +240,9 @@ class FusionGRUKernel : public framework::OpKernel<T> {
if (h0_data) { if (h0_data) {
prev_hidden_data = h0_data + bid * D; prev_hidden_data = h0_data + bid * D;
} else { } else {
ker->ComputeH1(xx_data, hidden_out_data); one_step.gates = xx_data;
one_step.ht = hidden_out_data;
ker->ComputeH1(&one_step, &attr);
prev_hidden_data = hidden_out_data; prev_hidden_data = hidden_out_data;
tstart = 1; tstart = 1;
move_step(); move_step();
...@@ -247,12 +252,15 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -247,12 +252,15 @@ class FusionGRUKernel : public framework::OpKernel<T> {
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D2, D, static_cast<T>(1), blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D2, D, static_cast<T>(1),
prev_hidden_data, D, wh_data, D2, static_cast<T>(1), xx_data, prev_hidden_data, D, wh_data, D2, static_cast<T>(1), xx_data,
D3); D3);
ker->ComputeHtPart1(xx_data, prev_hidden_data, hidden_out_data); one_step.gates = xx_data;
one_step.ht_1 = prev_hidden_data;
one_step.ht = hidden_out_data;
ker->ComputeHtPart1(&one_step, &attr);
// gemm rt * Ws // gemm rt * Ws
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast<T>(1), blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast<T>(1),
hidden_out_data, D, wh_state_data, D, static_cast<T>(1), hidden_out_data, D, wh_state_data, D, static_cast<T>(1),
xx_data + D2, D3); xx_data + D2, D3);
ker->ComputeHtPart2(xx_data, prev_hidden_data, hidden_out_data); ker->ComputeHtPart2(&one_step, &attr);
// save prev // save prev
prev_hidden_data = hidden_out_data; prev_hidden_data = hidden_out_data;
move_step(); move_step();
...@@ -314,7 +322,9 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -314,7 +322,9 @@ class FusionGRUKernel : public framework::OpKernel<T> {
T* cur_out_data = batched_out_data; T* cur_out_data = batched_out_data;
// W: {W_update, W_reset; W_state} // W: {W_update, W_reset; W_state}
for (int i = 0; i < max_bs; ++i) { for (int i = 0; i < max_bs; ++i) {
ker->ComputeH1(cur_in_data, cur_out_data); one_step.gates = cur_in_data;
one_step.ht = cur_out_data;
ker->ComputeH1(&one_step, &attr);
// add offset // add offset
cur_in_data += D3; cur_in_data += D3;
cur_out_data += D; cur_out_data += D;
...@@ -339,8 +349,11 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -339,8 +349,11 @@ class FusionGRUKernel : public framework::OpKernel<T> {
T* cur_out_data = batched_out_data; T* cur_out_data = batched_out_data;
T* cur_prev_hidden_data = prev_hidden_data; T* cur_prev_hidden_data = prev_hidden_data;
for (int i = 0; i < cur_bs; ++i) { for (int i = 0; i < cur_bs; ++i) {
ker->ComputeHtPart1(cur_batched_data, cur_prev_hidden_data, one_step.gates = cur_batched_data;
cur_out_data); one_step.ht_1 = cur_prev_hidden_data;
one_step.ht = cur_out_data;
ker->ComputeHtPart1(&one_step, &attr);
cur_batched_data += D3; cur_batched_data += D3;
cur_prev_hidden_data += D; cur_prev_hidden_data += D;
cur_out_data += D; cur_out_data += D;
...@@ -354,8 +367,10 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -354,8 +367,10 @@ class FusionGRUKernel : public framework::OpKernel<T> {
cur_prev_hidden_data = prev_hidden_data; cur_prev_hidden_data = prev_hidden_data;
for (int i = 0; i < cur_bs; ++i) { for (int i = 0; i < cur_bs; ++i) {
ker->ComputeHtPart2(cur_batched_data, cur_prev_hidden_data, one_step.gates = cur_batched_data;
cur_out_data); one_step.ht_1 = cur_prev_hidden_data;
one_step.ht = cur_out_data;
ker->ComputeHtPart2(&one_step, &attr);
cur_batched_data += D3; cur_batched_data += D3;
cur_prev_hidden_data += D; cur_prev_hidden_data += D;
cur_out_data += D; cur_out_data += D;
......
...@@ -236,27 +236,31 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -236,27 +236,31 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
const int D = wh_dims[0]; \ const int D = wh_dims[0]; \
const int D4 = wh_dims[1] const int D4 = wh_dims[1]
#define INIT_OTHER_DEFINES \ #define INIT_OTHER_DEFINES \
const T* x_data = x->data<T>(); \ const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \ const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \ const T* wh_data = wh->data<T>(); \
/* diagonal weight*/ \ /* diagonal weight*/ \
const T* wp_data = bias->data<T>() + D4; \ const T* wp_data = bias->data<T>() + D4; \
/* for peephole only*/ \ /* for peephole only*/ \
T* checked_cell_data = nullptr; \ T* checked_cell_data = nullptr; \
auto place = ctx.GetPlace(); \ auto place = ctx.GetPlace(); \
if (use_peepholes) { \ if (use_peepholes) { \
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \ /* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \
auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \ auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \
checked_cell_data = checked_cell->mutable_data<T>(place); \ checked_cell_data = checked_cell->mutable_data<T>(place); \
} \ } \
const auto& ker = \ const math::jitkernel::lstm_attr_t attr( \
math::jitkernel::KernelPool::Instance() \ D, ctx.Attr<std::string>("gate_activation"), \
.template Get<math::jitkernel::LSTMKernel<T>, const std::string&, \ ctx.Attr<std::string>("candidate_activation"), \
const std::string&, const std::string&>( \ ctx.Attr<std::string>("cell_activation"), use_peepholes); \
ctx.Attr<std::string>("gate_activation"), \ math::jitkernel::lstm_t one_step; \
ctx.Attr<std::string>("candidate_activation"), \ one_step.wp = wp_data; \
ctx.Attr<std::string>("cell_activation"), D, use_peepholes) one_step.checked = checked_cell_data; \
const auto& ker = \
math::jitkernel::KernelPool::Instance() \
.template Get<math::jitkernel::LSTMKernel<T>, \
const math::jitkernel::lstm_attr_t&>(attr)
// Wh GEMM // Wh GEMM
#define GEMM_WH_ADDON(bs, prev, out) \ #define GEMM_WH_ADDON(bs, prev, out) \
...@@ -299,7 +303,10 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -299,7 +303,10 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
prev_h_data = h0_data + bid * D; prev_h_data = h0_data + bid * D;
prev_c_data = c0_data + bid * D; prev_c_data = c0_data + bid * D;
} else { } else {
ker->ComputeC1H1(xx_data, c_out_data, h_out_data, wp_data); one_step.gates = xx_data;
one_step.ct = c_out_data;
one_step.ht = h_out_data;
ker->ComputeC1H1(&one_step, &attr);
tstart = 1; tstart = 1;
// move one step // move one step
prev_h_data = h_out_data; prev_h_data = h_out_data;
...@@ -310,8 +317,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -310,8 +317,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
} }
for (int step = tstart; step < seq_len; ++step) { for (int step = tstart; step < seq_len; ++step) {
GEMM_WH_ADDON(1, prev_h_data, xx_data); GEMM_WH_ADDON(1, prev_h_data, xx_data);
ker->ComputeCtHt(xx_data, prev_c_data, c_out_data, h_out_data, wp_data,
checked_cell_data); one_step.gates = xx_data;
one_step.ct_1 = prev_c_data;
one_step.ct = c_out_data;
one_step.ht = h_out_data;
ker->ComputeCtHt(&one_step, &attr);
// move one step // move one step
prev_h_data = h_out_data; prev_h_data = h_out_data;
prev_c_data = c_out_data; prev_c_data = c_out_data;
...@@ -388,7 +399,11 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -388,7 +399,11 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
T* cur_h_out_data = batched_h_out_data; T* cur_h_out_data = batched_h_out_data;
T* cur_c_out_data = batched_c_out_data; T* cur_c_out_data = batched_c_out_data;
for (int i = 0; i < max_bs; ++i) { for (int i = 0; i < max_bs; ++i) {
ker->ComputeC1H1(cur_in_data, cur_c_out_data, cur_h_out_data, wp_data); one_step.gates = cur_in_data;
one_step.ct = cur_c_out_data;
one_step.ht = cur_h_out_data;
ker->ComputeC1H1(&one_step, &attr);
cur_in_data += D4; cur_in_data += D4;
cur_c_out_data += D; cur_c_out_data += D;
cur_h_out_data += D; cur_h_out_data += D;
...@@ -413,8 +428,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -413,8 +428,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
T* cur_c_out_data = batched_c_out_data; T* cur_c_out_data = batched_c_out_data;
T* cur_h_out_data = batched_h_out_data; T* cur_h_out_data = batched_h_out_data;
for (int i = 0; i < cur_bs; ++i) { for (int i = 0; i < cur_bs; ++i) {
ker->ComputeCtHt(cur_in_data, cur_prev_c_data, cur_c_out_data, one_step.gates = cur_in_data;
cur_h_out_data, wp_data, checked_cell_data); one_step.ct_1 = cur_prev_c_data;
one_step.ct = cur_c_out_data;
one_step.ht = cur_h_out_data;
ker->ComputeCtHt(&one_step, &attr);
// move one batch // move one batch
cur_in_data += D4; cur_in_data += D4;
cur_prev_c_data += D; cur_prev_c_data += D;
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/jit_code.h" #include "paddle/fluid/operators/math/jit_code.h"
#include <stddef.h> // offsetof
#include "paddle/fluid/operators/math/jit_kernel.h" // TODO(TJ): remove me #include "paddle/fluid/operators/math/jit_kernel.h" // TODO(TJ): remove me
namespace paddle { namespace paddle {
...@@ -139,32 +140,10 @@ bool VActJitCode::init(int d, operand_type type) { ...@@ -139,32 +140,10 @@ bool VActJitCode::init(int d, operand_type type) {
} }
void VActJitCode::generate() { void VActJitCode::generate() {
xmm_t xmm_zero = xmm_t(2);
ymm_t ymm_zero = ymm_t(2);
if (type_ == operand_type::relu) {
vxorps(ymm_zero, ymm_zero, ymm_zero);
}
int offset = 0; int offset = 0;
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) { for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
vmovups(ymm_src, ptr[param1 + offset]); vmovups(ymm_src, ptr[param1 + offset]);
switch (type_) { act<ymm_t>(ymm_dst, ymm_src, type_);
case operand_type::relu:
relu_jmm<ymm_t>(ymm_dst, ymm_src, ymm_zero);
break;
case operand_type::exp:
exp_jmm<ymm_t>(ymm_dst, ymm_src, 2, 3, 4, 5);
break;
case operand_type::sigmoid:
sigmoid_jmm<ymm_t>(ymm_dst, ymm_src, 2, 3, 4, 5);
break;
case operand_type::tanh:
tanh_jmm<ymm_t>(ymm_dst, ymm_src, 2, 3, 4, 5);
break;
case operand_type::identity:
break;
default:
break;
}
vmovups(ptr[param2 + offset], ymm_dst); vmovups(ptr[param2 + offset], ymm_dst);
offset += sizeof(float) * YMM_FLOAT_BLOCK; offset += sizeof(float) * YMM_FLOAT_BLOCK;
} }
...@@ -181,22 +160,7 @@ void VActJitCode::generate() { ...@@ -181,22 +160,7 @@ void VActJitCode::generate() {
block = 1; block = 1;
vmovss(xmm_src, ptr[param1 + offset]); vmovss(xmm_src, ptr[param1 + offset]);
} }
switch (type_) { act<xmm_t>(xmm_dst, xmm_src, type_);
case operand_type::relu:
relu_jmm<xmm_t>(xmm_dst, xmm_src, xmm_zero);
break;
case operand_type::exp:
exp_jmm<xmm_t>(xmm_dst, xmm_src, 2, 3, 4, 5);
break;
case operand_type::sigmoid:
sigmoid_jmm<xmm_t>(xmm_dst, xmm_src, 2, 3, 4, 5);
break;
case operand_type::tanh:
tanh_jmm<xmm_t>(xmm_dst, xmm_src, 2, 3, 4, 5);
break;
default:
break;
}
if (rest >= 4) { if (rest >= 4) {
vmovups(ptr[param2 + offset], xmm_dst); vmovups(ptr[param2 + offset], xmm_dst);
} else if (rest >= 2) { } else if (rest >= 2) {
...@@ -210,6 +174,158 @@ void VActJitCode::generate() { ...@@ -210,6 +174,158 @@ void VActJitCode::generate() {
ret(); ret();
} }
bool LSTMJitCode::init(int d) { return MayIUse(avx) && d % 8 == 0; }
void LSTMJitCode::generate() {
if (use_peephole_) {
preCode();
}
reg64_t reg_ptr_gates = rax;
reg64_t reg_ptr_ct_1 = r9;
reg64_t reg_ptr_ct = r10;
reg64_t reg_ptr_ht = r11;
reg64_t reg_ptr_wp = r12;
mov(reg_ptr_gates, ptr[param1 + offsetof(lstm_t, gates)]);
mov(reg_ptr_ct_1, ptr[param1 + offsetof(lstm_t, ct_1)]);
mov(reg_ptr_ct, ptr[param1 + offsetof(lstm_t, ct)]);
mov(reg_ptr_ht, ptr[param1 + offsetof(lstm_t, ht)]);
if (use_peephole_) {
mov(reg_ptr_wp, ptr[param1 + offsetof(lstm_t, wp)]);
}
int offset = 0;
int d = num_ * sizeof(float);
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
/* gates: W_ch, W_ih, W_fh, W_oh */
ymm_t ymm_c = ymm_t(0);
ymm_t ymm_i = ymm_t(1);
ymm_t ymm_f = ymm_t(2);
ymm_t ymm_o = ymm_t(3);
ymm_t ymm_ct_1 = ymm_t(4);
ymm_t ymm_wp0 = ymm_t(5);
ymm_t ymm_wp1 = ymm_t(6);
ymm_t ymm_wp2 = ymm_t(7);
vmovups(ymm_c, ptr[reg_ptr_gates + offset]);
vmovups(ymm_i, ptr[reg_ptr_gates + offset + d]);
vmovups(ymm_f, ptr[reg_ptr_gates + offset + 2 * d]);
vmovups(ymm_o, ptr[reg_ptr_gates + offset + 3 * d]);
if (!compute_c1h1_) {
vmovups(ymm_ct_1, ptr[reg_ptr_ct_1 + offset]);
}
if (use_peephole_) {
vmovups(ymm_wp0, ptr[reg_ptr_wp + offset]);
vmovups(ymm_wp1, ptr[reg_ptr_wp + offset + d]);
vmovups(ymm_wp2, ptr[reg_ptr_wp + offset + 2 * d]);
}
/* C_t = act_cand(c) * act_gate(i) + C_t-1 * act_gate(f) */
// act_cand(c)
act<ymm_t>(ymm_c, ymm_c, act_cand_);
// act_gate(i) or act_gate(ct_1 * wp0 + i)
if (!compute_c1h1_ && use_peephole_) {
vmulps(ymm_wp0, ymm_ct_1, ymm_wp0);
vaddps(ymm_i, ymm_i, ymm_wp0);
}
act<ymm_t>(ymm_i, ymm_i, act_gate_);
vmulps(ymm_c, ymm_c, ymm_i);
if (!compute_c1h1_) {
// act_gate(f) or act_gate(ct_1 * wp1 + f)
if (use_peephole_) {
vmulps(ymm_wp1, ymm_ct_1, ymm_wp1);
vaddps(ymm_f, ymm_f, ymm_wp1);
}
act<ymm_t>(ymm_f, ymm_f, act_gate_);
// ct
vmulps(ymm_f, ymm_f, ymm_ct_1);
vaddps(ymm_f, ymm_f, ymm_c);
}
/* H_t = act_cell(C_t) * act_gate(o) */
// act_cell(C_t)
ymm_t ymm_ct = compute_c1h1_ ? ymm_c : ymm_f;
ymm_t ymm_tmp = ymm_i;
act<ymm_t>(ymm_tmp, ymm_ct, act_cell_);
// act_gate(o) or act_gate(ct * wp2 + o)
if (use_peephole_) {
vmulps(ymm_wp2, ymm_ct, ymm_wp2);
vaddps(ymm_o, ymm_o, ymm_wp2);
}
act<ymm_t>(ymm_o, ymm_o, act_gate_);
// ht
vmulps(ymm_o, ymm_o, ymm_tmp);
// save ct and ht
vmovups(ptr[reg_ptr_ct + offset], ymm_ct);
vmovups(ptr[reg_ptr_ht + offset], ymm_o);
offset += sizeof(float) * YMM_FLOAT_BLOCK;
}
if (use_peephole_) {
postCode();
} else {
ret();
}
}
bool GRUJitCode::init(int d) { return MayIUse(avx) && d % 8 == 0; }
void GRUJitCode::generate() {
reg64_t reg_ptr_gates = rax;
reg64_t reg_ptr_ht_1 = r9;
reg64_t reg_ptr_ht = r10;
mov(reg_ptr_gates, ptr[param1 + offsetof(gru_t, gates)]);
mov(reg_ptr_ht_1, ptr[param1 + offsetof(gru_t, ht_1)]);
mov(reg_ptr_ht, ptr[param1 + offsetof(gru_t, ht)]);
ymm_t ymm_one = ymm_t(0);
if (id_ == 2) {
reg64_t reg_ptr_tmp = r11;
mov(reg_ptr_tmp, reinterpret_cast<size_t>(exp_float_consts));
vmovaps(ymm_one, ptr[reg_ptr_tmp + OFFSET_EXP_ONE]);
}
int offset = 0;
int d = num_ * sizeof(float);
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
ymm_t ymm_u = ymm_t(1);
ymm_t ymm_r = ymm_t(2);
ymm_t ymm_s = ymm_t(3);
ymm_t ymm_ht_1 = ymm_t(4);
// W: {W_update, W_reset; W_state}
if (id_ == 0 || id_ == 2) {
vmovups(ymm_u, ptr[reg_ptr_gates + offset]);
vmovups(ymm_s, ptr[reg_ptr_gates + offset + 2 * d]);
}
if (id_ == 1) {
vmovups(ymm_r, ptr[reg_ptr_gates + offset + d]);
}
if (id_ == 1 || id_ == 2) {
vmovups(ymm_ht_1, ptr[reg_ptr_ht_1 + offset]);
}
if (id_ == 0) {
// ht = act_gate(u) * act_cand(s)
act<ymm_t>(ymm_u, ymm_u, act_gate_);
act<ymm_t>(ymm_s, ymm_s, act_cand_);
vmulps(ymm_s, ymm_s, ymm_u);
vmovups(ptr[reg_ptr_ht + offset], ymm_s);
} else if (id_ == 1) {
// ht = act_gate(r) * ht_1
act<ymm_t>(ymm_r, ymm_r, act_gate_);
vmulps(ymm_r, ymm_r, ymm_ht_1);
vmovups(ptr[reg_ptr_ht + offset], ymm_r);
} else if (id_ == 2) {
// ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1
ymm_t ymm_one_inner = ymm_t(ymm_one.getIdx());
act<ymm_t>(ymm_u, ymm_u, act_gate_);
act<ymm_t>(ymm_s, ymm_s, act_cand_);
vmulps(ymm_s, ymm_s, ymm_u);
vsubps(ymm_u, ymm_one_inner, ymm_u);
vmulps(ymm_u, ymm_ht_1, ymm_u);
vaddps(ymm_u, ymm_s, ymm_u);
vmovups(ptr[reg_ptr_ht + offset], ymm_u);
}
offset += sizeof(float) * YMM_FLOAT_BLOCK;
}
ret();
}
} // namespace gen } // namespace gen
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/operators/math/jit_gen.h" #include "paddle/fluid/operators/math/jit_gen.h"
#include "paddle/fluid/operators/math/jit_kernel_impl.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
namespace paddle { namespace paddle {
...@@ -46,14 +47,6 @@ extern const float exp_float_consts[]; ...@@ -46,14 +47,6 @@ extern const float exp_float_consts[];
extern const int exp_int_0x7f[]; extern const int exp_int_0x7f[];
extern int g_tmp_mem[]; extern int g_tmp_mem[];
// TODO(TJ): move these to some proper place
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
#define XMM_FLOAT_BLOCK 4
#define YMM_FLOAT_BLOCK 8
#define ZMM_FLOAT_BLOCK 16
#define ALIGN32 __attribute__((aligned(32))) #define ALIGN32 __attribute__((aligned(32)))
#define EXP_HIG 88.3762626647949f #define EXP_HIG 88.3762626647949f
#define EXP_LOW -88.3762626647949f #define EXP_LOW -88.3762626647949f
...@@ -176,31 +169,34 @@ class VActJitCode : public JitCode { ...@@ -176,31 +169,34 @@ class VActJitCode : public JitCode {
protected: protected:
// compute relu with ymm, xmm // compute relu with ymm, xmm
template <typename JMM> template <typename JMM>
void relu_jmm(JMM& dst, JMM& src, JMM& zero) { // NOLINT void relu_jmm(JMM& dst, JMM& src, int zero_idx = 15) { // NOLINT
JMM zero = JMM(zero_idx);
vxorps(zero, zero, zero);
vmaxps(dst, src, zero); vmaxps(dst, src, zero);
} }
// compute exp with ymm, xmm // compute exp with ymm, xmm
template <typename JMM> template <typename JMM>
void exp_jmm(JMM& dst, JMM& src, int fx_idx = 2, int fy_idx = 3, // NOLINT void exp_jmm(JMM& dst, JMM& src, int src_idx = 11, int fx_idx = 12, // NOLINT
int mask_idx = 4, int tmp_idx = 5) { int fy_idx = 13, int mask_idx = 14, int tmp_idx = 15) {
using namespace platform::jit; // NOLINT using namespace platform::jit; // NOLINT
assert(src.getIdx() != dst.getIdx()); // TODO(TJ): use enfore
// check all idx can not equal // check all idx can not equal
JMM jmm_src = JMM(src_idx);
JMM jmm_fx = JMM(fx_idx); JMM jmm_fx = JMM(fx_idx);
JMM jmm_fy = JMM(fy_idx); JMM jmm_fy = JMM(fy_idx);
JMM jmm_mask = JMM(mask_idx); JMM jmm_mask = JMM(mask_idx);
JMM jmm_tmp = JMM(tmp_idx); JMM jmm_tmp = JMM(tmp_idx);
reg64_t reg_ptr_global = rax; reg64_t reg_ptr_global = rax;
push(reg_ptr_global); push(reg_ptr_global);
vmovaps(jmm_src, src);
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts)); mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]);
vminps(src, src, jmm_tmp); vminps(jmm_src, jmm_src, jmm_tmp);
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOW]); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOW]);
vmaxps(src, src, jmm_tmp); vmaxps(jmm_src, jmm_src, jmm_tmp);
// express exp(x) as exp(g + n*log(2)) // express exp(x) as exp(g + n*log(2))
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOG2EF]); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOG2EF]);
vmulps(jmm_fx, src, jmm_tmp); vmulps(jmm_fx, jmm_src, jmm_tmp);
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_0P5]); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_0P5]);
vaddps(jmm_fx, jmm_fx, jmm_tmp); vaddps(jmm_fx, jmm_fx, jmm_tmp);
vroundps(jmm_fy, jmm_fx, 0x01); vroundps(jmm_fy, jmm_fx, 0x01);
...@@ -214,21 +210,21 @@ class VActJitCode : public JitCode { ...@@ -214,21 +210,21 @@ class VActJitCode : public JitCode {
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C2]); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C2]);
JMM ymm_z = JMM(jmm_mask.getIdx()); JMM ymm_z = JMM(jmm_mask.getIdx());
vmulps(ymm_z, jmm_fx, jmm_tmp); vmulps(ymm_z, jmm_fx, jmm_tmp);
vsubps(src, src, jmm_fy); vsubps(jmm_src, jmm_src, jmm_fy);
vsubps(src, src, ymm_z); vsubps(jmm_src, jmm_src, ymm_z);
vmulps(ymm_z, src, src); vmulps(ymm_z, jmm_src, jmm_src);
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P0]); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P0]);
vmulps(dst, src, jmm_tmp); vmulps(dst, jmm_src, jmm_tmp);
for (size_t i = OFFSET_EXP_P1; i < OFFSET_EXP_P5; for (size_t i = OFFSET_EXP_P1; i < OFFSET_EXP_P5;
i += (YMM_FLOAT_BLOCK * sizeof(float))) { i += (YMM_FLOAT_BLOCK * sizeof(float))) {
vmovaps(jmm_tmp, ptr[reg_ptr_global + i]); // P1~P4 vmovaps(jmm_tmp, ptr[reg_ptr_global + i]); // P1~P4
vaddps(dst, dst, jmm_tmp); vaddps(dst, dst, jmm_tmp);
vmulps(dst, dst, src); vmulps(dst, dst, jmm_src);
} }
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P5]); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P5]);
vaddps(dst, dst, jmm_tmp); vaddps(dst, dst, jmm_tmp);
vmulps(dst, dst, ymm_z); vmulps(dst, dst, ymm_z);
vaddps(dst, dst, src); vaddps(dst, dst, jmm_src);
vmovaps(jmm_tmp, ptr[reg_ptr_global]); vmovaps(jmm_tmp, ptr[reg_ptr_global]);
vaddps(dst, dst, jmm_tmp); vaddps(dst, dst, jmm_tmp);
// build 2^n // build 2^n
...@@ -265,20 +261,23 @@ class VActJitCode : public JitCode { ...@@ -265,20 +261,23 @@ class VActJitCode : public JitCode {
// compute sigmoid with ymm, xmm // compute sigmoid with ymm, xmm
template <typename JMM> template <typename JMM>
void sigmoid_jmm(JMM& dst, JMM& src, int fx_idx = 2, // NOLINT void sigmoid_jmm(JMM& dst, JMM& src, int src_idx = 11, // NOLINT
int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5) { int fx_idx = 12, int fy_idx = 13, int mask_idx = 14,
int tmp_idx = 15) {
// y = 1 / (1 + e^-x) // y = 1 / (1 + e^-x)
JMM jmm_tmp = JMM(tmp_idx); JMM jmm_tmp = JMM(tmp_idx);
JMM jmm_src = JMM(src_idx);
reg64_t reg_ptr_global = rax; reg64_t reg_ptr_global = rax;
push(reg_ptr_global); push(reg_ptr_global);
vmovaps(jmm_src, src);
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts)); mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MAX]); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MAX]);
vminps(src, src, jmm_tmp); vminps(jmm_src, jmm_src, jmm_tmp);
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MIN]); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MIN]);
vmaxps(src, src, jmm_tmp); vmaxps(jmm_src, jmm_src, jmm_tmp);
vxorps(jmm_tmp, jmm_tmp, jmm_tmp); vxorps(jmm_tmp, jmm_tmp, jmm_tmp);
vsubps(src, jmm_tmp, src); vsubps(jmm_src, jmm_tmp, jmm_src);
exp_jmm<JMM>(dst, src, fx_idx, fy_idx, mask_idx, tmp_idx); exp_jmm<JMM>(dst, jmm_src, src_idx, fx_idx, fy_idx, mask_idx, tmp_idx);
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
vaddps(dst, dst, jmm_tmp); vaddps(dst, dst, jmm_tmp);
vdivps(dst, jmm_tmp, dst); vdivps(dst, jmm_tmp, dst);
...@@ -287,19 +286,22 @@ class VActJitCode : public JitCode { ...@@ -287,19 +286,22 @@ class VActJitCode : public JitCode {
// compute tanh with ymm, xmm // compute tanh with ymm, xmm
template <typename JMM> template <typename JMM>
void tanh_jmm(JMM& dst, JMM& src, int fx_idx = 2, int fy_idx = 3, // NOLINT void tanh_jmm(JMM& dst, JMM& src, int src_idx = 11, // NOLINT
int mask_idx = 4, int tmp_idx = 5) { int fx_idx = 12, int fy_idx = 13, int mask_idx = 14,
int tmp_idx = 15) {
// y = 2 / (1 + e^(-2x)) - 1 // y = 2 / (1 + e^(-2x)) - 1
JMM jmm_src = JMM(src_idx);
JMM jmm_tmp = JMM(tmp_idx); JMM jmm_tmp = JMM(tmp_idx);
JMM jmm_zero = JMM(mask_idx); JMM jmm_zero = JMM(mask_idx);
reg64_t reg_ptr_global = rax; reg64_t reg_ptr_global = rax;
push(reg_ptr_global); push(reg_ptr_global);
vmovaps(jmm_src, src);
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts)); mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
vxorps(jmm_zero, jmm_zero, jmm_zero); vxorps(jmm_zero, jmm_zero, jmm_zero);
vsubps(jmm_tmp, jmm_zero, jmm_tmp); vsubps(jmm_tmp, jmm_zero, jmm_tmp);
vmulps(src, src, jmm_tmp); vmulps(jmm_src, jmm_src, jmm_tmp);
exp_jmm<JMM>(dst, src, fx_idx, fy_idx, mask_idx, tmp_idx); exp_jmm<JMM>(dst, jmm_src, src_idx, fx_idx, fy_idx, mask_idx, tmp_idx);
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
vaddps(dst, dst, jmm_tmp); vaddps(dst, dst, jmm_tmp);
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
...@@ -309,6 +311,30 @@ class VActJitCode : public JitCode { ...@@ -309,6 +311,30 @@ class VActJitCode : public JitCode {
pop(reg_ptr_global); pop(reg_ptr_global);
} }
template <typename JMM>
void act(JMM& dst, JMM& src, operand_type type) { // NOLINT
// use 11~15
switch (type) {
case operand_type::relu:
relu_jmm<JMM>(dst, src, 15);
break;
case operand_type::exp:
exp_jmm<JMM>(dst, src, 11, 12, 13, 14, 15);
break;
case operand_type::sigmoid:
sigmoid_jmm<JMM>(dst, src, 11, 12, 13, 14, 15);
break;
case operand_type::tanh:
tanh_jmm<JMM>(dst, src, 11, 12, 13, 14, 15);
break;
case operand_type::identity:
break;
default:
// throw error
break;
}
}
protected: protected:
int num_; int num_;
operand_type type_; operand_type type_;
...@@ -322,6 +348,148 @@ class VActJitCode : public JitCode { ...@@ -322,6 +348,148 @@ class VActJitCode : public JitCode {
ymm_t ymm_dst = ymm_t(1); ymm_t ymm_dst = ymm_t(1);
}; };
class LSTMJitCode : public VActJitCode {
public:
const char* name() const override {
std::string base = "LSTMJitCode";
if (use_peephole_) {
base += "_Peephole";
}
if (compute_c1h1_) {
base += "_C1H1";
}
auto AddTypeStr = [&](operand_type type) {
switch (type) {
case operand_type::relu:
base += "_Relu";
break;
case operand_type::exp:
base += "_Exp";
break;
case operand_type::sigmoid:
base += "_Sigmoid";
break;
case operand_type::tanh:
base += "_Tanh";
break;
case operand_type::identity:
base += "_Identity";
break;
default:
break;
}
};
AddTypeStr(act_gate_);
AddTypeStr(act_cand_);
AddTypeStr(act_cell_);
return base.c_str();
}
explicit LSTMJitCode(bool compute_c1h1, const lstm_attr_t& attr,
size_t code_size = 256 * 1024, void* code_ptr = nullptr)
: VActJitCode(attr.d, operand_type::sigmoid /* this is bugy*/, code_size,
code_ptr),
compute_c1h1_(compute_c1h1) {
auto typeExchange = [](const std::string& type) -> gen::operand_type {
if (type == "sigmoid") {
return operand_type::sigmoid;
} else if (type == "relu") {
return operand_type::relu;
} else if (type == "tanh") {
return operand_type::tanh;
} else if (type == "identity" || type == "") {
return operand_type::identity;
} // else throw error
return operand_type::identity;
};
num_ = attr.d;
use_peephole_ = attr.use_peephole;
act_gate_ = typeExchange(attr.act_gate);
act_cand_ = typeExchange(attr.act_cand);
act_cell_ = typeExchange(attr.act_cell);
}
static bool init(int d);
void generate() override;
protected:
int num_;
bool compute_c1h1_;
bool use_peephole_;
operand_type act_gate_;
operand_type act_cand_;
operand_type act_cell_;
reg64_t param1{abi_param1};
};
class GRUJitCode : public VActJitCode {
public:
const char* name() const override {
std::string base = "GRUJitCode";
if (id_ == 0) {
base += "_H1";
} else if (id_ == 1) {
base += "_HtPart1";
} else if (id_ == 2) {
base += "_HtPart2";
}
auto AddTypeStr = [&](operand_type type) {
switch (type) {
case operand_type::relu:
base += "_Relu";
break;
case operand_type::exp:
base += "_Exp";
break;
case operand_type::sigmoid:
base += "_Sigmoid";
break;
case operand_type::tanh:
base += "_Tanh";
break;
case operand_type::identity:
base += "_Identity";
break;
default:
break;
}
};
AddTypeStr(act_gate_);
AddTypeStr(act_cand_);
return base.c_str();
}
explicit GRUJitCode(int id, const gru_attr_t& attr,
size_t code_size = 256 * 1024, void* code_ptr = nullptr)
: VActJitCode(attr.d, operand_type::sigmoid /* this is bugy*/, code_size,
code_ptr),
id_(id) {
auto typeExchange = [](const std::string& type) -> gen::operand_type {
if (type == "sigmoid") {
return operand_type::sigmoid;
} else if (type == "relu") {
return operand_type::relu;
} else if (type == "tanh") {
return operand_type::tanh;
} else if (type == "identity" || type == "") {
return operand_type::identity;
} // else throw error
return operand_type::identity;
};
num_ = attr.d;
act_gate_ = typeExchange(attr.act_gate);
act_cand_ = typeExchange(attr.act_cand);
}
static bool init(int d);
void generate() override;
protected:
int id_;
int num_;
operand_type act_gate_;
operand_type act_cand_;
reg64_t param1{abi_param1};
};
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
struct EltwiseMulnChw16cNC : public Xbyak::CodeGenerator { struct EltwiseMulnChw16cNC : public Xbyak::CodeGenerator {
explicit EltwiseMulnChw16cNC(size_t code_size = 256 * 1024) explicit EltwiseMulnChw16cNC(size_t code_size = 256 * 1024)
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <memory> // for shared_ptr #include <memory> // for shared_ptr
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "paddle/fluid/operators/math/jit_kernel_impl.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
...@@ -26,14 +27,7 @@ namespace operators { ...@@ -26,14 +27,7 @@ namespace operators {
namespace math { namespace math {
namespace jitkernel { namespace jitkernel {
// TODO(TJ): move these to some proper place // TODO(TJ): remove me
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
#define XMM_FLOAT_BLOCK 4
#define YMM_FLOAT_BLOCK 8
#define ZMM_FLOAT_BLOCK 16
typedef enum { kLT8, kEQ8, kGT8LT16, kEQ16, kGT16 } jit_block; typedef enum { kLT8, kEQ8, kGT8LT16, kEQ16, kGT16 } jit_block;
class Kernel { class Kernel {
...@@ -128,24 +122,18 @@ class VTanhKernel : public VActKernel<T> {}; ...@@ -128,24 +122,18 @@ class VTanhKernel : public VActKernel<T> {};
template <typename T> template <typename T>
class LSTMKernel : public Kernel { class LSTMKernel : public Kernel {
public: public:
virtual void ComputeCtHt(T *gates, const T *ct_1, T *ct, T *ht,
/* below only used in peephole*/
const T *wp_data = nullptr,
T *checked = nullptr) const = 0;
// compute c1 and h1 without c0 or h0 // compute c1 and h1 without c0 or h0
virtual void ComputeC1H1(T *gates, T *ct, T *ht, void (*ComputeC1H1)(lstm_t *, const lstm_attr_t *);
/* below only used in peephole*/ void (*ComputeCtHt)(lstm_t *, const lstm_attr_t *);
const T *wp_data = nullptr) const = 0;
}; };
template <typename T> template <typename T>
class GRUKernel : public Kernel { class GRUKernel : public Kernel {
public: public:
// compute h1 without h0 // compute h1 without h0
virtual void ComputeH1(T *gates, T *ht) const = 0; void (*ComputeH1)(gru_t *, const gru_attr_t *);
virtual void ComputeHtPart1(T *gates, const T *ht_1, T *ht) const = 0; void (*ComputeHtPart1)(gru_t *, const gru_attr_t *);
virtual void ComputeHtPart2(T *gates, const T *ht_1, T *ht) const = 0; void (*ComputeHtPart2)(gru_t *, const gru_attr_t *);
}; };
template <typename T> template <typename T>
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h" #include "paddle/fluid/operators/math/jit_kernel.h"
#include <string> #include <string>
#include "paddle/fluid/operators/math/jit_kernel_macro.h" #include "paddle/fluid/operators/math/jit_kernel_macro.h"
#include "paddle/fluid/operators/math/jit_kernel_refer.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
...@@ -31,49 +32,6 @@ namespace math { ...@@ -31,49 +32,6 @@ namespace math {
namespace jitkernel { namespace jitkernel {
namespace jit = platform::jit; namespace jit = platform::jit;
template <typename T>
void VMulRefer(const T* x, const T* y, T* z, int n) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] * y[i];
}
}
template <typename T>
void VAddRefer(const T* x, const T* y, T* z, int n) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] + y[i];
}
}
template <typename T>
void VAddReluRefer(const T* x, const T* y, T* z, int n) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] + y[i];
z[i] = z[i] > 0 ? z[i] : 0;
}
}
template <typename T>
void VScalRefer(const T* a, const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = a[0] * x[i];
}
}
template <typename T>
void VAddBiasRefer(const T* a, const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = a[0] + x[i];
}
}
template <typename T>
void VReluRefer(const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = x[i] > 0 ? x[i] : 0;
}
}
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
template <typename T> template <typename T>
void VMulMKL(const T* x, const T* y, T* z, int n); void VMulMKL(const T* x, const T* y, T* z, int n);
...@@ -109,7 +67,7 @@ void VScalMKL<float>(const float* a, const float* x, float* y, int n) { ...@@ -109,7 +67,7 @@ void VScalMKL<float>(const float* a, const float* x, float* y, int n) {
if (x == y) { if (x == y) {
platform::dynload::cblas_sscal(n, *a, y, 1); platform::dynload::cblas_sscal(n, *a, y, 1);
} else { } else {
VScalRefer<float>(a, x, y, n); refer::VScal<float>(a, x, y, n);
} }
} }
...@@ -118,7 +76,7 @@ void VScalMKL<double>(const double* a, const double* x, double* y, int n) { ...@@ -118,7 +76,7 @@ void VScalMKL<double>(const double* a, const double* x, double* y, int n) {
if (x == y) { if (x == y) {
platform::dynload::cblas_dscal(n, *a, y, 1); platform::dynload::cblas_dscal(n, *a, y, 1);
} else { } else {
VScalRefer<double>(a, x, y, n); refer::VScal<double>(a, x, y, n);
} }
} }
...@@ -147,7 +105,7 @@ class VMulKernelImpl : public VMulKernel<T> { ...@@ -147,7 +105,7 @@ class VMulKernelImpl : public VMulKernel<T> {
return; return;
} }
#endif #endif
this->Compute = VMulRefer<T>; this->Compute = refer::VMul<T>;
} }
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
...@@ -198,7 +156,7 @@ class VAddKernelImpl : public VAddKernel<T> { ...@@ -198,7 +156,7 @@ class VAddKernelImpl : public VAddKernel<T> {
return; return;
} }
#endif #endif
this->Compute = VAddRefer<T>; this->Compute = refer::VAdd<T>;
} }
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
...@@ -280,7 +238,7 @@ class VAddReluKernelImpl : public VAddReluKernel<T> { ...@@ -280,7 +238,7 @@ class VAddReluKernelImpl : public VAddReluKernel<T> {
return; return;
} }
#endif #endif
this->Compute = VAddReluRefer<T>; this->Compute = refer::VAddRelu<T>;
} }
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
...@@ -318,7 +276,7 @@ class VScalKernelImpl : public VScalKernel<T> { ...@@ -318,7 +276,7 @@ class VScalKernelImpl : public VScalKernel<T> {
return; return;
} }
#endif #endif
this->Compute = VScalRefer<T>; this->Compute = refer::VScal<T>;
} }
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
...@@ -362,7 +320,7 @@ class VAddBiasKernelImpl : public VAddBiasKernel<T> { ...@@ -362,7 +320,7 @@ class VAddBiasKernelImpl : public VAddBiasKernel<T> {
} }
#endif #endif
this->Compute = VAddBiasRefer<T>; this->Compute = refer::VAddBias<T>;
} }
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
...@@ -396,7 +354,7 @@ class VReluKernelImpl : public VReluKernel<T> { ...@@ -396,7 +354,7 @@ class VReluKernelImpl : public VReluKernel<T> {
} }
#endif #endif
this->Compute = VReluRefer<T>; this->Compute = refer::VRelu<T>;
} }
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
...@@ -412,16 +370,13 @@ bool VReluKernelImpl<float>::useJIT(int d) { ...@@ -412,16 +370,13 @@ bool VReluKernelImpl<float>::useJIT(int d) {
} }
#endif #endif
template <typename T>
inline void VIdentityRefer(const T* x, T* y, int n) {}
/* An empty JitKernel */ /* An empty JitKernel */
template <typename T> template <typename T>
class VIdentityKernelImpl : public VIdentityKernel<T> { class VIdentityKernelImpl : public VIdentityKernel<T> {
public: public:
JITKERNEL_DECLARE_STATIC_FUNC; JITKERNEL_DECLARE_STATIC_FUNC;
explicit VIdentityKernelImpl(int d) : VIdentityKernel<T>() { explicit VIdentityKernelImpl(int d) : VIdentityKernel<T>() {
this->Compute = VIdentityRefer<T>; this->Compute = refer::VIdentity<T>;
} }
}; };
......
...@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h" #include "paddle/fluid/operators/math/jit_kernel.h"
#include <cmath> // for exp
#include <string> #include <string>
#include "paddle/fluid/operators/math/jit_kernel_macro.h" #include "paddle/fluid/operators/math/jit_kernel_macro.h"
#include "paddle/fluid/operators/math/jit_kernel_refer.h"
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
#include "paddle/fluid/operators/math/jit_code.h" #include "paddle/fluid/operators/math/jit_code.h"
...@@ -25,48 +25,12 @@ limitations under the License. */ ...@@ -25,48 +25,12 @@ limitations under the License. */
#include "paddle/fluid/platform/dynload/mklml.h" #include "paddle/fluid/platform/dynload/mklml.h"
#endif #endif
#ifdef __AVX__
#include <immintrin.h>
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
namespace jitkernel { namespace jitkernel {
namespace jit = platform::jit; namespace jit = platform::jit;
// TODO(TJ): move refer codes to one file
// Refer code only focus on correctness
template <typename T>
void VExpRefer(const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = std::exp(x[i]);
}
}
template <typename T>
void VSigmoidRefer(const T* x, T* y, int n) {
// y = 1 / (1 + e^-x)
const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX;
for (int i = 0; i < n; ++i) {
T tmp = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = static_cast<T>(1) / (static_cast<T>(1) + std::exp(-tmp));
}
}
template <typename T>
void VTanhRefer(const T* x, T* y, int n) {
// y = 2 * sigmoid(2x) - 1
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(2) * x[i];
}
VSigmoidRefer(y, y, n);
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(2) * y[i] - static_cast<T>(1);
}
}
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
// try to use MKL to speedup // try to use MKL to speedup
template <typename T> template <typename T>
...@@ -129,7 +93,7 @@ class VExpKernelImpl : public VExpKernel<T> { ...@@ -129,7 +93,7 @@ class VExpKernelImpl : public VExpKernel<T> {
return; return;
} }
#endif #endif
this->Compute = VExpRefer<T>; this->Compute = refer::VExp<T>;
} }
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
...@@ -182,7 +146,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> { ...@@ -182,7 +146,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
return; return;
} }
#endif #endif
this->Compute = VSigmoidRefer<T>; this->Compute = refer::VSigmoid<T>;
} }
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
...@@ -234,7 +198,7 @@ class VTanhKernelImpl : public VTanhKernel<T> { ...@@ -234,7 +198,7 @@ class VTanhKernelImpl : public VTanhKernel<T> {
return; return;
} }
#endif #endif
this->Compute = VTanhRefer<T>; this->Compute = refer::VTanh<T>;
} }
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
...@@ -267,154 +231,6 @@ REGISTER_JITKERNEL(vexp, VExpKernel); ...@@ -267,154 +231,6 @@ REGISTER_JITKERNEL(vexp, VExpKernel);
REGISTER_JITKERNEL(vsigmoid, VSigmoidKernel); REGISTER_JITKERNEL(vsigmoid, VSigmoidKernel);
REGISTER_JITKERNEL(vtanh, VTanhKernel); REGISTER_JITKERNEL(vtanh, VTanhKernel);
namespace detail {
#ifdef __AVX__
#define ALIGN32 __attribute__((aligned(32)))
#define _PS256_CONST(Name, Val) \
static const float _ps256_##Name[8] ALIGN32 = {Val, Val, Val, Val, \
Val, Val, Val, Val}
#define _PI256_CONST(Name, Val) \
static const int _pi256_##Name[8] ALIGN32 = {Val, Val, Val, Val, \
Val, Val, Val, Val}
_PI256_CONST(0x7f, 0x7f);
_PS256_CONST(one, 1.f);
_PS256_CONST(0p5, 0.5f);
_PS256_CONST(exp_hi, 88.3762626647949f);
_PS256_CONST(exp_lo, -88.3762626647949f);
_PS256_CONST(cephes_LOG2EF, 1.44269504088896341);
_PS256_CONST(cephes_exp_C1, 0.693359375);
_PS256_CONST(cephes_exp_C2, -2.12194440e-4);
_PS256_CONST(cephes_exp_p0, 1.9875691500E-4);
_PS256_CONST(cephes_exp_p1, 1.3981999507E-3);
_PS256_CONST(cephes_exp_p2, 8.3334519073E-3);
_PS256_CONST(cephes_exp_p3, 4.1665795894E-2);
_PS256_CONST(cephes_exp_p4, 1.6666665459E-1);
_PS256_CONST(cephes_exp_p5, 5.0000001201E-1);
typedef union imm_xmm_union {
__m256i imm;
__m128i xmm[2];
} imm_xmm_union;
#define COPY_IMM_TO_XMM(imm_, xmm0_, xmm1_) \
{ \
imm_xmm_union u ALIGN32; \
u.imm = imm_; \
xmm0_ = u.xmm[0]; \
xmm1_ = u.xmm[1]; \
}
#define COPY_XMM_TO_IMM(xmm0_, xmm1_, imm_) \
{ \
imm_xmm_union u ALIGN32; \
u.xmm[0] = xmm0_; \
u.xmm[1] = xmm1_; \
imm_ = u.imm; \
}
#define AVX2_BITOP_USING_SSE2(fn) \
static inline __m256i avx2_mm256_##fn(__m256i x, int y) { \
/* use SSE2 to perform the bitop AVX2 */ \
__m128i x1, x2; \
__m256i ret; \
COPY_IMM_TO_XMM(x, x1, x2); \
x1 = _mm_##fn(x1, y); \
x2 = _mm_##fn(x2, y); \
COPY_XMM_TO_IMM(x1, x2, ret); \
return ret; \
}
#define AVX2_INTOP_USING_SSE2(fn) \
static inline __m256i avx2_mm256_add_epi32(__m256i x, __m256i y) { \
/* use SSE2 to perform the AVX2 integer operation */ \
__m128i x1, x2; \
__m128i y1, y2; \
__m256i ret; \
COPY_IMM_TO_XMM(x, x1, x2); \
COPY_IMM_TO_XMM(y, y1, y2); \
x1 = _mm_##fn(x1, y1); \
x2 = _mm_##fn(x2, y2); \
COPY_XMM_TO_IMM(x1, x2, ret); \
return ret; \
}
AVX2_BITOP_USING_SSE2(slli_epi32);
AVX2_INTOP_USING_SSE2(add_epi32);
#define AVXEXP_BASE \
__m256 tmp = _mm256_setzero_ps(), fx; \
__m256 one = *reinterpret_cast<const __m256*>(_ps256_one); \
__m256i imm0; \
x = _mm256_min_ps(x, *reinterpret_cast<const __m256*>(_ps256_exp_hi)); \
x = _mm256_max_ps(x, *reinterpret_cast<const __m256*>(_ps256_exp_lo)); \
/* express exp(x) as exp(g + n*log(2)) */ \
fx = _mm256_mul_ps(x, \
*reinterpret_cast<const __m256*>(_ps256_cephes_LOG2EF)); \
fx = _mm256_add_ps(fx, *reinterpret_cast<const __m256*>(_ps256_0p5)); \
tmp = _mm256_floor_ps(fx); \
/* if greater, substract 1 */ \
__m256 mask = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS); \
mask = _mm256_and_ps(mask, one); \
fx = _mm256_sub_ps(tmp, mask); \
tmp = _mm256_mul_ps(fx, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_C1)); \
__m256 z = _mm256_mul_ps( \
fx, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_C2)); \
x = _mm256_sub_ps(x, tmp); \
x = _mm256_sub_ps(x, z); \
z = _mm256_mul_ps(x, x); \
__m256 y = *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p0); \
y = _mm256_mul_ps(y, x); \
y = _mm256_add_ps(y, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p1)); \
y = _mm256_mul_ps(y, x); \
y = _mm256_add_ps(y, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p2)); \
y = _mm256_mul_ps(y, x); \
y = _mm256_add_ps(y, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p3)); \
y = _mm256_mul_ps(y, x); \
y = _mm256_add_ps(y, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p4)); \
y = _mm256_mul_ps(y, x); \
y = _mm256_add_ps(y, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p5)); \
y = _mm256_mul_ps(y, z); \
y = _mm256_add_ps(y, x); \
y = _mm256_add_ps(y, one); \
/* build 2^n */ \
imm0 = _mm256_cvttps_epi32(fx)
__m256 ExpAVX(__m256 x) {
AVXEXP_BASE;
// two AVX2 instructions using SSE2
imm0 = avx2_mm256_add_epi32(imm0,
*reinterpret_cast<const __m256i*>(_pi256_0x7f));
imm0 = avx2_mm256_slli_epi32(imm0, 23);
__m256 pow2n = _mm256_castsi256_ps(imm0);
y = _mm256_mul_ps(y, pow2n);
return y;
}
#endif
#ifdef __AVX2__
__m256 ExpAVX2(__m256 x) {
AVXEXP_BASE;
// two AVX2 instructions
imm0 = _mm256_add_epi32(imm0, *reinterpret_cast<const __m256i*>(_pi256_0x7f));
imm0 = _mm256_slli_epi32(imm0, 23);
__m256 pow2n = _mm256_castsi256_ps(imm0);
y = _mm256_mul_ps(y, pow2n);
return y;
}
#endif
} // namespace detail
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <type_traits>
namespace paddle {
namespace operators {
namespace math {
namespace jitkernel {
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
#define XMM_FLOAT_BLOCK 4
#define YMM_FLOAT_BLOCK 8
#define ZMM_FLOAT_BLOCK 16
typedef struct {
void* gates; // gates: W_ch, W_ih, W_fh, W_oh
const void* ct_1;
void* ct;
void* ht;
/* weight_peephole and checked data are only used in peephole*/
const void* wp{nullptr};
void* checked{nullptr};
} lstm_t;
typedef struct {
void* gates; // gates: {W_update, W_reset; W_state}
const void* ht_1;
void* ht;
} gru_t;
struct rnn_attr_s {
int d;
std::string act_gate, act_cand;
rnn_attr_s() = default;
rnn_attr_s(int _d, const std::string& _act_gate, const std::string& _act_cand)
: d(_d), act_gate(_act_gate), act_cand(_act_cand) {}
};
struct lstm_attr_s : public rnn_attr_s {
bool use_peephole;
std::string act_cell;
lstm_attr_s() = default;
lstm_attr_s(int _d, const std::string& _act_gate,
const std::string& _act_cand, const std::string& _act_cell,
bool _use_peephole = false)
: rnn_attr_s(_d, _act_gate, _act_cand),
use_peephole(_use_peephole),
act_cell(_act_cell) {}
};
typedef struct rnn_attr_s gru_attr_t;
typedef struct lstm_attr_s lstm_attr_t;
} // namespace jitkernel
} // namespace math
} // namespace operators
} // namespace paddle
...@@ -82,10 +82,10 @@ namespace jitkernel { ...@@ -82,10 +82,10 @@ namespace jitkernel {
#define REGISTER_JITKERNEL_ARGS(ker_key, ker_class, marco_define_name, \ #define REGISTER_JITKERNEL_ARGS(ker_key, ker_class, marco_define_name, \
marco_declare, macro_find_key, macro_impl) \ marco_declare, macro_find_key, macro_impl) \
marco_define_name(ker_key, ker_class); \ marco_define_name(ker_key, ker_class); \
REGISTER_JITKERNEL_WITH_DTYPE(ker_class, float, JITKERNEL_DECLARE, \ REGISTER_JITKERNEL_WITH_DTYPE(ker_class, float, marco_declare, \
JITKERNEL_FIND_KEY, JITKERNEL_IMPL); \ macro_find_key, macro_impl); \
REGISTER_JITKERNEL_WITH_DTYPE(ker_class, double, JITKERNEL_DECLARE, \ REGISTER_JITKERNEL_WITH_DTYPE(ker_class, double, marco_declare, \
JITKERNEL_FIND_KEY, JITKERNEL_IMPL) macro_find_key, macro_impl)
#define REGISTER_JITKERNEL(ker_key, ker_class) \ #define REGISTER_JITKERNEL(ker_key, ker_class) \
REGISTER_JITKERNEL_ARGS(ker_key, ker_class, JITKERNEL_DEFINE_NAME, \ REGISTER_JITKERNEL_ARGS(ker_key, ker_class, JITKERNEL_DEFINE_NAME, \
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cmath>
#include <string>
#include "paddle/fluid/operators/math/jit_kernel_impl.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace math {
namespace jitkernel {
namespace refer {
/* Refer code only focus on correctness */
template <typename T>
void VMul(const T* x, const T* y, T* z, int n) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] * y[i];
}
}
template <typename T>
void VAdd(const T* x, const T* y, T* z, int n) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] + y[i];
}
}
template <typename T>
void VAddRelu(const T* x, const T* y, T* z, int n) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] + y[i];
z[i] = z[i] > 0 ? z[i] : 0;
}
}
template <typename T>
void VScal(const T* a, const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = a[0] * x[i];
}
}
template <typename T>
void VAddBias(const T* a, const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = a[0] + x[i];
}
}
template <typename T>
void VRelu(const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = x[i] > 0 ? x[i] : 0;
}
}
template <typename T>
inline void VIdentity(const T* x, T* y, int n) {}
template <typename T>
void VExp(const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = std::exp(x[i]);
}
}
template <typename T>
void VSigmoid(const T* x, T* y, int n) {
// y = 1 / (1 + e^-x)
const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX;
for (int i = 0; i < n; ++i) {
T tmp = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = static_cast<T>(1) / (static_cast<T>(1) + std::exp(-tmp));
}
}
template <typename T>
void VTanh(const T* x, T* y, int n) {
// y = 2 * sigmoid(2x) - 1
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(2) * x[i];
}
VSigmoid(y, y, n);
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(2) * y[i] - static_cast<T>(1);
}
}
template <typename T>
void (*getActFunc(const std::string& type))(const T*, T*, int) { // NOLINT
if (type == "sigmoid") {
return VSigmoid<T>;
} else if (type == "relu") {
return VRelu<T>;
} else if (type == "tanh") {
return VTanh<T>;
} else if (type == "identity" || type == "") {
return VIdentity<T>;
}
PADDLE_THROW("Not support type: %s", type);
return nullptr;
}
// compute ct and ht
template <typename T>
void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) {
T* gates = reinterpret_cast<T*>(step->gates);
const T* ct_1 = reinterpret_cast<const T*>(step->ct_1);
T* ct = reinterpret_cast<T*>(step->ct);
T* ht = reinterpret_cast<T*>(step->ht);
const T* wp = reinterpret_cast<const T*>(step->wp);
T* checked = reinterpret_cast<T*>(step->checked);
auto act_gate = getActFunc<T>(attr->act_gate);
auto act_cand = getActFunc<T>(attr->act_cand);
auto act_cell = getActFunc<T>(attr->act_cell);
int d = attr->d;
int d2 = d * 2;
int d3 = d * 3;
// gates: W_ch, W_ih, W_fh, W_oh
if (attr->use_peephole) {
VMul(wp, ct_1, checked, d);
VMul(wp + d, ct_1, checked + d, d);
VAdd(checked, gates + d, gates + d, d2);
act_gate(gates + d, gates + d, d2);
} else {
act_gate(gates + d, gates + d, d3);
}
// C_t = C_t-1 * fgated + cand_gated * igated
act_cand(gates, gates, d);
VMul(gates, gates + d, gates + d, d);
VMul(ct_1, gates + d2, gates + d2, d);
VAdd(gates + d, gates + d2, ct, d);
if (attr->use_peephole) {
// get ogated
VMul(wp + d2, ct, gates + d, d);
VAdd(gates + d, gates + d3, gates + d3, d);
act_gate(gates + d3, gates + d3, d);
}
// H_t = act_cell(C_t) * ogated
act_cell(ct, gates + d2, d);
VMul(gates + d2, gates + d3, ht, d);
}
// compute c1 and h1 without c0 or h0
template <typename T>
void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr) {
T* gates = reinterpret_cast<T*>(step->gates);
T* ct = reinterpret_cast<T*>(step->ct);
T* ht = reinterpret_cast<T*>(step->ht);
auto act_gate = getActFunc<T>(attr->act_gate);
auto act_cand = getActFunc<T>(attr->act_cand);
auto act_cell = getActFunc<T>(attr->act_cell);
int d = attr->d;
int d2 = d * 2;
int d3 = d * 3;
/* C_t = igated * cgated*/
act_gate(gates + d, gates + d, d);
act_cand(gates, gates, d);
VMul(gates, gates + d, ct, d);
if (attr->use_peephole) {
// get outgated, put W_oc * C_t on igated
const T* wp = reinterpret_cast<const T*>(step->wp);
VMul(wp + d2, ct, gates + d, d);
VAdd(gates + d, gates + d3, gates + d3, d);
}
/* H_t = act_cell(C_t) * ogated */
act_gate(gates + d3, gates + d3, d);
act_cell(ct, gates + d2, d);
VMul(gates + d2, gates + d3, ht, d);
}
// compute h1 without h0
template <typename T>
void GRUH1(gru_t* step, const gru_attr_t* attr) {
T* gates = reinterpret_cast<T*>(step->gates);
T* ht = reinterpret_cast<T*>(step->ht);
auto act_gate = getActFunc<T>(attr->act_gate);
auto act_cand = getActFunc<T>(attr->act_cand);
int d = attr->d;
int d2 = d * 2;
act_gate(gates, gates, d);
act_cand(gates + d2, gates + d2, d);
VMul(gates, gates + d2, ht, d);
}
// compute the first part of GRU: ht = act_gate(r) * ht_1
template <typename T>
void GRUHtPart1(gru_t* step, const gru_attr_t* attr) {
// W: {W_update, W_reset; W_state}
T* gates = reinterpret_cast<T*>(step->gates);
T* ht = reinterpret_cast<T*>(step->ht);
const T* ht_1 = reinterpret_cast<const T*>(step->ht_1);
auto act_gate = getActFunc<T>(attr->act_gate);
act_gate(gates + attr->d, gates + attr->d, attr->d);
VMul(ht_1, gates + attr->d, ht, attr->d);
}
// compute the second part of GRU:
// ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1
template <typename T>
void GRUHtPart2(gru_t* step, const gru_attr_t* attr) {
T* gates = reinterpret_cast<T*>(step->gates);
T* ht = reinterpret_cast<T*>(step->ht);
const T* ht_1 = reinterpret_cast<const T*>(step->ht_1);
auto act_gate = getActFunc<T>(attr->act_gate);
auto act_cand = getActFunc<T>(attr->act_cand);
int d = attr->d;
T* y = gates + d * 2;
act_gate(gates, gates, d);
act_cand(y, y, d);
// out = zt*ht~ + (1-zt)*ht_1
for (int i = 0; i < d; ++i) {
ht[i] = gates[i] * y[i] + (static_cast<T>(1) - gates[i]) * ht_1[i];
}
}
} // namespace refer
} // namespace jitkernel
} // namespace math
} // namespace operators
} // namespace paddle
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "glog/logging.h" #include "glog/logging.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/operators/math/jit_kernel_refer.h"
#include "paddle/fluid/platform/port.h" #include "paddle/fluid/platform/port.h"
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
...@@ -53,12 +54,6 @@ void RandomVec(const int n, T* a, const T lower = static_cast<T>(-20.f), ...@@ -53,12 +54,6 @@ void RandomVec(const int n, T* a, const T lower = static_cast<T>(-20.f),
} }
} }
void vrelu_ref(const int n, const float* x, float* y) {
for (int i = 0; i < n; ++i) {
y[i] = x[i] > 0.f ? x[i] : 0.f;
}
}
#if defined __AVX__ || defined __AVX2__ #if defined __AVX__ || defined __AVX2__
void vrelu_intri8(const int n, const float* x, float* y) { void vrelu_intri8(const int n, const float* x, float* y) {
__m256 tmp = _mm256_loadu_ps(x); __m256 tmp = _mm256_loadu_ps(x);
...@@ -69,6 +64,7 @@ void vrelu_intri8(const int n, const float* x, float* y) { ...@@ -69,6 +64,7 @@ void vrelu_intri8(const int n, const float* x, float* y) {
TEST(JitKernel, vrelu) { TEST(JitKernel, vrelu) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
namespace refer = paddle::operators::math::jitkernel::refer;
for (int d : {3, 7, 8, 15, 16, 30, 256, 512}) { for (int d : {3, 7, 8, 15, 16, 30, 256, 512}) {
std::vector<float> x(d); std::vector<float> x(d);
std::vector<float> zref(d), ztgt(d); std::vector<float> zref(d), ztgt(d);
...@@ -80,7 +76,7 @@ TEST(JitKernel, vrelu) { ...@@ -80,7 +76,7 @@ TEST(JitKernel, vrelu) {
float* zref_data = zref.data(); float* zref_data = zref.data();
auto trefs = GetCurrentUS(); auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
vrelu_ref(d, x_data, zref_data); refer::VRelu<float>(x_data, zref_data, d);
} }
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
#if defined __AVX__ || defined __AVX2__ #if defined __AVX__ || defined __AVX2__
...@@ -90,7 +86,7 @@ TEST(JitKernel, vrelu) { ...@@ -90,7 +86,7 @@ TEST(JitKernel, vrelu) {
vrelu_intri8(d, x_data, zref_data); vrelu_intri8(d, x_data, zref_data);
} }
auto si1 = GetCurrentUS(); auto si1 = GetCurrentUS();
VLOG(30) << "Vec size 8 intr takes: " << (si1 - si0) / repeat; VLOG(30) << "Vec size 8 intr takes: " << (si1 - si0) / repeat << " us";
} }
#endif #endif
auto ttgts = GetCurrentUS(); auto ttgts = GetCurrentUS();
...@@ -100,21 +96,16 @@ TEST(JitKernel, vrelu) { ...@@ -100,21 +96,16 @@ TEST(JitKernel, vrelu) {
auto ttgte = GetCurrentUS(); auto ttgte = GetCurrentUS();
VLOG(30) << "Vec size " << d VLOG(30) << "Vec size " << d
<< ": refer takes: " << (trefe - trefs) / repeat << ": refer takes: " << (trefe - trefs) / repeat
<< " us, tgt takes: " << (ttgte - ttgts) / repeat; << " us, tgt takes: " << (ttgte - ttgts) / repeat << " us";
for (int i = 0; i < d; ++i) { for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
} }
} }
} }
void vaddbias_ref(const int n, const float a, const float* x, float* y) {
for (int i = 0; i < n; ++i) {
y[i] = x[i] + a;
}
}
TEST(JitKernel, vaddbias) { TEST(JitKernel, vaddbias) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
namespace refer = paddle::operators::math::jitkernel::refer;
for (int d : {7, 8, 15, 16, 30, 64, 100, 128, 256}) { for (int d : {7, 8, 15, 16, 30, 64, 100, 128, 256}) {
std::vector<float> x(d); std::vector<float> x(d);
std::vector<float> zref(d), ztgt(d); std::vector<float> zref(d), ztgt(d);
...@@ -127,7 +118,7 @@ TEST(JitKernel, vaddbias) { ...@@ -127,7 +118,7 @@ TEST(JitKernel, vaddbias) {
float* zref_data = zref.data(); float* zref_data = zref.data();
auto trefs = GetCurrentUS(); auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
vaddbias_ref(d, a, x_data, zref_data); refer::VAddBias<float>(&a, x_data, zref_data, d);
} }
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
auto ttgts = GetCurrentUS(); auto ttgts = GetCurrentUS();
...@@ -138,19 +129,13 @@ TEST(JitKernel, vaddbias) { ...@@ -138,19 +129,13 @@ TEST(JitKernel, vaddbias) {
VLOG(30) << "Vec size " << d VLOG(30) << "Vec size " << d
<< ": refer takes: " << (trefe - trefs) / repeat << ": refer takes: " << (trefe - trefs) / repeat
<< " us, tgt takes: " << (ttgte - ttgts) / repeat; << " us, tgt takes: " << (ttgte - ttgts) / repeat << " us";
for (int i = 0; i < d; ++i) { for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
} }
} }
} }
void vexp_ref(const int n, const float* x, float* y) {
for (int i = 0; i < n; ++i) {
y[i] = std::exp(x[i]);
}
}
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
void vexp_mkl(const int n, const float* x, float* y) { void vexp_mkl(const int n, const float* x, float* y) {
paddle::platform::dynload::vsExp(n, x, y); paddle::platform::dynload::vsExp(n, x, y);
...@@ -159,6 +144,7 @@ void vexp_mkl(const int n, const float* x, float* y) { ...@@ -159,6 +144,7 @@ void vexp_mkl(const int n, const float* x, float* y) {
TEST(JitKernel, vexp) { TEST(JitKernel, vexp) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
namespace refer = paddle::operators::math::jitkernel::refer;
for (int d : {1, 3, 4, 6, 7, 8, 12, 15, 16, 20, 30, 128, 256}) { for (int d : {1, 3, 4, 6, 7, 8, 12, 15, 16, 20, 30, 128, 256}) {
std::vector<float> x(d); std::vector<float> x(d);
std::vector<float> zref(d), ztgt(d); std::vector<float> zref(d), ztgt(d);
...@@ -170,7 +156,7 @@ TEST(JitKernel, vexp) { ...@@ -170,7 +156,7 @@ TEST(JitKernel, vexp) {
float* zref_data = zref.data(); float* zref_data = zref.data();
auto trefs = GetCurrentUS(); auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
vexp_ref(d, x_data, zref_data); refer::VExp<float>(x_data, zref_data, d);
} }
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
...@@ -196,26 +182,13 @@ TEST(JitKernel, vexp) { ...@@ -196,26 +182,13 @@ TEST(JitKernel, vexp) {
#else #else
<< " us, " << " us, "
#endif #endif
<< "tgt takes: " << (ttgte - ttgts) / repeat; << "tgt takes: " << (ttgte - ttgts) / repeat << " us";
for (int i = 0; i < d; ++i) { for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
} }
} }
} }
inline float _sigmoid(float x) {
const float min = SIGMOID_THRESHOLD_MIN;
const float max = SIGMOID_THRESHOLD_MAX;
float tmp = (x < min) ? min : ((x > max) ? max : x);
return 1.f / (1.f + std::exp(-tmp));
}
void vsigmoid_ref(const int n, const float* x, float* y) {
for (int i = 0; i < n; ++i) {
y[i] = _sigmoid(x[i]);
}
}
void vsigmoid_better( void vsigmoid_better(
const std::shared_ptr< const std::shared_ptr<
const paddle::operators::math::jitkernel::VExpKernel<float>>& vexp, const paddle::operators::math::jitkernel::VExpKernel<float>>& vexp,
...@@ -234,6 +207,7 @@ void vsigmoid_better( ...@@ -234,6 +207,7 @@ void vsigmoid_better(
TEST(JitKernel, vsigmoid) { TEST(JitKernel, vsigmoid) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
namespace refer = paddle::operators::math::jitkernel::refer;
for (int d : {1, 3, 4, 6, 7, 8, 15, 16, 30, 32, 64, 100, 128, 256}) { for (int d : {1, 3, 4, 6, 7, 8, 15, 16, 30, 32, 64, 100, 128, 256}) {
std::vector<float> x(d); std::vector<float> x(d);
std::vector<float> zref(d), ztgt(d); std::vector<float> zref(d), ztgt(d);
...@@ -252,7 +226,7 @@ TEST(JitKernel, vsigmoid) { ...@@ -252,7 +226,7 @@ TEST(JitKernel, vsigmoid) {
auto tmkle = GetCurrentUS(); auto tmkle = GetCurrentUS();
auto trefs = GetCurrentUS(); auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
vsigmoid_ref(d, x_data, zref_data); refer::VSigmoid<float>(x_data, zref_data, d);
} }
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
auto ttgts = GetCurrentUS(); auto ttgts = GetCurrentUS();
...@@ -264,21 +238,13 @@ TEST(JitKernel, vsigmoid) { ...@@ -264,21 +238,13 @@ TEST(JitKernel, vsigmoid) {
VLOG(30) << "Vec size " << d VLOG(30) << "Vec size " << d
<< ": refer takes: " << (trefe - trefs) / repeat << ": refer takes: " << (trefe - trefs) / repeat
<< " us, better(jit exp) takes: " << (tmkle - tmkls) / repeat << " us, better(jit exp) takes: " << (tmkle - tmkls) / repeat
<< " us, tgt takes: " << (ttgte - ttgts) / repeat; << " us, tgt takes: " << (ttgte - ttgts) / repeat << " us";
for (int i = 0; i < d; ++i) { for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
} }
} }
} }
inline float _tanh(float x) { return 2.f * _sigmoid(2.f * x) - 1.f; }
void vtanh_ref(const int n, const float* x, float* y) {
for (int i = 0; i < n; ++i) {
y[i] = _tanh(x[i]);
}
}
void vtanh_better( void vtanh_better(
const std::shared_ptr< const std::shared_ptr<
const paddle::operators::math::jitkernel::VScalKernel<float>>& vscal, const paddle::operators::math::jitkernel::VScalKernel<float>>& vscal,
...@@ -298,6 +264,7 @@ void vtanh_better( ...@@ -298,6 +264,7 @@ void vtanh_better(
TEST(JitKernel, vtanh) { TEST(JitKernel, vtanh) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
namespace refer = paddle::operators::math::jitkernel::refer;
for (int d : {1, 2, 3, 4, 5, 6, 7, 8, 15, 16, 30, 32, 64, 100, 128, 256}) { for (int d : {1, 2, 3, 4, 5, 6, 7, 8, 15, 16, 30, 32, 64, 100, 128, 256}) {
std::vector<float> x(d); std::vector<float> x(d);
std::vector<float> zref(d), ztgt(d); std::vector<float> zref(d), ztgt(d);
...@@ -320,7 +287,7 @@ TEST(JitKernel, vtanh) { ...@@ -320,7 +287,7 @@ TEST(JitKernel, vtanh) {
auto tmkle = GetCurrentUS(); auto tmkle = GetCurrentUS();
auto trefs = GetCurrentUS(); auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
vtanh_ref(d, x_data, zref_data); refer::VTanh<float>(x_data, zref_data, d);
} }
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
auto ttgts = GetCurrentUS(); auto ttgts = GetCurrentUS();
...@@ -332,39 +299,13 @@ TEST(JitKernel, vtanh) { ...@@ -332,39 +299,13 @@ TEST(JitKernel, vtanh) {
VLOG(30) << "Vec size " << d VLOG(30) << "Vec size " << d
<< ": refer takes: " << (trefe - trefs) / repeat << ": refer takes: " << (trefe - trefs) / repeat
<< " us, better(jit exp) takes: " << (tmkle - tmkls) / repeat << " us, better(jit exp) takes: " << (tmkle - tmkls) / repeat
<< " us, tgt takes: " << (ttgte - ttgts) / repeat; << " us, tgt takes: " << (ttgte - ttgts) / repeat << " us";
for (int i = 0; i < d; ++i) { for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
} }
} }
} }
void lstm_ctht_ref(
const std::shared_ptr<
const paddle::operators::math::jitkernel::VSigmoidKernel<float>>&
vsigmoid_3d,
const std::shared_ptr<
const paddle::operators::math::jitkernel::VTanhKernel<float>>& vtanh_d,
const std::shared_ptr<
const paddle::operators::math::jitkernel::VExpKernel<float>>& vexp_1,
const int d, float* gates, const float* ct_1, float* ct, float* ht) {
vsigmoid_3d->Compute(gates + d, gates + d, 3 * d);
vtanh_d->Compute(gates, gates, d);
const float *i = gates + d, *f = gates + d * 2, *o = gates + d * 3;
const float min = SIGMOID_THRESHOLD_MIN;
const float max = SIGMOID_THRESHOLD_MAX;
for (int k = 0; k < d; ++k) {
// C_t = C_t-1 * fgated + cand_gated * igated
ct[k] = ct_1[k] * f[k] + gates[k] * i[k];
// H_t = act_cell(C_t) * ogated
float tmp = ct[k] * 2;
tmp = 0.f - ((tmp < min) ? min : ((tmp > max) ? max : tmp));
vexp_1->Compute(&tmp, &tmp, 1);
tmp = 2.f / (1.f + tmp) - 1.f;
ht[k] = tmp * o[k];
}
}
void lstm_ctht_better( void lstm_ctht_better(
const std::shared_ptr< const std::shared_ptr<
const paddle::operators::math::jitkernel::VSigmoidKernel<float>>& const paddle::operators::math::jitkernel::VSigmoidKernel<float>>&
...@@ -389,6 +330,7 @@ void lstm_ctht_better( ...@@ -389,6 +330,7 @@ void lstm_ctht_better(
TEST(JitKernel, lstm) { TEST(JitKernel, lstm) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
namespace refer = paddle::operators::math::jitkernel::refer;
for (int d : {1, 2, 3, 4, 5, 6, 7, 8, 15, 16, 30, 32, 64, 100}) { for (int d : {1, 2, 3, 4, 5, 6, 7, 8, 15, 16, 30, 32, 64, 100}) {
int d4 = d * 4; int d4 = d * 4;
int d3 = d * 3; int d3 = d * 3;
...@@ -399,19 +341,17 @@ TEST(JitKernel, lstm) { ...@@ -399,19 +341,17 @@ TEST(JitKernel, lstm) {
RandomVec<float>(d, ct_1.data(), -2.f, 2.f); RandomVec<float>(d, ct_1.data(), -2.f, 2.f);
memcpy(xref.data(), x.data(), sizeof(float) * d4); memcpy(xref.data(), x.data(), sizeof(float) * d4);
std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh"; std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh";
const jit::lstm_attr_t attr(d, act_gate, act_cand, act_cell, false);
const auto& ker = const auto& ker =
jit::KernelPool::Instance() jit::KernelPool::Instance()
.template Get<jit::LSTMKernel<float>, const std::string&, .template Get<jit::LSTMKernel<float>, const jit::lstm_attr_t&>(
const std::string&, const std::string&>( attr);
act_gate, act_cand, act_cell, d, false);
// below kernels are used to compute refer // below kernels are used to compute refer
const auto& vsigmoid_3d = const auto& vsigmoid_3d =
jit::KernelPool::Instance().template Get<jit::VSigmoidKernel<float>>( jit::KernelPool::Instance().template Get<jit::VSigmoidKernel<float>>(
d3); d3);
const auto& vtanh_d = const auto& vtanh_d =
jit::KernelPool::Instance().template Get<jit::VTanhKernel<float>>(d); jit::KernelPool::Instance().template Get<jit::VTanhKernel<float>>(d);
const auto& vexp_1 =
jit::KernelPool::Instance().template Get<jit::VExpKernel<float>>(1);
const auto& vmul_d = const auto& vmul_d =
jit::KernelPool::Instance().template Get<jit::VMulKernel<float>>(d); jit::KernelPool::Instance().template Get<jit::VMulKernel<float>>(d);
const auto& vadd_d = const auto& vadd_d =
...@@ -425,9 +365,17 @@ TEST(JitKernel, lstm) { ...@@ -425,9 +365,17 @@ TEST(JitKernel, lstm) {
float* ct_ref_data = ct_ref.data(); float* ct_ref_data = ct_ref.data();
float* ht_ref_data = ht_ref.data(); float* ht_ref_data = ht_ref.data();
// compute once to check correctness // compute once to check correctness
lstm_ctht_ref(vsigmoid_3d, vtanh_d, vexp_1, d, xref_data, ct_1_data, jit::lstm_t step;
ct_ref_data, ht_ref_data); step.gates = xref_data;
ker->ComputeCtHt(x_data, ct_1_data, ct_tgt_data, ht_tgt_data); step.ct_1 = ct_1_data;
step.ct = ct_ref_data;
step.ht = ht_ref_data;
refer::LSTMCtHt<float>(&step, &attr);
step.gates = x_data;
step.ct = ct_tgt_data;
step.ht = ht_tgt_data;
ker->ComputeCtHt(&step, &attr);
for (int i = 0; i < d; ++i) { for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ct_tgt_data[i], ct_ref_data[i], 1e-3); EXPECT_NEAR(ct_tgt_data[i], ct_ref_data[i], 1e-3);
EXPECT_NEAR(ht_tgt_data[i], ht_ref_data[i], 1e-3); EXPECT_NEAR(ht_tgt_data[i], ht_ref_data[i], 1e-3);
...@@ -441,32 +389,21 @@ TEST(JitKernel, lstm) { ...@@ -441,32 +389,21 @@ TEST(JitKernel, lstm) {
auto tmkle = GetCurrentUS(); auto tmkle = GetCurrentUS();
auto trefs = GetCurrentUS(); auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
lstm_ctht_ref(vsigmoid_3d, vtanh_d, vexp_1, d, xref_data, ct_1_data, refer::LSTMCtHt<float>(&step, &attr);
ct_ref_data, ht_ref_data);
} }
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
auto ttgts = GetCurrentUS(); auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
ker->ComputeCtHt(x_data, ct_1_data, ct_tgt_data, ht_tgt_data); ker->ComputeCtHt(&step, &attr);
} }
auto ttgte = GetCurrentUS(); auto ttgte = GetCurrentUS();
VLOG(30) << "Vec size " << d VLOG(30) << "Vec size " << d
<< ": refer takes: " << (trefe - trefs) / repeat << ": refer takes: " << (trefe - trefs) / repeat
<< " us, better(jit) takes: " << (tmkle - tmkls) / repeat << " us, better(jit) takes: " << (tmkle - tmkls) / repeat
<< " us, tgt takes: " << (ttgte - ttgts) / repeat; << " us, tgt takes: " << (ttgte - ttgts) / repeat << " us";
} }
} }
void vscal_ref(const int n, const float a, const float* x, float* y) {
for (int i = 0; i < n; ++i) {
y[i] = a * x[i];
}
}
void vscal_inp_ref(const int n, const float a, float* x) {
for (int i = 0; i < n; ++i) {
x[i] = a * x[i];
}
}
#if defined __AVX__ || defined __AVX2__ #if defined __AVX__ || defined __AVX2__
void vscal_intri8(const int n, const float a, const float* x, float* y) { void vscal_intri8(const int n, const float a, const float* x, float* y) {
__m256 tmp; __m256 tmp;
...@@ -492,6 +429,7 @@ void vscal_inp_mkl(const int n, const float a, float* x) { ...@@ -492,6 +429,7 @@ void vscal_inp_mkl(const int n, const float a, float* x) {
TEST(JitKernel, vscal) { TEST(JitKernel, vscal) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
namespace refer = paddle::operators::math::jitkernel::refer;
for (int d : {7, 8, 15, 16, 30, 256, 512}) { for (int d : {7, 8, 15, 16, 30, 256, 512}) {
std::vector<float> x(d), y(d); std::vector<float> x(d), y(d);
std::vector<float> zref(d), ztgt(d); std::vector<float> zref(d), ztgt(d);
...@@ -506,12 +444,12 @@ TEST(JitKernel, vscal) { ...@@ -506,12 +444,12 @@ TEST(JitKernel, vscal) {
float* zref_data = zref.data(); float* zref_data = zref.data();
auto trefs = GetCurrentUS(); auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
vscal_ref(d, a, x_data, zref_data); refer::VScal<float>(&a, x_data, zref_data, d);
} }
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
auto trefs1 = GetCurrentUS(); auto trefs1 = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
vscal_inp_ref(d, a, y_data); refer::VScal<float>(&a, y_data, y_data, d);
} }
auto trefe1 = GetCurrentUS(); auto trefe1 = GetCurrentUS();
...@@ -536,7 +474,7 @@ TEST(JitKernel, vscal) { ...@@ -536,7 +474,7 @@ TEST(JitKernel, vscal) {
} }
auto si3 = GetCurrentUS(); auto si3 = GetCurrentUS();
VLOG(30) << "Vec size 8 intr takes: " << (si1 - si0) / repeat VLOG(30) << "Vec size 8 intr takes: " << (si1 - si0) / repeat
<< " us, inplace: " << (si3 - si2) / repeat; << " us, inplace: " << (si3 - si2) / repeat << " us";
} }
#endif #endif
...@@ -560,19 +498,14 @@ TEST(JitKernel, vscal) { ...@@ -560,19 +498,14 @@ TEST(JitKernel, vscal) {
<< " us, " << " us, "
#endif #endif
<< "tgt takes: " << (ttgte - ttgts) / repeat << "tgt takes: " << (ttgte - ttgts) / repeat
<< "us, tgt inplace takes: " << (ttgte1 - ttgts1) / repeat; << "us, tgt inplace takes: " << (ttgte1 - ttgts1) / repeat
<< " us";
for (int i = 0; i < d; ++i) { for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
} }
} }
} }
void vmul_ref(const int n, const float* x, const float* y, float* z) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] * y[i];
}
}
#if defined __AVX__ || defined __AVX2__ #if defined __AVX__ || defined __AVX2__
void vmul_intri8(const int n, const float* x, const float* y, float* z) { void vmul_intri8(const int n, const float* x, const float* y, float* z) {
__m256 tmpx, tmpy; __m256 tmpx, tmpy;
...@@ -591,6 +524,7 @@ void vmul_mkl(const int n, const float* x, const float* y, float* z) { ...@@ -591,6 +524,7 @@ void vmul_mkl(const int n, const float* x, const float* y, float* z) {
TEST(JitKernel, vmul) { TEST(JitKernel, vmul) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
namespace refer = paddle::operators::math::jitkernel::refer;
for (int d : {7, 8, 15, 16, 20, 30, 256, 512, 1000, 1024}) { for (int d : {7, 8, 15, 16, 20, 30, 256, 512, 1000, 1024}) {
std::vector<float> x(d), y(d); std::vector<float> x(d), y(d);
std::vector<float> zref(d), ztgt(d); std::vector<float> zref(d), ztgt(d);
...@@ -604,7 +538,7 @@ TEST(JitKernel, vmul) { ...@@ -604,7 +538,7 @@ TEST(JitKernel, vmul) {
float* zref_data = zref.data(); float* zref_data = zref.data();
auto trefs = GetCurrentUS(); auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
vmul_ref(d, x_data, y_data, zref_data); refer::VMul<float>(x_data, y_data, zref_data, d);
} }
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
...@@ -640,19 +574,13 @@ TEST(JitKernel, vmul) { ...@@ -640,19 +574,13 @@ TEST(JitKernel, vmul) {
#else #else
<< " us, " << " us, "
#endif #endif
<< "tgt takes: " << (ttgte - ttgts) / repeat; << "tgt takes: " << (ttgte - ttgts) / repeat << " us";
for (int i = 0; i < d; ++i) { for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
} }
} }
} }
void vadd_ref(const int n, const float* x, const float* y, float* z) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] + y[i];
}
}
#if defined __AVX__ || defined __AVX2__ #if defined __AVX__ || defined __AVX2__
void vadd_intri8(const int n, const float* x, const float* y, float* z) { void vadd_intri8(const int n, const float* x, const float* y, float* z) {
__m256 tmpx, tmpy; __m256 tmpx, tmpy;
...@@ -671,6 +599,7 @@ void vadd_mkl(const int n, const float* x, const float* y, float* z) { ...@@ -671,6 +599,7 @@ void vadd_mkl(const int n, const float* x, const float* y, float* z) {
TEST(JitKernel, vadd) { TEST(JitKernel, vadd) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
namespace refer = paddle::operators::math::jitkernel::refer;
for (int d : {7, 8, 15, 16, 30, 256, 512}) { for (int d : {7, 8, 15, 16, 30, 256, 512}) {
std::vector<float> x(d), y(d); std::vector<float> x(d), y(d);
std::vector<float> zref(d), ztgt(d); std::vector<float> zref(d), ztgt(d);
...@@ -684,7 +613,7 @@ TEST(JitKernel, vadd) { ...@@ -684,7 +613,7 @@ TEST(JitKernel, vadd) {
float* zref_data = zref.data(); float* zref_data = zref.data();
auto trefs = GetCurrentUS(); auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
vadd_ref(d, x_data, y_data, zref_data); refer::VAdd<float>(x_data, y_data, zref_data, d);
} }
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
...@@ -720,19 +649,13 @@ TEST(JitKernel, vadd) { ...@@ -720,19 +649,13 @@ TEST(JitKernel, vadd) {
#else #else
<< " us, " << " us, "
#endif #endif
<< "tgt takes: " << (ttgte - ttgts) / repeat; << "tgt takes: " << (ttgte - ttgts) / repeat << " us";
for (int i = 0; i < d; ++i) { for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
} }
} }
} }
void vaddrelu_ref(const int n, const float* x, const float* y, float* z) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] + y[i];
z[i] = z[i] > 0 ? z[i] : 0;
}
}
void vaddrelu_better( void vaddrelu_better(
const std::shared_ptr< const std::shared_ptr<
const paddle::operators::math::jitkernel::VAddKernel<float>>& vadd, const paddle::operators::math::jitkernel::VAddKernel<float>>& vadd,
...@@ -745,6 +668,7 @@ void vaddrelu_better( ...@@ -745,6 +668,7 @@ void vaddrelu_better(
TEST(JitKernel, vaddrelu) { TEST(JitKernel, vaddrelu) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
namespace refer = paddle::operators::math::jitkernel::refer;
for (int d : {7, 8, 15, 16, 30, 256, 512}) { for (int d : {7, 8, 15, 16, 30, 256, 512}) {
std::vector<float> x(d), y(d); std::vector<float> x(d), y(d);
std::vector<float> zref(d), ztgt(d); std::vector<float> zref(d), ztgt(d);
...@@ -762,7 +686,7 @@ TEST(JitKernel, vaddrelu) { ...@@ -762,7 +686,7 @@ TEST(JitKernel, vaddrelu) {
float* zref_data = zref.data(); float* zref_data = zref.data();
auto trefs = GetCurrentUS(); auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
vaddrelu_ref(d, x_data, y_data, zref_data); refer::VAddRelu<float>(x_data, y_data, zref_data, d);
} }
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
auto tmkls = GetCurrentUS(); auto tmkls = GetCurrentUS();
...@@ -778,7 +702,7 @@ TEST(JitKernel, vaddrelu) { ...@@ -778,7 +702,7 @@ TEST(JitKernel, vaddrelu) {
VLOG(30) << "Vec size " << d VLOG(30) << "Vec size " << d
<< ": refer takes: " << (trefe - trefs) / repeat << ": refer takes: " << (trefe - trefs) / repeat
<< " us, better takes: " << (tmkle - tmkls) / repeat << " us, " << " us, better takes: " << (tmkle - tmkls) / repeat << " us, "
<< "tgt takes: " << (ttgte - ttgts) / repeat; << "tgt takes: " << (ttgte - ttgts) / repeat << " us";
for (int i = 0; i < d; ++i) { for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
} }
...@@ -789,21 +713,23 @@ TEST(JitKernel, pool) { ...@@ -789,21 +713,23 @@ TEST(JitKernel, pool) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
const int frame_size = 4; const int frame_size = 4;
std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh"; std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh";
jit::lstm_attr_t attr(frame_size, act_gate, act_cand, act_cell, false);
// empty call it to avoid unknown flag 'use_pinned_memory' on Mac
paddle::platform::jit::MayIUse(paddle::platform::jit::avx);
const auto& plstm1 = const auto& plstm1 =
jit::KernelPool::Instance() jit::KernelPool::Instance()
.template Get<jit::LSTMKernel<float>, const std::string&, .template Get<jit::LSTMKernel<float>, const jit::lstm_attr_t&>(attr);
const std::string&, const std::string&>(
act_gate, act_cand, act_cell, frame_size, false);
const auto& plstm2 = const auto& plstm2 =
jit::KernelPool::Instance() jit::KernelPool::Instance()
.template Get<jit::LSTMKernel<float>, const std::string&, .template Get<jit::LSTMKernel<float>, const jit::lstm_attr_t&>(attr);
const std::string&, const std::string&>( EXPECT_EQ(plstm1, plstm2);
act_gate, act_cand, act_cell, frame_size, false);
const auto& peephole = const auto& peephole =
jit::KernelPool::Instance() jit::KernelPool::Instance()
.template Get<jit::LSTMKernel<float>, const std::string&, .template Get<jit::LSTMKernel<float>, const jit::lstm_attr_t&>(
const std::string&, const std::string&>( jit::lstm_attr_t(frame_size, act_gate, act_cand, act_cell, true));
act_gate, act_cand, act_cell, frame_size, true);
EXPECT_TRUE(plstm1 != peephole); EXPECT_TRUE(plstm1 != peephole);
const auto& pvmul_f = const auto& pvmul_f =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册