未验证 提交 30e47bce 编写于 作者: Q Qiyang Min 提交者: GitHub

Merge branch 'develop' into revert_vlog

...@@ -43,6 +43,8 @@ RUN wget -q https://www.python.org/ftp/python/3.7.0/Python-3.7.0.tgz && \ ...@@ -43,6 +43,8 @@ RUN wget -q https://www.python.org/ftp/python/3.7.0/Python-3.7.0.tgz && \
CFLAGS="-Wformat" ./configure --prefix=/usr/local/ --enable-shared > /dev/null && \ CFLAGS="-Wformat" ./configure --prefix=/usr/local/ --enable-shared > /dev/null && \
make -j8 > /dev/null && make altinstall > /dev/null make -j8 > /dev/null && make altinstall > /dev/null
RUN rm -r /root/python_build
RUN apt-get update && \ RUN apt-get update && \
apt-get install -y --allow-downgrades patchelf \ apt-get install -y --allow-downgrades patchelf \
python3 python3-dev python3-pip \ python3 python3-dev python3-pip \
......
...@@ -192,11 +192,14 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -192,11 +192,14 @@ class FusionGRUKernel : public framework::OpKernel<T> {
const int M = x_dims[1]; \ const int M = x_dims[1]; \
const int D = wh_dims[0]; \ const int D = wh_dims[0]; \
const int D2 = D * 2; \ const int D2 = D * 2; \
const auto& ker = math::jitkernel::KernelPool::Instance() \ const math::jitkernel::gru_attr_t attr( \
D, ctx.Attr<std::string>("gate_activation"), \
ctx.Attr<std::string>("activation")); \
math::jitkernel::gru_t one_step; \
const auto& ker = \
math::jitkernel::KernelPool::Instance() \
.template Get<math::jitkernel::GRUKernel<T>, \ .template Get<math::jitkernel::GRUKernel<T>, \
const std::string&, const std::string&>( \ const math::jitkernel::gru_attr_t&>(attr); \
ctx.Attr<std::string>("gate_activation"), \
ctx.Attr<std::string>("activation"), D); \
const T* x_data = x->data<T>(); \ const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \ const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \ const T* wh_data = wh->data<T>(); \
...@@ -237,7 +240,9 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -237,7 +240,9 @@ class FusionGRUKernel : public framework::OpKernel<T> {
if (h0_data) { if (h0_data) {
prev_hidden_data = h0_data + bid * D; prev_hidden_data = h0_data + bid * D;
} else { } else {
ker->ComputeH1(xx_data, hidden_out_data); one_step.gates = xx_data;
one_step.ht = hidden_out_data;
ker->ComputeH1(&one_step, &attr);
prev_hidden_data = hidden_out_data; prev_hidden_data = hidden_out_data;
tstart = 1; tstart = 1;
move_step(); move_step();
...@@ -247,12 +252,15 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -247,12 +252,15 @@ class FusionGRUKernel : public framework::OpKernel<T> {
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D2, D, static_cast<T>(1), blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D2, D, static_cast<T>(1),
prev_hidden_data, D, wh_data, D2, static_cast<T>(1), xx_data, prev_hidden_data, D, wh_data, D2, static_cast<T>(1), xx_data,
D3); D3);
ker->ComputeHtPart1(xx_data, prev_hidden_data, hidden_out_data); one_step.gates = xx_data;
one_step.ht_1 = prev_hidden_data;
one_step.ht = hidden_out_data;
ker->ComputeHtPart1(&one_step, &attr);
// gemm rt * Ws // gemm rt * Ws
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast<T>(1), blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast<T>(1),
hidden_out_data, D, wh_state_data, D, static_cast<T>(1), hidden_out_data, D, wh_state_data, D, static_cast<T>(1),
xx_data + D2, D3); xx_data + D2, D3);
ker->ComputeHtPart2(xx_data, prev_hidden_data, hidden_out_data); ker->ComputeHtPart2(&one_step, &attr);
// save prev // save prev
prev_hidden_data = hidden_out_data; prev_hidden_data = hidden_out_data;
move_step(); move_step();
...@@ -314,7 +322,9 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -314,7 +322,9 @@ class FusionGRUKernel : public framework::OpKernel<T> {
T* cur_out_data = batched_out_data; T* cur_out_data = batched_out_data;
// W: {W_update, W_reset; W_state} // W: {W_update, W_reset; W_state}
for (int i = 0; i < max_bs; ++i) { for (int i = 0; i < max_bs; ++i) {
ker->ComputeH1(cur_in_data, cur_out_data); one_step.gates = cur_in_data;
one_step.ht = cur_out_data;
ker->ComputeH1(&one_step, &attr);
// add offset // add offset
cur_in_data += D3; cur_in_data += D3;
cur_out_data += D; cur_out_data += D;
...@@ -339,8 +349,11 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -339,8 +349,11 @@ class FusionGRUKernel : public framework::OpKernel<T> {
T* cur_out_data = batched_out_data; T* cur_out_data = batched_out_data;
T* cur_prev_hidden_data = prev_hidden_data; T* cur_prev_hidden_data = prev_hidden_data;
for (int i = 0; i < cur_bs; ++i) { for (int i = 0; i < cur_bs; ++i) {
ker->ComputeHtPart1(cur_batched_data, cur_prev_hidden_data, one_step.gates = cur_batched_data;
cur_out_data); one_step.ht_1 = cur_prev_hidden_data;
one_step.ht = cur_out_data;
ker->ComputeHtPart1(&one_step, &attr);
cur_batched_data += D3; cur_batched_data += D3;
cur_prev_hidden_data += D; cur_prev_hidden_data += D;
cur_out_data += D; cur_out_data += D;
...@@ -354,8 +367,10 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -354,8 +367,10 @@ class FusionGRUKernel : public framework::OpKernel<T> {
cur_prev_hidden_data = prev_hidden_data; cur_prev_hidden_data = prev_hidden_data;
for (int i = 0; i < cur_bs; ++i) { for (int i = 0; i < cur_bs; ++i) {
ker->ComputeHtPart2(cur_batched_data, cur_prev_hidden_data, one_step.gates = cur_batched_data;
cur_out_data); one_step.ht_1 = cur_prev_hidden_data;
one_step.ht = cur_out_data;
ker->ComputeHtPart2(&one_step, &attr);
cur_batched_data += D3; cur_batched_data += D3;
cur_prev_hidden_data += D; cur_prev_hidden_data += D;
cur_out_data += D; cur_out_data += D;
......
...@@ -250,13 +250,17 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -250,13 +250,17 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \ auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \
checked_cell_data = checked_cell->mutable_data<T>(place); \ checked_cell_data = checked_cell->mutable_data<T>(place); \
} \ } \
const math::jitkernel::lstm_attr_t attr( \
D, ctx.Attr<std::string>("gate_activation"), \
ctx.Attr<std::string>("candidate_activation"), \
ctx.Attr<std::string>("cell_activation"), use_peepholes); \
math::jitkernel::lstm_t one_step; \
one_step.wp = wp_data; \
one_step.checked = checked_cell_data; \
const auto& ker = \ const auto& ker = \
math::jitkernel::KernelPool::Instance() \ math::jitkernel::KernelPool::Instance() \
.template Get<math::jitkernel::LSTMKernel<T>, const std::string&, \ .template Get<math::jitkernel::LSTMKernel<T>, \
const std::string&, const std::string&>( \ const math::jitkernel::lstm_attr_t&>(attr)
ctx.Attr<std::string>("gate_activation"), \
ctx.Attr<std::string>("candidate_activation"), \
ctx.Attr<std::string>("cell_activation"), D, use_peepholes)
// Wh GEMM // Wh GEMM
#define GEMM_WH_ADDON(bs, prev, out) \ #define GEMM_WH_ADDON(bs, prev, out) \
...@@ -299,7 +303,10 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -299,7 +303,10 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
prev_h_data = h0_data + bid * D; prev_h_data = h0_data + bid * D;
prev_c_data = c0_data + bid * D; prev_c_data = c0_data + bid * D;
} else { } else {
ker->ComputeC1H1(xx_data, c_out_data, h_out_data, wp_data); one_step.gates = xx_data;
one_step.ct = c_out_data;
one_step.ht = h_out_data;
ker->ComputeC1H1(&one_step, &attr);
tstart = 1; tstart = 1;
// move one step // move one step
prev_h_data = h_out_data; prev_h_data = h_out_data;
...@@ -310,8 +317,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -310,8 +317,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
} }
for (int step = tstart; step < seq_len; ++step) { for (int step = tstart; step < seq_len; ++step) {
GEMM_WH_ADDON(1, prev_h_data, xx_data); GEMM_WH_ADDON(1, prev_h_data, xx_data);
ker->ComputeCtHt(xx_data, prev_c_data, c_out_data, h_out_data, wp_data,
checked_cell_data); one_step.gates = xx_data;
one_step.ct_1 = prev_c_data;
one_step.ct = c_out_data;
one_step.ht = h_out_data;
ker->ComputeCtHt(&one_step, &attr);
// move one step // move one step
prev_h_data = h_out_data; prev_h_data = h_out_data;
prev_c_data = c_out_data; prev_c_data = c_out_data;
...@@ -388,7 +399,11 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -388,7 +399,11 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
T* cur_h_out_data = batched_h_out_data; T* cur_h_out_data = batched_h_out_data;
T* cur_c_out_data = batched_c_out_data; T* cur_c_out_data = batched_c_out_data;
for (int i = 0; i < max_bs; ++i) { for (int i = 0; i < max_bs; ++i) {
ker->ComputeC1H1(cur_in_data, cur_c_out_data, cur_h_out_data, wp_data); one_step.gates = cur_in_data;
one_step.ct = cur_c_out_data;
one_step.ht = cur_h_out_data;
ker->ComputeC1H1(&one_step, &attr);
cur_in_data += D4; cur_in_data += D4;
cur_c_out_data += D; cur_c_out_data += D;
cur_h_out_data += D; cur_h_out_data += D;
...@@ -413,8 +428,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -413,8 +428,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
T* cur_c_out_data = batched_c_out_data; T* cur_c_out_data = batched_c_out_data;
T* cur_h_out_data = batched_h_out_data; T* cur_h_out_data = batched_h_out_data;
for (int i = 0; i < cur_bs; ++i) { for (int i = 0; i < cur_bs; ++i) {
ker->ComputeCtHt(cur_in_data, cur_prev_c_data, cur_c_out_data, one_step.gates = cur_in_data;
cur_h_out_data, wp_data, checked_cell_data); one_step.ct_1 = cur_prev_c_data;
one_step.ct = cur_c_out_data;
one_step.ht = cur_h_out_data;
ker->ComputeCtHt(&one_step, &attr);
// move one batch // move one batch
cur_in_data += D4; cur_in_data += D4;
cur_prev_c_data += D; cur_prev_c_data += D;
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/jit_code.h" #include "paddle/fluid/operators/math/jit_code.h"
#include <stddef.h> // offsetof
#include "paddle/fluid/operators/math/jit_kernel.h" // TODO(TJ): remove me #include "paddle/fluid/operators/math/jit_kernel.h" // TODO(TJ): remove me
namespace paddle { namespace paddle {
...@@ -139,32 +140,10 @@ bool VActJitCode::init(int d, operand_type type) { ...@@ -139,32 +140,10 @@ bool VActJitCode::init(int d, operand_type type) {
} }
void VActJitCode::generate() { void VActJitCode::generate() {
xmm_t xmm_zero = xmm_t(2);
ymm_t ymm_zero = ymm_t(2);
if (type_ == operand_type::relu) {
vxorps(ymm_zero, ymm_zero, ymm_zero);
}
int offset = 0; int offset = 0;
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) { for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
vmovups(ymm_src, ptr[param1 + offset]); vmovups(ymm_src, ptr[param1 + offset]);
switch (type_) { act<ymm_t>(ymm_dst, ymm_src, type_);
case operand_type::relu:
relu_jmm<ymm_t>(ymm_dst, ymm_src, ymm_zero);
break;
case operand_type::exp:
exp_jmm<ymm_t>(ymm_dst, ymm_src, 2, 3, 4, 5);
break;
case operand_type::sigmoid:
sigmoid_jmm<ymm_t>(ymm_dst, ymm_src, 2, 3, 4, 5);
break;
case operand_type::tanh:
tanh_jmm<ymm_t>(ymm_dst, ymm_src, 2, 3, 4, 5);
break;
case operand_type::identity:
break;
default:
break;
}
vmovups(ptr[param2 + offset], ymm_dst); vmovups(ptr[param2 + offset], ymm_dst);
offset += sizeof(float) * YMM_FLOAT_BLOCK; offset += sizeof(float) * YMM_FLOAT_BLOCK;
} }
...@@ -181,22 +160,7 @@ void VActJitCode::generate() { ...@@ -181,22 +160,7 @@ void VActJitCode::generate() {
block = 1; block = 1;
vmovss(xmm_src, ptr[param1 + offset]); vmovss(xmm_src, ptr[param1 + offset]);
} }
switch (type_) { act<xmm_t>(xmm_dst, xmm_src, type_);
case operand_type::relu:
relu_jmm<xmm_t>(xmm_dst, xmm_src, xmm_zero);
break;
case operand_type::exp:
exp_jmm<xmm_t>(xmm_dst, xmm_src, 2, 3, 4, 5);
break;
case operand_type::sigmoid:
sigmoid_jmm<xmm_t>(xmm_dst, xmm_src, 2, 3, 4, 5);
break;
case operand_type::tanh:
tanh_jmm<xmm_t>(xmm_dst, xmm_src, 2, 3, 4, 5);
break;
default:
break;
}
if (rest >= 4) { if (rest >= 4) {
vmovups(ptr[param2 + offset], xmm_dst); vmovups(ptr[param2 + offset], xmm_dst);
} else if (rest >= 2) { } else if (rest >= 2) {
...@@ -210,6 +174,158 @@ void VActJitCode::generate() { ...@@ -210,6 +174,158 @@ void VActJitCode::generate() {
ret(); ret();
} }
bool LSTMJitCode::init(int d) { return MayIUse(avx) && d % 8 == 0; }
void LSTMJitCode::generate() {
if (use_peephole_) {
preCode();
}
reg64_t reg_ptr_gates = rax;
reg64_t reg_ptr_ct_1 = r9;
reg64_t reg_ptr_ct = r10;
reg64_t reg_ptr_ht = r11;
reg64_t reg_ptr_wp = r12;
mov(reg_ptr_gates, ptr[param1 + offsetof(lstm_t, gates)]);
mov(reg_ptr_ct_1, ptr[param1 + offsetof(lstm_t, ct_1)]);
mov(reg_ptr_ct, ptr[param1 + offsetof(lstm_t, ct)]);
mov(reg_ptr_ht, ptr[param1 + offsetof(lstm_t, ht)]);
if (use_peephole_) {
mov(reg_ptr_wp, ptr[param1 + offsetof(lstm_t, wp)]);
}
int offset = 0;
int d = num_ * sizeof(float);
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
/* gates: W_ch, W_ih, W_fh, W_oh */
ymm_t ymm_c = ymm_t(0);
ymm_t ymm_i = ymm_t(1);
ymm_t ymm_f = ymm_t(2);
ymm_t ymm_o = ymm_t(3);
ymm_t ymm_ct_1 = ymm_t(4);
ymm_t ymm_wp0 = ymm_t(5);
ymm_t ymm_wp1 = ymm_t(6);
ymm_t ymm_wp2 = ymm_t(7);
vmovups(ymm_c, ptr[reg_ptr_gates + offset]);
vmovups(ymm_i, ptr[reg_ptr_gates + offset + d]);
vmovups(ymm_f, ptr[reg_ptr_gates + offset + 2 * d]);
vmovups(ymm_o, ptr[reg_ptr_gates + offset + 3 * d]);
if (!compute_c1h1_) {
vmovups(ymm_ct_1, ptr[reg_ptr_ct_1 + offset]);
}
if (use_peephole_) {
vmovups(ymm_wp0, ptr[reg_ptr_wp + offset]);
vmovups(ymm_wp1, ptr[reg_ptr_wp + offset + d]);
vmovups(ymm_wp2, ptr[reg_ptr_wp + offset + 2 * d]);
}
/* C_t = act_cand(c) * act_gate(i) + C_t-1 * act_gate(f) */
// act_cand(c)
act<ymm_t>(ymm_c, ymm_c, act_cand_);
// act_gate(i) or act_gate(ct_1 * wp0 + i)
if (!compute_c1h1_ && use_peephole_) {
vmulps(ymm_wp0, ymm_ct_1, ymm_wp0);
vaddps(ymm_i, ymm_i, ymm_wp0);
}
act<ymm_t>(ymm_i, ymm_i, act_gate_);
vmulps(ymm_c, ymm_c, ymm_i);
if (!compute_c1h1_) {
// act_gate(f) or act_gate(ct_1 * wp1 + f)
if (use_peephole_) {
vmulps(ymm_wp1, ymm_ct_1, ymm_wp1);
vaddps(ymm_f, ymm_f, ymm_wp1);
}
act<ymm_t>(ymm_f, ymm_f, act_gate_);
// ct
vmulps(ymm_f, ymm_f, ymm_ct_1);
vaddps(ymm_f, ymm_f, ymm_c);
}
/* H_t = act_cell(C_t) * act_gate(o) */
// act_cell(C_t)
ymm_t ymm_ct = compute_c1h1_ ? ymm_c : ymm_f;
ymm_t ymm_tmp = ymm_i;
act<ymm_t>(ymm_tmp, ymm_ct, act_cell_);
// act_gate(o) or act_gate(ct * wp2 + o)
if (use_peephole_) {
vmulps(ymm_wp2, ymm_ct, ymm_wp2);
vaddps(ymm_o, ymm_o, ymm_wp2);
}
act<ymm_t>(ymm_o, ymm_o, act_gate_);
// ht
vmulps(ymm_o, ymm_o, ymm_tmp);
// save ct and ht
vmovups(ptr[reg_ptr_ct + offset], ymm_ct);
vmovups(ptr[reg_ptr_ht + offset], ymm_o);
offset += sizeof(float) * YMM_FLOAT_BLOCK;
}
if (use_peephole_) {
postCode();
} else {
ret();
}
}
bool GRUJitCode::init(int d) { return MayIUse(avx) && d % 8 == 0; }
void GRUJitCode::generate() {
reg64_t reg_ptr_gates = rax;
reg64_t reg_ptr_ht_1 = r9;
reg64_t reg_ptr_ht = r10;
mov(reg_ptr_gates, ptr[param1 + offsetof(gru_t, gates)]);
mov(reg_ptr_ht_1, ptr[param1 + offsetof(gru_t, ht_1)]);
mov(reg_ptr_ht, ptr[param1 + offsetof(gru_t, ht)]);
ymm_t ymm_one = ymm_t(0);
if (id_ == 2) {
reg64_t reg_ptr_tmp = r11;
mov(reg_ptr_tmp, reinterpret_cast<size_t>(exp_float_consts));
vmovaps(ymm_one, ptr[reg_ptr_tmp + OFFSET_EXP_ONE]);
}
int offset = 0;
int d = num_ * sizeof(float);
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
ymm_t ymm_u = ymm_t(1);
ymm_t ymm_r = ymm_t(2);
ymm_t ymm_s = ymm_t(3);
ymm_t ymm_ht_1 = ymm_t(4);
// W: {W_update, W_reset; W_state}
if (id_ == 0 || id_ == 2) {
vmovups(ymm_u, ptr[reg_ptr_gates + offset]);
vmovups(ymm_s, ptr[reg_ptr_gates + offset + 2 * d]);
}
if (id_ == 1) {
vmovups(ymm_r, ptr[reg_ptr_gates + offset + d]);
}
if (id_ == 1 || id_ == 2) {
vmovups(ymm_ht_1, ptr[reg_ptr_ht_1 + offset]);
}
if (id_ == 0) {
// ht = act_gate(u) * act_cand(s)
act<ymm_t>(ymm_u, ymm_u, act_gate_);
act<ymm_t>(ymm_s, ymm_s, act_cand_);
vmulps(ymm_s, ymm_s, ymm_u);
vmovups(ptr[reg_ptr_ht + offset], ymm_s);
} else if (id_ == 1) {
// ht = act_gate(r) * ht_1
act<ymm_t>(ymm_r, ymm_r, act_gate_);
vmulps(ymm_r, ymm_r, ymm_ht_1);
vmovups(ptr[reg_ptr_ht + offset], ymm_r);
} else if (id_ == 2) {
// ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1
ymm_t ymm_one_inner = ymm_t(ymm_one.getIdx());
act<ymm_t>(ymm_u, ymm_u, act_gate_);
act<ymm_t>(ymm_s, ymm_s, act_cand_);
vmulps(ymm_s, ymm_s, ymm_u);
vsubps(ymm_u, ymm_one_inner, ymm_u);
vmulps(ymm_u, ymm_ht_1, ymm_u);
vaddps(ymm_u, ymm_s, ymm_u);
vmovups(ptr[reg_ptr_ht + offset], ymm_u);
}
offset += sizeof(float) * YMM_FLOAT_BLOCK;
}
ret();
}
} // namespace gen } // namespace gen
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/operators/math/jit_gen.h" #include "paddle/fluid/operators/math/jit_gen.h"
#include "paddle/fluid/operators/math/jit_kernel_impl.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
namespace paddle { namespace paddle {
...@@ -46,14 +47,6 @@ extern const float exp_float_consts[]; ...@@ -46,14 +47,6 @@ extern const float exp_float_consts[];
extern const int exp_int_0x7f[]; extern const int exp_int_0x7f[];
extern int g_tmp_mem[]; extern int g_tmp_mem[];
// TODO(TJ): move these to some proper place
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
#define XMM_FLOAT_BLOCK 4
#define YMM_FLOAT_BLOCK 8
#define ZMM_FLOAT_BLOCK 16
#define ALIGN32 __attribute__((aligned(32))) #define ALIGN32 __attribute__((aligned(32)))
#define EXP_HIG 88.3762626647949f #define EXP_HIG 88.3762626647949f
#define EXP_LOW -88.3762626647949f #define EXP_LOW -88.3762626647949f
...@@ -176,31 +169,34 @@ class VActJitCode : public JitCode { ...@@ -176,31 +169,34 @@ class VActJitCode : public JitCode {
protected: protected:
// compute relu with ymm, xmm // compute relu with ymm, xmm
template <typename JMM> template <typename JMM>
void relu_jmm(JMM& dst, JMM& src, JMM& zero) { // NOLINT void relu_jmm(JMM& dst, JMM& src, int zero_idx = 15) { // NOLINT
JMM zero = JMM(zero_idx);
vxorps(zero, zero, zero);
vmaxps(dst, src, zero); vmaxps(dst, src, zero);
} }
// compute exp with ymm, xmm // compute exp with ymm, xmm
template <typename JMM> template <typename JMM>
void exp_jmm(JMM& dst, JMM& src, int fx_idx = 2, int fy_idx = 3, // NOLINT void exp_jmm(JMM& dst, JMM& src, int src_idx = 11, int fx_idx = 12, // NOLINT
int mask_idx = 4, int tmp_idx = 5) { int fy_idx = 13, int mask_idx = 14, int tmp_idx = 15) {
using namespace platform::jit; // NOLINT using namespace platform::jit; // NOLINT
assert(src.getIdx() != dst.getIdx()); // TODO(TJ): use enfore
// check all idx can not equal // check all idx can not equal
JMM jmm_src = JMM(src_idx);
JMM jmm_fx = JMM(fx_idx); JMM jmm_fx = JMM(fx_idx);
JMM jmm_fy = JMM(fy_idx); JMM jmm_fy = JMM(fy_idx);
JMM jmm_mask = JMM(mask_idx); JMM jmm_mask = JMM(mask_idx);
JMM jmm_tmp = JMM(tmp_idx); JMM jmm_tmp = JMM(tmp_idx);
reg64_t reg_ptr_global = rax; reg64_t reg_ptr_global = rax;
push(reg_ptr_global); push(reg_ptr_global);
vmovaps(jmm_src, src);
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts)); mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]);
vminps(src, src, jmm_tmp); vminps(jmm_src, jmm_src, jmm_tmp);
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOW]); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOW]);
vmaxps(src, src, jmm_tmp); vmaxps(jmm_src, jmm_src, jmm_tmp);
// express exp(x) as exp(g + n*log(2)) // express exp(x) as exp(g + n*log(2))
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOG2EF]); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOG2EF]);
vmulps(jmm_fx, src, jmm_tmp); vmulps(jmm_fx, jmm_src, jmm_tmp);
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_0P5]); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_0P5]);
vaddps(jmm_fx, jmm_fx, jmm_tmp); vaddps(jmm_fx, jmm_fx, jmm_tmp);
vroundps(jmm_fy, jmm_fx, 0x01); vroundps(jmm_fy, jmm_fx, 0x01);
...@@ -214,21 +210,21 @@ class VActJitCode : public JitCode { ...@@ -214,21 +210,21 @@ class VActJitCode : public JitCode {
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C2]); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C2]);
JMM ymm_z = JMM(jmm_mask.getIdx()); JMM ymm_z = JMM(jmm_mask.getIdx());
vmulps(ymm_z, jmm_fx, jmm_tmp); vmulps(ymm_z, jmm_fx, jmm_tmp);
vsubps(src, src, jmm_fy); vsubps(jmm_src, jmm_src, jmm_fy);
vsubps(src, src, ymm_z); vsubps(jmm_src, jmm_src, ymm_z);
vmulps(ymm_z, src, src); vmulps(ymm_z, jmm_src, jmm_src);
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P0]); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P0]);
vmulps(dst, src, jmm_tmp); vmulps(dst, jmm_src, jmm_tmp);
for (size_t i = OFFSET_EXP_P1; i < OFFSET_EXP_P5; for (size_t i = OFFSET_EXP_P1; i < OFFSET_EXP_P5;
i += (YMM_FLOAT_BLOCK * sizeof(float))) { i += (YMM_FLOAT_BLOCK * sizeof(float))) {
vmovaps(jmm_tmp, ptr[reg_ptr_global + i]); // P1~P4 vmovaps(jmm_tmp, ptr[reg_ptr_global + i]); // P1~P4
vaddps(dst, dst, jmm_tmp); vaddps(dst, dst, jmm_tmp);
vmulps(dst, dst, src); vmulps(dst, dst, jmm_src);
} }
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P5]); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P5]);
vaddps(dst, dst, jmm_tmp); vaddps(dst, dst, jmm_tmp);
vmulps(dst, dst, ymm_z); vmulps(dst, dst, ymm_z);
vaddps(dst, dst, src); vaddps(dst, dst, jmm_src);
vmovaps(jmm_tmp, ptr[reg_ptr_global]); vmovaps(jmm_tmp, ptr[reg_ptr_global]);
vaddps(dst, dst, jmm_tmp); vaddps(dst, dst, jmm_tmp);
// build 2^n // build 2^n
...@@ -265,20 +261,23 @@ class VActJitCode : public JitCode { ...@@ -265,20 +261,23 @@ class VActJitCode : public JitCode {
// compute sigmoid with ymm, xmm // compute sigmoid with ymm, xmm
template <typename JMM> template <typename JMM>
void sigmoid_jmm(JMM& dst, JMM& src, int fx_idx = 2, // NOLINT void sigmoid_jmm(JMM& dst, JMM& src, int src_idx = 11, // NOLINT
int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5) { int fx_idx = 12, int fy_idx = 13, int mask_idx = 14,
int tmp_idx = 15) {
// y = 1 / (1 + e^-x) // y = 1 / (1 + e^-x)
JMM jmm_tmp = JMM(tmp_idx); JMM jmm_tmp = JMM(tmp_idx);
JMM jmm_src = JMM(src_idx);
reg64_t reg_ptr_global = rax; reg64_t reg_ptr_global = rax;
push(reg_ptr_global); push(reg_ptr_global);
vmovaps(jmm_src, src);
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts)); mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MAX]); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MAX]);
vminps(src, src, jmm_tmp); vminps(jmm_src, jmm_src, jmm_tmp);
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MIN]); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MIN]);
vmaxps(src, src, jmm_tmp); vmaxps(jmm_src, jmm_src, jmm_tmp);
vxorps(jmm_tmp, jmm_tmp, jmm_tmp); vxorps(jmm_tmp, jmm_tmp, jmm_tmp);
vsubps(src, jmm_tmp, src); vsubps(jmm_src, jmm_tmp, jmm_src);
exp_jmm<JMM>(dst, src, fx_idx, fy_idx, mask_idx, tmp_idx); exp_jmm<JMM>(dst, jmm_src, src_idx, fx_idx, fy_idx, mask_idx, tmp_idx);
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
vaddps(dst, dst, jmm_tmp); vaddps(dst, dst, jmm_tmp);
vdivps(dst, jmm_tmp, dst); vdivps(dst, jmm_tmp, dst);
...@@ -287,19 +286,22 @@ class VActJitCode : public JitCode { ...@@ -287,19 +286,22 @@ class VActJitCode : public JitCode {
// compute tanh with ymm, xmm // compute tanh with ymm, xmm
template <typename JMM> template <typename JMM>
void tanh_jmm(JMM& dst, JMM& src, int fx_idx = 2, int fy_idx = 3, // NOLINT void tanh_jmm(JMM& dst, JMM& src, int src_idx = 11, // NOLINT
int mask_idx = 4, int tmp_idx = 5) { int fx_idx = 12, int fy_idx = 13, int mask_idx = 14,
int tmp_idx = 15) {
// y = 2 / (1 + e^(-2x)) - 1 // y = 2 / (1 + e^(-2x)) - 1
JMM jmm_src = JMM(src_idx);
JMM jmm_tmp = JMM(tmp_idx); JMM jmm_tmp = JMM(tmp_idx);
JMM jmm_zero = JMM(mask_idx); JMM jmm_zero = JMM(mask_idx);
reg64_t reg_ptr_global = rax; reg64_t reg_ptr_global = rax;
push(reg_ptr_global); push(reg_ptr_global);
vmovaps(jmm_src, src);
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts)); mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
vxorps(jmm_zero, jmm_zero, jmm_zero); vxorps(jmm_zero, jmm_zero, jmm_zero);
vsubps(jmm_tmp, jmm_zero, jmm_tmp); vsubps(jmm_tmp, jmm_zero, jmm_tmp);
vmulps(src, src, jmm_tmp); vmulps(jmm_src, jmm_src, jmm_tmp);
exp_jmm<JMM>(dst, src, fx_idx, fy_idx, mask_idx, tmp_idx); exp_jmm<JMM>(dst, jmm_src, src_idx, fx_idx, fy_idx, mask_idx, tmp_idx);
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
vaddps(dst, dst, jmm_tmp); vaddps(dst, dst, jmm_tmp);
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
...@@ -309,6 +311,30 @@ class VActJitCode : public JitCode { ...@@ -309,6 +311,30 @@ class VActJitCode : public JitCode {
pop(reg_ptr_global); pop(reg_ptr_global);
} }
template <typename JMM>
void act(JMM& dst, JMM& src, operand_type type) { // NOLINT
// use 11~15
switch (type) {
case operand_type::relu:
relu_jmm<JMM>(dst, src, 15);
break;
case operand_type::exp:
exp_jmm<JMM>(dst, src, 11, 12, 13, 14, 15);
break;
case operand_type::sigmoid:
sigmoid_jmm<JMM>(dst, src, 11, 12, 13, 14, 15);
break;
case operand_type::tanh:
tanh_jmm<JMM>(dst, src, 11, 12, 13, 14, 15);
break;
case operand_type::identity:
break;
default:
// throw error
break;
}
}
protected: protected:
int num_; int num_;
operand_type type_; operand_type type_;
...@@ -322,6 +348,148 @@ class VActJitCode : public JitCode { ...@@ -322,6 +348,148 @@ class VActJitCode : public JitCode {
ymm_t ymm_dst = ymm_t(1); ymm_t ymm_dst = ymm_t(1);
}; };
class LSTMJitCode : public VActJitCode {
public:
const char* name() const override {
std::string base = "LSTMJitCode";
if (use_peephole_) {
base += "_Peephole";
}
if (compute_c1h1_) {
base += "_C1H1";
}
auto AddTypeStr = [&](operand_type type) {
switch (type) {
case operand_type::relu:
base += "_Relu";
break;
case operand_type::exp:
base += "_Exp";
break;
case operand_type::sigmoid:
base += "_Sigmoid";
break;
case operand_type::tanh:
base += "_Tanh";
break;
case operand_type::identity:
base += "_Identity";
break;
default:
break;
}
};
AddTypeStr(act_gate_);
AddTypeStr(act_cand_);
AddTypeStr(act_cell_);
return base.c_str();
}
explicit LSTMJitCode(bool compute_c1h1, const lstm_attr_t& attr,
size_t code_size = 256 * 1024, void* code_ptr = nullptr)
: VActJitCode(attr.d, operand_type::sigmoid /* this is bugy*/, code_size,
code_ptr),
compute_c1h1_(compute_c1h1) {
auto typeExchange = [](const std::string& type) -> gen::operand_type {
if (type == "sigmoid") {
return operand_type::sigmoid;
} else if (type == "relu") {
return operand_type::relu;
} else if (type == "tanh") {
return operand_type::tanh;
} else if (type == "identity" || type == "") {
return operand_type::identity;
} // else throw error
return operand_type::identity;
};
num_ = attr.d;
use_peephole_ = attr.use_peephole;
act_gate_ = typeExchange(attr.act_gate);
act_cand_ = typeExchange(attr.act_cand);
act_cell_ = typeExchange(attr.act_cell);
}
static bool init(int d);
void generate() override;
protected:
int num_;
bool compute_c1h1_;
bool use_peephole_;
operand_type act_gate_;
operand_type act_cand_;
operand_type act_cell_;
reg64_t param1{abi_param1};
};
class GRUJitCode : public VActJitCode {
public:
const char* name() const override {
std::string base = "GRUJitCode";
if (id_ == 0) {
base += "_H1";
} else if (id_ == 1) {
base += "_HtPart1";
} else if (id_ == 2) {
base += "_HtPart2";
}
auto AddTypeStr = [&](operand_type type) {
switch (type) {
case operand_type::relu:
base += "_Relu";
break;
case operand_type::exp:
base += "_Exp";
break;
case operand_type::sigmoid:
base += "_Sigmoid";
break;
case operand_type::tanh:
base += "_Tanh";
break;
case operand_type::identity:
base += "_Identity";
break;
default:
break;
}
};
AddTypeStr(act_gate_);
AddTypeStr(act_cand_);
return base.c_str();
}
explicit GRUJitCode(int id, const gru_attr_t& attr,
size_t code_size = 256 * 1024, void* code_ptr = nullptr)
: VActJitCode(attr.d, operand_type::sigmoid /* this is bugy*/, code_size,
code_ptr),
id_(id) {
auto typeExchange = [](const std::string& type) -> gen::operand_type {
if (type == "sigmoid") {
return operand_type::sigmoid;
} else if (type == "relu") {
return operand_type::relu;
} else if (type == "tanh") {
return operand_type::tanh;
} else if (type == "identity" || type == "") {
return operand_type::identity;
} // else throw error
return operand_type::identity;
};
num_ = attr.d;
act_gate_ = typeExchange(attr.act_gate);
act_cand_ = typeExchange(attr.act_cand);
}
static bool init(int d);
void generate() override;
protected:
int id_;
int num_;
operand_type act_gate_;
operand_type act_cand_;
reg64_t param1{abi_param1};
};
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
struct EltwiseMulnChw16cNC : public Xbyak::CodeGenerator { struct EltwiseMulnChw16cNC : public Xbyak::CodeGenerator {
explicit EltwiseMulnChw16cNC(size_t code_size = 256 * 1024) explicit EltwiseMulnChw16cNC(size_t code_size = 256 * 1024)
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <memory> // for shared_ptr #include <memory> // for shared_ptr
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "paddle/fluid/operators/math/jit_kernel_impl.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
...@@ -26,14 +27,7 @@ namespace operators { ...@@ -26,14 +27,7 @@ namespace operators {
namespace math { namespace math {
namespace jitkernel { namespace jitkernel {
// TODO(TJ): move these to some proper place // TODO(TJ): remove me
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
#define XMM_FLOAT_BLOCK 4
#define YMM_FLOAT_BLOCK 8
#define ZMM_FLOAT_BLOCK 16
typedef enum { kLT8, kEQ8, kGT8LT16, kEQ16, kGT16 } jit_block; typedef enum { kLT8, kEQ8, kGT8LT16, kEQ16, kGT16 } jit_block;
class Kernel { class Kernel {
...@@ -128,24 +122,18 @@ class VTanhKernel : public VActKernel<T> {}; ...@@ -128,24 +122,18 @@ class VTanhKernel : public VActKernel<T> {};
template <typename T> template <typename T>
class LSTMKernel : public Kernel { class LSTMKernel : public Kernel {
public: public:
virtual void ComputeCtHt(T *gates, const T *ct_1, T *ct, T *ht,
/* below only used in peephole*/
const T *wp_data = nullptr,
T *checked = nullptr) const = 0;
// compute c1 and h1 without c0 or h0 // compute c1 and h1 without c0 or h0
virtual void ComputeC1H1(T *gates, T *ct, T *ht, void (*ComputeC1H1)(lstm_t *, const lstm_attr_t *);
/* below only used in peephole*/ void (*ComputeCtHt)(lstm_t *, const lstm_attr_t *);
const T *wp_data = nullptr) const = 0;
}; };
template <typename T> template <typename T>
class GRUKernel : public Kernel { class GRUKernel : public Kernel {
public: public:
// compute h1 without h0 // compute h1 without h0
virtual void ComputeH1(T *gates, T *ht) const = 0; void (*ComputeH1)(gru_t *, const gru_attr_t *);
virtual void ComputeHtPart1(T *gates, const T *ht_1, T *ht) const = 0; void (*ComputeHtPart1)(gru_t *, const gru_attr_t *);
virtual void ComputeHtPart2(T *gates, const T *ht_1, T *ht) const = 0; void (*ComputeHtPart2)(gru_t *, const gru_attr_t *);
}; };
template <typename T> template <typename T>
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h" #include "paddle/fluid/operators/math/jit_kernel.h"
#include <string> #include <string>
#include "paddle/fluid/operators/math/jit_kernel_macro.h" #include "paddle/fluid/operators/math/jit_kernel_macro.h"
#include "paddle/fluid/operators/math/jit_kernel_refer.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
...@@ -31,49 +32,6 @@ namespace math { ...@@ -31,49 +32,6 @@ namespace math {
namespace jitkernel { namespace jitkernel {
namespace jit = platform::jit; namespace jit = platform::jit;
template <typename T>
void VMulRefer(const T* x, const T* y, T* z, int n) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] * y[i];
}
}
template <typename T>
void VAddRefer(const T* x, const T* y, T* z, int n) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] + y[i];
}
}
template <typename T>
void VAddReluRefer(const T* x, const T* y, T* z, int n) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] + y[i];
z[i] = z[i] > 0 ? z[i] : 0;
}
}
template <typename T>
void VScalRefer(const T* a, const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = a[0] * x[i];
}
}
template <typename T>
void VAddBiasRefer(const T* a, const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = a[0] + x[i];
}
}
template <typename T>
void VReluRefer(const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = x[i] > 0 ? x[i] : 0;
}
}
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
template <typename T> template <typename T>
void VMulMKL(const T* x, const T* y, T* z, int n); void VMulMKL(const T* x, const T* y, T* z, int n);
...@@ -109,7 +67,7 @@ void VScalMKL<float>(const float* a, const float* x, float* y, int n) { ...@@ -109,7 +67,7 @@ void VScalMKL<float>(const float* a, const float* x, float* y, int n) {
if (x == y) { if (x == y) {
platform::dynload::cblas_sscal(n, *a, y, 1); platform::dynload::cblas_sscal(n, *a, y, 1);
} else { } else {
VScalRefer<float>(a, x, y, n); refer::VScal<float>(a, x, y, n);
} }
} }
...@@ -118,7 +76,7 @@ void VScalMKL<double>(const double* a, const double* x, double* y, int n) { ...@@ -118,7 +76,7 @@ void VScalMKL<double>(const double* a, const double* x, double* y, int n) {
if (x == y) { if (x == y) {
platform::dynload::cblas_dscal(n, *a, y, 1); platform::dynload::cblas_dscal(n, *a, y, 1);
} else { } else {
VScalRefer<double>(a, x, y, n); refer::VScal<double>(a, x, y, n);
} }
} }
...@@ -147,7 +105,7 @@ class VMulKernelImpl : public VMulKernel<T> { ...@@ -147,7 +105,7 @@ class VMulKernelImpl : public VMulKernel<T> {
return; return;
} }
#endif #endif
this->Compute = VMulRefer<T>; this->Compute = refer::VMul<T>;
} }
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
...@@ -198,7 +156,7 @@ class VAddKernelImpl : public VAddKernel<T> { ...@@ -198,7 +156,7 @@ class VAddKernelImpl : public VAddKernel<T> {
return; return;
} }
#endif #endif
this->Compute = VAddRefer<T>; this->Compute = refer::VAdd<T>;
} }
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
...@@ -280,7 +238,7 @@ class VAddReluKernelImpl : public VAddReluKernel<T> { ...@@ -280,7 +238,7 @@ class VAddReluKernelImpl : public VAddReluKernel<T> {
return; return;
} }
#endif #endif
this->Compute = VAddReluRefer<T>; this->Compute = refer::VAddRelu<T>;
} }
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
...@@ -318,7 +276,7 @@ class VScalKernelImpl : public VScalKernel<T> { ...@@ -318,7 +276,7 @@ class VScalKernelImpl : public VScalKernel<T> {
return; return;
} }
#endif #endif
this->Compute = VScalRefer<T>; this->Compute = refer::VScal<T>;
} }
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
...@@ -362,7 +320,7 @@ class VAddBiasKernelImpl : public VAddBiasKernel<T> { ...@@ -362,7 +320,7 @@ class VAddBiasKernelImpl : public VAddBiasKernel<T> {
} }
#endif #endif
this->Compute = VAddBiasRefer<T>; this->Compute = refer::VAddBias<T>;
} }
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
...@@ -396,7 +354,7 @@ class VReluKernelImpl : public VReluKernel<T> { ...@@ -396,7 +354,7 @@ class VReluKernelImpl : public VReluKernel<T> {
} }
#endif #endif
this->Compute = VReluRefer<T>; this->Compute = refer::VRelu<T>;
} }
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
...@@ -412,16 +370,13 @@ bool VReluKernelImpl<float>::useJIT(int d) { ...@@ -412,16 +370,13 @@ bool VReluKernelImpl<float>::useJIT(int d) {
} }
#endif #endif
template <typename T>
inline void VIdentityRefer(const T* x, T* y, int n) {}
/* An empty JitKernel */ /* An empty JitKernel */
template <typename T> template <typename T>
class VIdentityKernelImpl : public VIdentityKernel<T> { class VIdentityKernelImpl : public VIdentityKernel<T> {
public: public:
JITKERNEL_DECLARE_STATIC_FUNC; JITKERNEL_DECLARE_STATIC_FUNC;
explicit VIdentityKernelImpl(int d) : VIdentityKernel<T>() { explicit VIdentityKernelImpl(int d) : VIdentityKernel<T>() {
this->Compute = VIdentityRefer<T>; this->Compute = refer::VIdentity<T>;
} }
}; };
......
...@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h" #include "paddle/fluid/operators/math/jit_kernel.h"
#include <cmath> // for exp
#include <string> #include <string>
#include "paddle/fluid/operators/math/jit_kernel_macro.h" #include "paddle/fluid/operators/math/jit_kernel_macro.h"
#include "paddle/fluid/operators/math/jit_kernel_refer.h"
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
#include "paddle/fluid/operators/math/jit_code.h" #include "paddle/fluid/operators/math/jit_code.h"
...@@ -25,48 +25,12 @@ limitations under the License. */ ...@@ -25,48 +25,12 @@ limitations under the License. */
#include "paddle/fluid/platform/dynload/mklml.h" #include "paddle/fluid/platform/dynload/mklml.h"
#endif #endif
#ifdef __AVX__
#include <immintrin.h>
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
namespace jitkernel { namespace jitkernel {
namespace jit = platform::jit; namespace jit = platform::jit;
// TODO(TJ): move refer codes to one file
// Refer code only focus on correctness
template <typename T>
void VExpRefer(const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = std::exp(x[i]);
}
}
template <typename T>
void VSigmoidRefer(const T* x, T* y, int n) {
// y = 1 / (1 + e^-x)
const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX;
for (int i = 0; i < n; ++i) {
T tmp = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = static_cast<T>(1) / (static_cast<T>(1) + std::exp(-tmp));
}
}
template <typename T>
void VTanhRefer(const T* x, T* y, int n) {
// y = 2 * sigmoid(2x) - 1
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(2) * x[i];
}
VSigmoidRefer(y, y, n);
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(2) * y[i] - static_cast<T>(1);
}
}
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
// try to use MKL to speedup // try to use MKL to speedup
template <typename T> template <typename T>
...@@ -129,7 +93,7 @@ class VExpKernelImpl : public VExpKernel<T> { ...@@ -129,7 +93,7 @@ class VExpKernelImpl : public VExpKernel<T> {
return; return;
} }
#endif #endif
this->Compute = VExpRefer<T>; this->Compute = refer::VExp<T>;
} }
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
...@@ -182,7 +146,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> { ...@@ -182,7 +146,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
return; return;
} }
#endif #endif
this->Compute = VSigmoidRefer<T>; this->Compute = refer::VSigmoid<T>;
} }
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
...@@ -234,7 +198,7 @@ class VTanhKernelImpl : public VTanhKernel<T> { ...@@ -234,7 +198,7 @@ class VTanhKernelImpl : public VTanhKernel<T> {
return; return;
} }
#endif #endif
this->Compute = VTanhRefer<T>; this->Compute = refer::VTanh<T>;
} }
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
...@@ -267,154 +231,6 @@ REGISTER_JITKERNEL(vexp, VExpKernel); ...@@ -267,154 +231,6 @@ REGISTER_JITKERNEL(vexp, VExpKernel);
REGISTER_JITKERNEL(vsigmoid, VSigmoidKernel); REGISTER_JITKERNEL(vsigmoid, VSigmoidKernel);
REGISTER_JITKERNEL(vtanh, VTanhKernel); REGISTER_JITKERNEL(vtanh, VTanhKernel);
namespace detail {
#ifdef __AVX__
#define ALIGN32 __attribute__((aligned(32)))
#define _PS256_CONST(Name, Val) \
static const float _ps256_##Name[8] ALIGN32 = {Val, Val, Val, Val, \
Val, Val, Val, Val}
#define _PI256_CONST(Name, Val) \
static const int _pi256_##Name[8] ALIGN32 = {Val, Val, Val, Val, \
Val, Val, Val, Val}
_PI256_CONST(0x7f, 0x7f);
_PS256_CONST(one, 1.f);
_PS256_CONST(0p5, 0.5f);
_PS256_CONST(exp_hi, 88.3762626647949f);
_PS256_CONST(exp_lo, -88.3762626647949f);
_PS256_CONST(cephes_LOG2EF, 1.44269504088896341);
_PS256_CONST(cephes_exp_C1, 0.693359375);
_PS256_CONST(cephes_exp_C2, -2.12194440e-4);
_PS256_CONST(cephes_exp_p0, 1.9875691500E-4);
_PS256_CONST(cephes_exp_p1, 1.3981999507E-3);
_PS256_CONST(cephes_exp_p2, 8.3334519073E-3);
_PS256_CONST(cephes_exp_p3, 4.1665795894E-2);
_PS256_CONST(cephes_exp_p4, 1.6666665459E-1);
_PS256_CONST(cephes_exp_p5, 5.0000001201E-1);
typedef union imm_xmm_union {
__m256i imm;
__m128i xmm[2];
} imm_xmm_union;
#define COPY_IMM_TO_XMM(imm_, xmm0_, xmm1_) \
{ \
imm_xmm_union u ALIGN32; \
u.imm = imm_; \
xmm0_ = u.xmm[0]; \
xmm1_ = u.xmm[1]; \
}
#define COPY_XMM_TO_IMM(xmm0_, xmm1_, imm_) \
{ \
imm_xmm_union u ALIGN32; \
u.xmm[0] = xmm0_; \
u.xmm[1] = xmm1_; \
imm_ = u.imm; \
}
#define AVX2_BITOP_USING_SSE2(fn) \
static inline __m256i avx2_mm256_##fn(__m256i x, int y) { \
/* use SSE2 to perform the bitop AVX2 */ \
__m128i x1, x2; \
__m256i ret; \
COPY_IMM_TO_XMM(x, x1, x2); \
x1 = _mm_##fn(x1, y); \
x2 = _mm_##fn(x2, y); \
COPY_XMM_TO_IMM(x1, x2, ret); \
return ret; \
}
#define AVX2_INTOP_USING_SSE2(fn) \
static inline __m256i avx2_mm256_add_epi32(__m256i x, __m256i y) { \
/* use SSE2 to perform the AVX2 integer operation */ \
__m128i x1, x2; \
__m128i y1, y2; \
__m256i ret; \
COPY_IMM_TO_XMM(x, x1, x2); \
COPY_IMM_TO_XMM(y, y1, y2); \
x1 = _mm_##fn(x1, y1); \
x2 = _mm_##fn(x2, y2); \
COPY_XMM_TO_IMM(x1, x2, ret); \
return ret; \
}
AVX2_BITOP_USING_SSE2(slli_epi32);
AVX2_INTOP_USING_SSE2(add_epi32);
#define AVXEXP_BASE \
__m256 tmp = _mm256_setzero_ps(), fx; \
__m256 one = *reinterpret_cast<const __m256*>(_ps256_one); \
__m256i imm0; \
x = _mm256_min_ps(x, *reinterpret_cast<const __m256*>(_ps256_exp_hi)); \
x = _mm256_max_ps(x, *reinterpret_cast<const __m256*>(_ps256_exp_lo)); \
/* express exp(x) as exp(g + n*log(2)) */ \
fx = _mm256_mul_ps(x, \
*reinterpret_cast<const __m256*>(_ps256_cephes_LOG2EF)); \
fx = _mm256_add_ps(fx, *reinterpret_cast<const __m256*>(_ps256_0p5)); \
tmp = _mm256_floor_ps(fx); \
/* if greater, substract 1 */ \
__m256 mask = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS); \
mask = _mm256_and_ps(mask, one); \
fx = _mm256_sub_ps(tmp, mask); \
tmp = _mm256_mul_ps(fx, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_C1)); \
__m256 z = _mm256_mul_ps( \
fx, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_C2)); \
x = _mm256_sub_ps(x, tmp); \
x = _mm256_sub_ps(x, z); \
z = _mm256_mul_ps(x, x); \
__m256 y = *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p0); \
y = _mm256_mul_ps(y, x); \
y = _mm256_add_ps(y, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p1)); \
y = _mm256_mul_ps(y, x); \
y = _mm256_add_ps(y, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p2)); \
y = _mm256_mul_ps(y, x); \
y = _mm256_add_ps(y, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p3)); \
y = _mm256_mul_ps(y, x); \
y = _mm256_add_ps(y, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p4)); \
y = _mm256_mul_ps(y, x); \
y = _mm256_add_ps(y, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p5)); \
y = _mm256_mul_ps(y, z); \
y = _mm256_add_ps(y, x); \
y = _mm256_add_ps(y, one); \
/* build 2^n */ \
imm0 = _mm256_cvttps_epi32(fx)
__m256 ExpAVX(__m256 x) {
AVXEXP_BASE;
// two AVX2 instructions using SSE2
imm0 = avx2_mm256_add_epi32(imm0,
*reinterpret_cast<const __m256i*>(_pi256_0x7f));
imm0 = avx2_mm256_slli_epi32(imm0, 23);
__m256 pow2n = _mm256_castsi256_ps(imm0);
y = _mm256_mul_ps(y, pow2n);
return y;
}
#endif
#ifdef __AVX2__
__m256 ExpAVX2(__m256 x) {
AVXEXP_BASE;
// two AVX2 instructions
imm0 = _mm256_add_epi32(imm0, *reinterpret_cast<const __m256i*>(_pi256_0x7f));
imm0 = _mm256_slli_epi32(imm0, 23);
__m256 pow2n = _mm256_castsi256_ps(imm0);
y = _mm256_mul_ps(y, pow2n);
return y;
}
#endif
} // namespace detail
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <type_traits>
namespace paddle {
namespace operators {
namespace math {
namespace jitkernel {
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
#define XMM_FLOAT_BLOCK 4
#define YMM_FLOAT_BLOCK 8
#define ZMM_FLOAT_BLOCK 16
typedef struct {
void* gates; // gates: W_ch, W_ih, W_fh, W_oh
const void* ct_1;
void* ct;
void* ht;
/* weight_peephole and checked data are only used in peephole*/
const void* wp{nullptr};
void* checked{nullptr};
} lstm_t;
typedef struct {
void* gates; // gates: {W_update, W_reset; W_state}
const void* ht_1;
void* ht;
} gru_t;
struct rnn_attr_s {
int d;
std::string act_gate, act_cand;
rnn_attr_s() = default;
rnn_attr_s(int _d, const std::string& _act_gate, const std::string& _act_cand)
: d(_d), act_gate(_act_gate), act_cand(_act_cand) {}
};
struct lstm_attr_s : public rnn_attr_s {
bool use_peephole;
std::string act_cell;
lstm_attr_s() = default;
lstm_attr_s(int _d, const std::string& _act_gate,
const std::string& _act_cand, const std::string& _act_cell,
bool _use_peephole = false)
: rnn_attr_s(_d, _act_gate, _act_cand),
use_peephole(_use_peephole),
act_cell(_act_cell) {}
};
typedef struct rnn_attr_s gru_attr_t;
typedef struct lstm_attr_s lstm_attr_t;
} // namespace jitkernel
} // namespace math
} // namespace operators
} // namespace paddle
...@@ -82,10 +82,10 @@ namespace jitkernel { ...@@ -82,10 +82,10 @@ namespace jitkernel {
#define REGISTER_JITKERNEL_ARGS(ker_key, ker_class, marco_define_name, \ #define REGISTER_JITKERNEL_ARGS(ker_key, ker_class, marco_define_name, \
marco_declare, macro_find_key, macro_impl) \ marco_declare, macro_find_key, macro_impl) \
marco_define_name(ker_key, ker_class); \ marco_define_name(ker_key, ker_class); \
REGISTER_JITKERNEL_WITH_DTYPE(ker_class, float, JITKERNEL_DECLARE, \ REGISTER_JITKERNEL_WITH_DTYPE(ker_class, float, marco_declare, \
JITKERNEL_FIND_KEY, JITKERNEL_IMPL); \ macro_find_key, macro_impl); \
REGISTER_JITKERNEL_WITH_DTYPE(ker_class, double, JITKERNEL_DECLARE, \ REGISTER_JITKERNEL_WITH_DTYPE(ker_class, double, marco_declare, \
JITKERNEL_FIND_KEY, JITKERNEL_IMPL) macro_find_key, macro_impl)
#define REGISTER_JITKERNEL(ker_key, ker_class) \ #define REGISTER_JITKERNEL(ker_key, ker_class) \
REGISTER_JITKERNEL_ARGS(ker_key, ker_class, JITKERNEL_DEFINE_NAME, \ REGISTER_JITKERNEL_ARGS(ker_key, ker_class, JITKERNEL_DEFINE_NAME, \
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cmath>
#include <string>
#include "paddle/fluid/operators/math/jit_kernel_impl.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace math {
namespace jitkernel {
namespace refer {
/* Refer code only focus on correctness */
template <typename T>
void VMul(const T* x, const T* y, T* z, int n) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] * y[i];
}
}
template <typename T>
void VAdd(const T* x, const T* y, T* z, int n) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] + y[i];
}
}
template <typename T>
void VAddRelu(const T* x, const T* y, T* z, int n) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] + y[i];
z[i] = z[i] > 0 ? z[i] : 0;
}
}
template <typename T>
void VScal(const T* a, const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = a[0] * x[i];
}
}
template <typename T>
void VAddBias(const T* a, const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = a[0] + x[i];
}
}
template <typename T>
void VRelu(const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = x[i] > 0 ? x[i] : 0;
}
}
template <typename T>
inline void VIdentity(const T* x, T* y, int n) {}
template <typename T>
void VExp(const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = std::exp(x[i]);
}
}
template <typename T>
void VSigmoid(const T* x, T* y, int n) {
// y = 1 / (1 + e^-x)
const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX;
for (int i = 0; i < n; ++i) {
T tmp = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = static_cast<T>(1) / (static_cast<T>(1) + std::exp(-tmp));
}
}
template <typename T>
void VTanh(const T* x, T* y, int n) {
// y = 2 * sigmoid(2x) - 1
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(2) * x[i];
}
VSigmoid(y, y, n);
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(2) * y[i] - static_cast<T>(1);
}
}
template <typename T>
void (*getActFunc(const std::string& type))(const T*, T*, int) { // NOLINT
if (type == "sigmoid") {
return VSigmoid<T>;
} else if (type == "relu") {
return VRelu<T>;
} else if (type == "tanh") {
return VTanh<T>;
} else if (type == "identity" || type == "") {
return VIdentity<T>;
}
PADDLE_THROW("Not support type: %s", type);
return nullptr;
}
// compute ct and ht
template <typename T>
void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) {
T* gates = reinterpret_cast<T*>(step->gates);
const T* ct_1 = reinterpret_cast<const T*>(step->ct_1);
T* ct = reinterpret_cast<T*>(step->ct);
T* ht = reinterpret_cast<T*>(step->ht);
const T* wp = reinterpret_cast<const T*>(step->wp);
T* checked = reinterpret_cast<T*>(step->checked);
auto act_gate = getActFunc<T>(attr->act_gate);
auto act_cand = getActFunc<T>(attr->act_cand);
auto act_cell = getActFunc<T>(attr->act_cell);
int d = attr->d;
int d2 = d * 2;
int d3 = d * 3;
// gates: W_ch, W_ih, W_fh, W_oh
if (attr->use_peephole) {
VMul(wp, ct_1, checked, d);
VMul(wp + d, ct_1, checked + d, d);
VAdd(checked, gates + d, gates + d, d2);
act_gate(gates + d, gates + d, d2);
} else {
act_gate(gates + d, gates + d, d3);
}
// C_t = C_t-1 * fgated + cand_gated * igated
act_cand(gates, gates, d);
VMul(gates, gates + d, gates + d, d);
VMul(ct_1, gates + d2, gates + d2, d);
VAdd(gates + d, gates + d2, ct, d);
if (attr->use_peephole) {
// get ogated
VMul(wp + d2, ct, gates + d, d);
VAdd(gates + d, gates + d3, gates + d3, d);
act_gate(gates + d3, gates + d3, d);
}
// H_t = act_cell(C_t) * ogated
act_cell(ct, gates + d2, d);
VMul(gates + d2, gates + d3, ht, d);
}
// compute c1 and h1 without c0 or h0
template <typename T>
void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr) {
T* gates = reinterpret_cast<T*>(step->gates);
T* ct = reinterpret_cast<T*>(step->ct);
T* ht = reinterpret_cast<T*>(step->ht);
auto act_gate = getActFunc<T>(attr->act_gate);
auto act_cand = getActFunc<T>(attr->act_cand);
auto act_cell = getActFunc<T>(attr->act_cell);
int d = attr->d;
int d2 = d * 2;
int d3 = d * 3;
/* C_t = igated * cgated*/
act_gate(gates + d, gates + d, d);
act_cand(gates, gates, d);
VMul(gates, gates + d, ct, d);
if (attr->use_peephole) {
// get outgated, put W_oc * C_t on igated
const T* wp = reinterpret_cast<const T*>(step->wp);
VMul(wp + d2, ct, gates + d, d);
VAdd(gates + d, gates + d3, gates + d3, d);
}
/* H_t = act_cell(C_t) * ogated */
act_gate(gates + d3, gates + d3, d);
act_cell(ct, gates + d2, d);
VMul(gates + d2, gates + d3, ht, d);
}
// compute h1 without h0
template <typename T>
void GRUH1(gru_t* step, const gru_attr_t* attr) {
T* gates = reinterpret_cast<T*>(step->gates);
T* ht = reinterpret_cast<T*>(step->ht);
auto act_gate = getActFunc<T>(attr->act_gate);
auto act_cand = getActFunc<T>(attr->act_cand);
int d = attr->d;
int d2 = d * 2;
act_gate(gates, gates, d);
act_cand(gates + d2, gates + d2, d);
VMul(gates, gates + d2, ht, d);
}
// compute the first part of GRU: ht = act_gate(r) * ht_1
template <typename T>
void GRUHtPart1(gru_t* step, const gru_attr_t* attr) {
// W: {W_update, W_reset; W_state}
T* gates = reinterpret_cast<T*>(step->gates);
T* ht = reinterpret_cast<T*>(step->ht);
const T* ht_1 = reinterpret_cast<const T*>(step->ht_1);
auto act_gate = getActFunc<T>(attr->act_gate);
act_gate(gates + attr->d, gates + attr->d, attr->d);
VMul(ht_1, gates + attr->d, ht, attr->d);
}
// compute the second part of GRU:
// ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1
template <typename T>
void GRUHtPart2(gru_t* step, const gru_attr_t* attr) {
T* gates = reinterpret_cast<T*>(step->gates);
T* ht = reinterpret_cast<T*>(step->ht);
const T* ht_1 = reinterpret_cast<const T*>(step->ht_1);
auto act_gate = getActFunc<T>(attr->act_gate);
auto act_cand = getActFunc<T>(attr->act_cand);
int d = attr->d;
T* y = gates + d * 2;
act_gate(gates, gates, d);
act_cand(y, y, d);
// out = zt*ht~ + (1-zt)*ht_1
for (int i = 0; i < d; ++i) {
ht[i] = gates[i] * y[i] + (static_cast<T>(1) - gates[i]) * ht_1[i];
}
}
} // namespace refer
} // namespace jitkernel
} // namespace math
} // namespace operators
} // namespace paddle
...@@ -15,470 +15,248 @@ limitations under the License. */ ...@@ -15,470 +15,248 @@ limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h" #include "paddle/fluid/operators/math/jit_kernel.h"
#include <string> #include <string>
#include "paddle/fluid/operators/math/jit_kernel_macro.h" #include "paddle/fluid/operators/math/jit_kernel_macro.h"
#include "paddle/fluid/operators/math/jit_kernel_refer.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
#ifdef __AVX__ #ifdef PADDLE_WITH_XBYAK
#include <immintrin.h> #include "paddle/fluid/operators/math/jit_code.h"
#endif #endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
namespace jitkernel { namespace jitkernel {
namespace detail {
#ifdef __AVX__
__m256 ExpAVX(__m256 x);
#endif
#ifdef __AVX2__
__m256 ExpAVX2(__m256 x);
#endif
} // namespace detail
namespace jit = platform::jit;
#ifdef __AVX__
typedef enum { kSigmoid, kRelu, kTanh, kIdentity } act_type;
class AVXAct { /* LSTM JitKernel */
public: template <typename T>
virtual ~AVXAct() = default; class LSTMKernelImpl : public LSTMKernel<T> {
virtual __m256 Compute(__m256 x) const = 0;
};
template <act_type type, jit::cpu_isa_t isa>
class AVXActImpl : public AVXAct {
public: public:
__m256 Compute(__m256 x) const override { PADDLE_THROW("Unkown type!"); } static inline std::string name(const lstm_attr_t& attr) {
}; PADDLE_THROW("DType should be either float or double");
#define AVX_SIGMOID(isa, expisa) \
template <> \
__m256 AVXActImpl<kSigmoid, isa>::Compute(__m256 x) const { \
__m256 ones = _mm256_set1_ps(1.0f); \
x = _mm256_max_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MIN)); \
x = _mm256_min_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MAX)); \
x = _mm256_sub_ps(_mm256_set1_ps(0.0f), x); \
x = expisa(x); \
x = _mm256_add_ps(ones, x); \
return _mm256_div_ps(ones, x); \
}
#define AVX_TANH(isa, expisa) \
template <> \
__m256 AVXActImpl<kTanh, isa>::Compute(__m256 x) const { \
__m256 ones = _mm256_set1_ps(1.0f); \
x = _mm256_mul_ps(_mm256_set1_ps(-2.0f), x); \
x = _mm256_min_ps(x, _mm256_set1_ps(EXP_MAX_INPUT)); \
x = expisa(x); \
x = _mm256_add_ps(ones, x); \
x = _mm256_div_ps(_mm256_set1_ps(2.0f), x); \
return _mm256_sub_ps(x, ones); \
} }
static inline bool useJIT(int d) { return false; }
#define AVX_RELU(isa) \ static inline bool useMKL(int d) { return false; }
template <> \ explicit LSTMKernelImpl(const lstm_attr_t& attr) : LSTMKernel<T>() {
__m256 AVXActImpl<kRelu, isa>::Compute(__m256 x) const { \ #ifdef PADDLE_WITH_XBYAK
return _mm256_max_ps(x, _mm256_setzero_ps()); \ if (useJIT(attr.d)) {
size_t sz = 96 + attr.d / YMM_FLOAT_BLOCK * 90 * 4 * 8;
jitcode0_.reset(new gen::LSTMJitCode(false, attr, sz > 4096 ? sz : 4096));
this->ComputeCtHt =
jitcode0_->getCode<void (*)(lstm_t*, const lstm_attr_t*)>();
jitcode1_.reset(new gen::LSTMJitCode(true, attr, sz > 4096 ? sz : 4096));
this->ComputeC1H1 =
jitcode1_->getCode<void (*)(lstm_t*, const lstm_attr_t*)>();
return;
} }
#endif
#define AVX_IDENTITY(isa) \ this->ComputeCtHt = refer::LSTMCtHt<T>;
template <> \ this->ComputeC1H1 = refer::LSTMC1H1<T>;
__m256 AVXActImpl<kIdentity, isa>::Compute(__m256 x) const { \
return x; \
} }
#define FOR_EACH_AVX_ISA(macro_) \ #ifdef PADDLE_WITH_XBYAK
macro_(jit::avx); \
macro_(jit::avx2); \
macro_(jit::avx512f)
FOR_EACH_AVX_ISA(AVX_RELU);
FOR_EACH_AVX_ISA(AVX_IDENTITY);
AVX_SIGMOID(jit::avx, detail::ExpAVX);
AVX_TANH(jit::avx, detail::ExpAVX);
#ifdef __AVX2__ private:
AVX_SIGMOID(jit::avx2, detail::ExpAVX2); std::unique_ptr<gen::LSTMJitCode> jitcode0_{nullptr}, jitcode1_{nullptr};
AVX_SIGMOID(jit::avx512f, detail::ExpAVX2);
AVX_TANH(jit::avx2, detail::ExpAVX2);
AVX_TANH(jit::avx512f, detail::ExpAVX2);
#endif #endif
};
#undef FOR_EACH_AVX_ISA #ifdef PADDLE_WITH_XBYAK
#undef AVX_IDENTITY template <>
#undef AVX_RELU bool LSTMKernelImpl<float>::useJIT(int d) {
#undef AVX_TANH return gen::LSTMJitCode::init(d);
#undef AVX_SIGMOID }
#endif #endif
/* Peephole JitKernel */
template <typename T> template <typename T>
static std::shared_ptr<const VActKernel<T>> GetActKernel( class PeepholeKernelImpl : public LSTMKernel<T> {
const std::string& type, int n) { public:
if (type == "sigmoid") { static inline std::string name(const lstm_attr_t& attr) {
return std::dynamic_pointer_cast<const VActKernel<T>>( PADDLE_THROW("DType should be either float or double");
KernelPool::Instance().template Get<VSigmoidKernel<T>>(n));
} else if (type == "relu") {
return std::dynamic_pointer_cast<const VActKernel<T>>(
KernelPool::Instance().template Get<VReluKernel<T>>(n));
} else if (type == "tanh") {
return std::dynamic_pointer_cast<const VActKernel<T>>(
KernelPool::Instance().template Get<VTanhKernel<T>>(n));
} else if (type == "identity" || type == "") {
return std::dynamic_pointer_cast<const VActKernel<T>>(
KernelPool::Instance().template Get<VIdentityKernel<T>>(n));
} }
PADDLE_THROW("Not support type: %s", type); static inline bool useJIT(int d) { return false; }
return nullptr; static inline bool useMKL(int d) { return false; }
} explicit PeepholeKernelImpl(const lstm_attr_t& attr) : LSTMKernel<T>() {
#ifdef PADDLE_WITH_XBYAK
#ifdef __AVX__ if (useJIT(attr.d)) {
template <jit::cpu_isa_t isa> size_t sz = 96 + attr.d / YMM_FLOAT_BLOCK * 96 * 4 * 8;
static std::unique_ptr<AVXAct> GetAVXAct(const std::string& type) { jitcode0_.reset(new gen::LSTMJitCode(false, attr, sz > 4096 ? sz : 4096));
if (type == "sigmoid") { this->ComputeCtHt =
return std::unique_ptr<AVXAct>(new AVXActImpl<kSigmoid, isa>()); jitcode0_->getCode<void (*)(lstm_t*, const lstm_attr_t*)>();
} else if (type == "relu") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kRelu, isa>()); jitcode1_.reset(new gen::LSTMJitCode(true, attr, sz > 4096 ? sz : 4096));
} else if (type == "tanh") { this->ComputeC1H1 =
return std::unique_ptr<AVXAct>(new AVXActImpl<kTanh, isa>()); jitcode1_->getCode<void (*)(lstm_t*, const lstm_attr_t*)>();
} else if (type == "identity" || type == "") { return;
return std::unique_ptr<AVXAct>(new AVXActImpl<kIdentity, isa>());
} }
PADDLE_THROW("Not support type: %s", type);
return nullptr;
}
#endif #endif
/* LSTM JitKernel */ this->ComputeCtHt = refer::LSTMCtHt<T>;
template <typename T, jit::cpu_isa_t isa, jit_block> this->ComputeC1H1 = refer::LSTMC1H1<T>;
class LSTMKernelImpl : public LSTMKernel<T> {
public:
explicit LSTMKernelImpl(const std::string& act_gate,
const std::string& act_cand,
const std::string& act_cell, int d)
: LSTMKernel<T>() {
d_ = d;
d2_ = d * 2;
d3_ = d * 3;
act_gate_d3_ = GetActKernel<T>(act_gate, d3_);
act_gate_d_ = GetActKernel<T>(act_gate, d);
act_cand_d_ = GetActKernel<T>(act_cand, d);
act_cell_d_ = GetActKernel<T>(act_cell, d);
vmul_d_ = KernelPool::Instance().template Get<VMulKernel<T>>(d);
vadd_d_ = KernelPool::Instance().template Get<VAddKernel<T>>(d);
} }
void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data, #ifdef PADDLE_WITH_XBYAK
T* checked) const override {
// gates: W_ch, W_ih, W_fh, W_oh
act_gate_d3_->Compute(gates + d_, gates + d_, d3_);
/* C_t = C_t-1 * fgated + cand_gated * igated */
act_cand_d_->Compute(gates, gates, d_);
vmul_d_->Compute(gates, gates + d_, gates + d_, d_);
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_);
vadd_d_->Compute(gates + d_, gates + d2_, ct, d_);
/* H_t = act_cell(C_t) * ogated */
act_cell_d_->Compute(ct, gates + d2_, d_);
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
}
void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override {
/* C_t = igated * cgated*/
act_gate_d_->Compute(gates + d_, gates + d_, d_);
act_cand_d_->Compute(gates, gates, d_);
vmul_d_->Compute(gates, gates + d_, ct, d_);
/* H_t = act_cell(C_t) * ogated */
act_gate_d_->Compute(gates + d3_, gates + d3_, d_);
act_cell_d_->Compute(ct, gates + d2_, d_);
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
}
private: private:
int d_, d2_, d3_; std::unique_ptr<gen::LSTMJitCode> jitcode0_{nullptr}, jitcode1_{nullptr};
std::shared_ptr<const VActKernel<T>> act_gate_d3_, act_gate_d_, act_cand_d_,
act_cell_d_;
std::shared_ptr<const VMulKernel<T>> vmul_d_;
std::shared_ptr<const VAddKernel<T>> vadd_d_;
#ifdef __AVX__
std::unique_ptr<const AVXAct> avx_act_gate_, avx_act_cand_, avx_act_cell_;
#endif #endif
}; };
#define INTRI8_FLOAT(isa) \ #ifdef PADDLE_WITH_XBYAK
template <>
bool PeepholeKernelImpl<float>::useJIT(int d) {
return gen::LSTMJitCode::init(d);
}
#endif
#define JITKERNEL_DEFINE_NAME_LSTM(ker_key, ker_class) \
template <> \ template <> \
LSTMKernelImpl<float, isa, kEQ8>::LSTMKernelImpl( \ std::string ker_class##Impl<float>::name(const lstm_attr_t& attr) { \
const std::string& act_gate, const std::string& act_cand, \ std::string key(#ker_key "f"); \
const std::string& act_cell, int d) \ key += (attr.act_gate + attr.act_cand + attr.act_cell + \
: LSTMKernel<float>() { \ (attr.use_peephole ? "p" : "n")); \
avx_act_gate_ = GetAVXAct<isa>(act_gate); \ if (useJIT(attr.d)) { \
avx_act_cand_ = GetAVXAct<isa>(act_cand); \ /* only jit code need record d*/ \
avx_act_cell_ = GetAVXAct<isa>(act_cell); \ return key + "jit" + std::to_string(attr.d); \
} else if (useMKL(attr.d)) { \
return key + "mkl"; \
} else { \
return key + "any"; \
} \ } \
template <> \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \
float* gates, const float* ct_1, float* ct, float* ht, \
const float* wp_data, float* checked) const { \
/* gates: W_ch, W_ih, W_fh, W_oh */ \
__m256 c, i, f, o; \
c = _mm256_loadu_ps(gates); \
i = _mm256_loadu_ps(gates + 8); \
f = _mm256_loadu_ps(gates + 16); \
o = _mm256_loadu_ps(gates + 24); \
/* C_t = C_t-1 * fgated + cand_gated * igated*/ \
c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \
i = _mm256_loadu_ps(ct_1); \
f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \
f = _mm256_add_ps(c, f); \
_mm256_storeu_ps(ct, f); \
/* H_t = act_cell(C_t) * ogated */ \
o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \
_mm256_storeu_ps(ht, o); \
} \ } \
template <> \ template <> \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeC1H1( \ std::string ker_class##Impl<double>::name(const lstm_attr_t& attr) { \
float* gates, float* ct, float* ht, const float* wp_data) const { \ std::string key(#ker_key "d"); \
__m256 c, i, o; \ /* jit code do not support double yet*/ \
c = _mm256_loadu_ps(gates); \ if (useMKL(attr.d)) { \
i = _mm256_loadu_ps(gates + 8); \ return key + "mkl"; \
o = _mm256_loadu_ps(gates + 24); \ } else { \
/* C_t = igated * cgated*/ \ return key + "any"; \
c = _mm256_mul_ps(avx_act_gate_->Compute(i), avx_act_cand_->Compute(c)); \ } \
_mm256_storeu_ps(ct, c); \
/* H_t = act_cell(C_t) * ogated */ \
o = _mm256_mul_ps(avx_act_cell_->Compute(c), avx_act_gate_->Compute(o)); \
_mm256_storeu_ps(ht, o); \
}
// TODO(TJ): optimize keq16
#ifdef __AVX__
INTRI8_FLOAT(jit::avx);
#endif
#ifdef __AVX2__
INTRI8_FLOAT(jit::avx2);
#endif
#ifdef __AVX512F__
INTRI8_FLOAT(jit::avx512f);
#endif
/* Peephole JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block>
class PeepholeKernelImpl : public LSTMKernel<T> {
public:
explicit PeepholeKernelImpl(const std::string& act_gate,
const std::string& act_cand,
const std::string& act_cell, int d)
: LSTMKernel<T>() {
d_ = d;
d2_ = d * 2;
d3_ = d * 3;
act_gate_d_ = GetActKernel<T>(act_gate, d);
act_cand_d_ = GetActKernel<T>(act_cand, d);
act_cell_d_ = GetActKernel<T>(act_cell, d);
vmul_d_ = KernelPool::Instance().template Get<VMulKernel<T>>(d);
vadd_d_ = KernelPool::Instance().template Get<VAddKernel<T>>(d);
vadd_d2_ = KernelPool::Instance().template Get<VAddKernel<T>>(d2_);
act_gate_d2_ = GetActKernel<T>(act_gate, d2_);
}
void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data,
T* checked) const override {
/* get fgated and igated*/
vmul_d_->Compute(wp_data, ct_1, checked, d_);
vmul_d_->Compute(wp_data + d_, ct_1, checked + d_, d_);
vadd_d2_->Compute(checked, gates + d_, gates + d_, d2_);
act_gate_d2_->Compute(gates + d_, gates + d_, d2_);
/* C_t = C_t-1 * fgated + cand_gated * igated*/
act_cand_d_->Compute(gates, gates, d_);
vmul_d_->Compute(gates, gates + d_, gates + d_, d_);
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_);
vadd_d_->Compute(gates + d_, gates + d2_, ct, d_);
/* get ogated*/
vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_);
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_);
act_gate_d_->Compute(gates + d3_, gates + d3_, d_);
/* H_t = act_cell(C_t) * ogated */
act_cell_d_->Compute(ct, gates + d2_, d_);
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
}
void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override {
/* C_t = igated * cgated*/
act_gate_d_->Compute(gates + d_, gates + d_, d_);
act_cand_d_->Compute(gates, gates, d_);
vmul_d_->Compute(gates, gates + d_, ct, d_);
/* get outgated, put W_oc * C_t on igated */
vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_);
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_);
/* H_t = act_cell(C_t) * ogated */
act_gate_d_->Compute(gates + d3_, gates + d3_, d_);
act_cell_d_->Compute(ct, gates + d2_, d_);
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
} }
private:
int d_, d2_, d3_;
std::shared_ptr<const VActKernel<T>> act_gate_d2_, act_gate_d_, act_cand_d_,
act_cell_d_;
std::shared_ptr<const VMulKernel<T>> vmul_d_;
std::shared_ptr<const VAddKernel<T>> vadd_d_, vadd_d2_;
};
#define JITKERNEL_DECLARE_LSTM(ker_class, ker_dtype) \ #define JITKERNEL_DECLARE_LSTM(ker_class, ker_dtype) \
template <> \ template <> \
std::shared_ptr<const LSTMKernel<ker_dtype>> \ std::shared_ptr<const LSTMKernel<ker_dtype>> \
KernelPool::Get<LSTMKernel<ker_dtype>, const std::string&, \ KernelPool::Get<LSTMKernel<ker_dtype>, const lstm_attr_t&>( \
const std::string&, const std::string&, int, bool>( \ const lstm_attr_t& attr)
const std::string& act_gate, const std::string& act_cand, \
const std::string& act_cell, int d, bool use_peephole)
#define JITKERNEL_KEY_LSTM(ker_key, dtype_key) \ #define JITKERNEL_FIND_KEY_LSTM(ker_class, ker_dtype) \
#ker_key #dtype_key + std::to_string(d) + act_gate + act_cand + act_cell + \ std::string key = ker_class##Impl<ker_dtype>::name(attr)
(use_peephole ? "p" : "n")
#define JITKERNEL_NEW_LSTM_IMPL(ker, dtype, isa, k) \ #define JITKERNEL_LSTM_IMPL(ker, dtype) \
if (use_peephole) { \ if (attr.use_peephole) { \
p = std::dynamic_pointer_cast<ker<dtype>>( \ p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<PeepholeKernelImpl<dtype, isa, k>>( \ std::make_shared<PeepholeKernelImpl<dtype>>(attr)); \
act_gate, act_cand, act_cell, d)); \
} else { \ } else { \
p = std::dynamic_pointer_cast<ker<dtype>>( \ p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype, isa, k>>(act_gate, act_cand, \ std::make_shared<ker##Impl<dtype>>(attr)); \
act_cell, d)); \
} }
REGISTER_JITKERNEL_ARGS_DEPRECATED(lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM, REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DEFINE_NAME_LSTM,
JITKERNEL_KEY_LSTM, JITKERNEL_NEW_LSTM_IMPL); JITKERNEL_DECLARE_LSTM, JITKERNEL_FIND_KEY_LSTM,
JITKERNEL_LSTM_IMPL);
#undef INTRI8_FLOAT #undef JITKERNEL_LSTM_IMPL
#undef JITKERNEL_FIND_KEY_LSTM
#undef JITKERNEL_DECLARE_LSTM #undef JITKERNEL_DECLARE_LSTM
#undef JITKERNEL_KEY_LSTM #undef JITKERNEL_DEFINE_NAME_LSTM
#undef JITKERNEL_NEW_LSTM_IMPL
/* GRU JitKernel */ /* GRU JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block> template <typename T>
class GRUKernelImpl : public GRUKernel<T> { class GRUKernelImpl : public GRUKernel<T> {
public: public:
explicit GRUKernelImpl(const std::string& act_gate, static inline std::string name(const gru_attr_t& attr) {
const std::string& act_state, int d) PADDLE_THROW("DType should be either float or double");
: GRUKernel<T>() {
d_ = d;
d2_ = d * 2;
act_gate_d2_ = GetActKernel<T>(act_gate, d2_);
act_gate_d_ = GetActKernel<T>(act_gate, d);
act_state_d_ = GetActKernel<T>(act_state, d);
vmul_d_ = KernelPool::Instance().template Get<VMulKernel<T>>(d);
}
void ComputeH1(T* gates, T* ht) const override {
act_gate_d_->Compute(gates, gates, d_);
act_state_d_->Compute(gates + d2_, gates + d2_, d_);
vmul_d_->Compute(gates, gates + d2_, ht, d_);
} }
static inline bool useJIT(int d) { return false; }
void ComputeHtPart1(T* gates, const T* ht_1, T* ht) const override { static inline bool useMKL(int d) { return false; }
// W: {W_update, W_reset; W_state} explicit GRUKernelImpl(const gru_attr_t& attr) : GRUKernel<T>() {
act_gate_d2_->Compute(gates, gates, d2_); #ifdef PADDLE_WITH_XBYAK
vmul_d_->Compute(ht_1, gates + d_, ht, d_); if (useJIT(attr.d)) {
} size_t sz = 96 + attr.d / YMM_FLOAT_BLOCK * 96 * 2 * 8;
jitcode0_.reset(new gen::GRUJitCode(0, attr, sz > 4096 ? sz : 4096));
void ComputeHtPart2(T* gates, const T* ht_1, T* ht) const override { this->ComputeH1 =
T* y = gates + d2_; jitcode0_->getCode<void (*)(gru_t*, const gru_attr_t*)>();
act_state_d_->Compute(y, y, d_);
// out = zt*ht~ + (1-zt)*ht_1 jitcode1_.reset(new gen::GRUJitCode(1, attr, sz > 4096 ? sz : 4096));
for (int i = 0; i < d_; ++i) { this->ComputeHtPart1 =
ht[i] = gates[i] * y[i] + (static_cast<T>(1) - gates[i]) * ht_1[i]; jitcode1_->getCode<void (*)(gru_t*, const gru_attr_t*)>();
jitcode2_.reset(new gen::GRUJitCode(2, attr, sz > 4096 ? sz : 4096));
this->ComputeHtPart2 =
jitcode2_->getCode<void (*)(gru_t*, const gru_attr_t*)>();
return;
} }
#endif
this->ComputeH1 = refer::GRUH1<T>;
this->ComputeHtPart1 = refer::GRUHtPart1<T>;
this->ComputeHtPart2 = refer::GRUHtPart2<T>;
} }
#ifdef PADDLE_WITH_XBYAK
private: private:
int d_, d2_; std::unique_ptr<gen::GRUJitCode> jitcode0_{nullptr}, jitcode1_{nullptr},
std::shared_ptr<const VActKernel<T>> act_gate_d2_, act_gate_d_, act_state_d_; jitcode2_{nullptr};
std::shared_ptr<const VMulKernel<T>> vmul_d_;
#ifdef __AVX__
std::unique_ptr<const AVXAct> avx_act_gate_, avx_act_state_;
#endif #endif
}; };
#define INTRI8_FLOAT(isa) \ #ifdef PADDLE_WITH_XBYAK
template <>
bool GRUKernelImpl<float>::useJIT(int d) {
return gen::GRUJitCode::init(d);
}
#endif
#define JITKERNEL_DEFINE_NAME_GRU(ker_key, ker_class) \
template <> \ template <> \
GRUKernelImpl<float, isa, kEQ8>::GRUKernelImpl( \ std::string ker_class##Impl<float>::name(const gru_attr_t& attr) { \
const std::string& act_gate, const std::string& act_state, int d) \ std::string key(#ker_key "f"); \
: GRUKernel<float>() { \ key += (attr.act_gate + attr.act_cand); \
avx_act_gate_ = GetAVXAct<isa>(act_gate); \ if (useJIT(attr.d)) { \
avx_act_state_ = GetAVXAct<isa>(act_state); \ /* only jit code need record d*/ \
return key + "jit" + std::to_string(attr.d); \
} else if (useMKL(attr.d)) { \
return key + "mkl"; \
} else { \
return key + "any"; \
} \ } \
template <> \
void GRUKernelImpl<float, isa, kEQ8>::ComputeH1(float* gates, float* ht) \
const { \
__m256 u, s; \
/* W: {W_update, W_reset; W_state} */ \
u = _mm256_loadu_ps(gates); \
s = _mm256_loadu_ps(gates + 16); \
s = _mm256_mul_ps(avx_act_gate_->Compute(u), avx_act_state_->Compute(s)); \
_mm256_storeu_ps(ht, s); \
} \ } \
template <> \ template <> \
void GRUKernelImpl<float, isa, kEQ8>::ComputeHtPart1( \ std::string ker_class##Impl<double>::name(const gru_attr_t& attr) { \
float* gates, const float* ht_1, float* ht) const { \ std::string key(#ker_key "d"); \
/* not exactly equal the any implementation */ \ /* jit code do not support double yet*/ \
__m256 r, ht0; \ if (useMKL(attr.d)) { \
r = _mm256_loadu_ps(gates + 8); \ return key + "mkl"; \
ht0 = _mm256_loadu_ps(ht_1); \ } else { \
r = _mm256_mul_ps(avx_act_gate_->Compute(r), ht0); \ return key + "any"; \
_mm256_storeu_ps(ht, r); \
} \ } \
template <> \
void GRUKernelImpl<float, isa, kEQ8>::ComputeHtPart2( \
float* gates, const float* ht_1, float* ht) const { \
/* not exactly equal the any implementation */ \
__m256 u, s, ht0; \
u = _mm256_loadu_ps(gates); \
s = _mm256_loadu_ps(gates + 16); \
ht0 = _mm256_loadu_ps(ht_1); \
u = avx_act_gate_->Compute(u); \
s = _mm256_mul_ps(u, avx_act_state_->Compute(s)); \
u = _mm256_sub_ps(_mm256_set1_ps(1.f), u); \
u = _mm256_mul_ps(u, ht0); \
u = _mm256_add_ps(s, u); \
_mm256_storeu_ps(ht, u); \
} }
#ifdef __AVX__
INTRI8_FLOAT(jit::avx);
#endif
#ifdef __AVX2__
INTRI8_FLOAT(jit::avx2);
#endif
#ifdef __AVX512F__
INTRI8_FLOAT(jit::avx512f);
#endif
#define JITKERNEL_DECLARE_GRU(ker_class, ker_dtype) \ #define JITKERNEL_DECLARE_GRU(ker_class, ker_dtype) \
template <> \ template <> \
std::shared_ptr<const GRUKernel<ker_dtype>> KernelPool::Get< \ std::shared_ptr<const ker_class<ker_dtype>> \
GRUKernel<ker_dtype>, const std::string&, const std::string&, int>( \ KernelPool::Get<ker_class<ker_dtype>, const gru_attr_t&>( \
const std::string& act_gate, const std::string& act_state, int d) const gru_attr_t& attr)
#define JITKERNEL_KEY_GRU(ker_key, dtype_key) \ #define JITKERNEL_FIND_KEY_GRU(ker_class, ker_dtype) \
#ker_key #dtype_key + std::to_string(d) + act_gate + act_state std::string key = ker_class##Impl<ker_dtype>::name(attr)
#define JITKERNEL_NEW_GRU_IMPL(ker, dtype, isa, k) \ #define JITKERNEL_GRU_IMPL(ker, dtype) \
p = std::dynamic_pointer_cast<ker<dtype>>( \ p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype, isa, k>>(act_gate, act_state, d)); std::make_shared<ker##Impl<dtype>>(attr));
REGISTER_JITKERNEL_ARGS_DEPRECATED(gru, GRUKernel, JITKERNEL_DECLARE_GRU, REGISTER_JITKERNEL_ARGS(gru, GRUKernel, JITKERNEL_DEFINE_NAME_GRU,
JITKERNEL_KEY_GRU, JITKERNEL_NEW_GRU_IMPL); JITKERNEL_DECLARE_GRU, JITKERNEL_FIND_KEY_GRU,
JITKERNEL_GRU_IMPL);
#undef INTRI8_FLOAT #undef JITKERNEL_GRU_IMPL
#undef JITKERNEL_NEW_GRU_IMPL #undef JITKERNEL_FIND_KEY_GRU
#undef JITKERNEL_KEY_GRU
#undef JITKERNEL_DECLARE_GRU #undef JITKERNEL_DECLARE_GRU
#undef JITKERNEL_DEFINE_NAME_GRU
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "glog/logging.h" #include "glog/logging.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/operators/math/jit_kernel_refer.h"
#include "paddle/fluid/platform/port.h" #include "paddle/fluid/platform/port.h"
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
...@@ -53,12 +54,6 @@ void RandomVec(const int n, T* a, const T lower = static_cast<T>(-20.f), ...@@ -53,12 +54,6 @@ void RandomVec(const int n, T* a, const T lower = static_cast<T>(-20.f),
} }
} }
void vrelu_ref(const int n, const float* x, float* y) {
for (int i = 0; i < n; ++i) {
y[i] = x[i] > 0.f ? x[i] : 0.f;
}
}
#if defined __AVX__ || defined __AVX2__ #if defined __AVX__ || defined __AVX2__
void vrelu_intri8(const int n, const float* x, float* y) { void vrelu_intri8(const int n, const float* x, float* y) {
__m256 tmp = _mm256_loadu_ps(x); __m256 tmp = _mm256_loadu_ps(x);
...@@ -69,6 +64,7 @@ void vrelu_intri8(const int n, const float* x, float* y) { ...@@ -69,6 +64,7 @@ void vrelu_intri8(const int n, const float* x, float* y) {
TEST(JitKernel, vrelu) { TEST(JitKernel, vrelu) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
namespace refer = paddle::operators::math::jitkernel::refer;
for (int d : {3, 7, 8, 15, 16, 30, 256, 512}) { for (int d : {3, 7, 8, 15, 16, 30, 256, 512}) {
std::vector<float> x(d); std::vector<float> x(d);
std::vector<float> zref(d), ztgt(d); std::vector<float> zref(d), ztgt(d);
...@@ -80,7 +76,7 @@ TEST(JitKernel, vrelu) { ...@@ -80,7 +76,7 @@ TEST(JitKernel, vrelu) {
float* zref_data = zref.data(); float* zref_data = zref.data();
auto trefs = GetCurrentUS(); auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
vrelu_ref(d, x_data, zref_data); refer::VRelu<float>(x_data, zref_data, d);
} }
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
#if defined __AVX__ || defined __AVX2__ #if defined __AVX__ || defined __AVX2__
...@@ -90,7 +86,7 @@ TEST(JitKernel, vrelu) { ...@@ -90,7 +86,7 @@ TEST(JitKernel, vrelu) {
vrelu_intri8(d, x_data, zref_data); vrelu_intri8(d, x_data, zref_data);
} }
auto si1 = GetCurrentUS(); auto si1 = GetCurrentUS();
VLOG(3) << "Vec size 8 intr takes: " << (si1 - si0) / repeat; VLOG(3) << "Vec size 8 intr takes: " << (si1 - si0) / repeat << " us";
} }
#endif #endif
auto ttgts = GetCurrentUS(); auto ttgts = GetCurrentUS();
...@@ -98,22 +94,18 @@ TEST(JitKernel, vrelu) { ...@@ -98,22 +94,18 @@ TEST(JitKernel, vrelu) {
ker->Compute(x_data, ztgt_data, d); ker->Compute(x_data, ztgt_data, d);
} }
auto ttgte = GetCurrentUS(); auto ttgte = GetCurrentUS();
VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat VLOG(3) << "Vec size " << d
<< " us, tgt takes: " << (ttgte - ttgts) / repeat; << ": refer takes: " << (trefe - trefs) / repeat
<< " us, tgt takes: " << (ttgte - ttgts) / repeat << " us";
for (int i = 0; i < d; ++i) { for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
} }
} }
} }
void vaddbias_ref(const int n, const float a, const float* x, float* y) {
for (int i = 0; i < n; ++i) {
y[i] = x[i] + a;
}
}
TEST(JitKernel, vaddbias) { TEST(JitKernel, vaddbias) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
namespace refer = paddle::operators::math::jitkernel::refer;
for (int d : {7, 8, 15, 16, 30, 64, 100, 128, 256}) { for (int d : {7, 8, 15, 16, 30, 64, 100, 128, 256}) {
std::vector<float> x(d); std::vector<float> x(d);
std::vector<float> zref(d), ztgt(d); std::vector<float> zref(d), ztgt(d);
...@@ -126,7 +118,7 @@ TEST(JitKernel, vaddbias) { ...@@ -126,7 +118,7 @@ TEST(JitKernel, vaddbias) {
float* zref_data = zref.data(); float* zref_data = zref.data();
auto trefs = GetCurrentUS(); auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
vaddbias_ref(d, a, x_data, zref_data); refer::VAddBias<float>(&a, x_data, zref_data, d);
} }
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
auto ttgts = GetCurrentUS(); auto ttgts = GetCurrentUS();
...@@ -135,20 +127,15 @@ TEST(JitKernel, vaddbias) { ...@@ -135,20 +127,15 @@ TEST(JitKernel, vaddbias) {
} }
auto ttgte = GetCurrentUS(); auto ttgte = GetCurrentUS();
VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat VLOG(3) << "Vec size " << d
<< " us, tgt takes: " << (ttgte - ttgts) / repeat; << ": refer takes: " << (trefe - trefs) / repeat
<< " us, tgt takes: " << (ttgte - ttgts) / repeat << " us";
for (int i = 0; i < d; ++i) { for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
} }
} }
} }
void vexp_ref(const int n, const float* x, float* y) {
for (int i = 0; i < n; ++i) {
y[i] = std::exp(x[i]);
}
}
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
void vexp_mkl(const int n, const float* x, float* y) { void vexp_mkl(const int n, const float* x, float* y) {
paddle::platform::dynload::vsExp(n, x, y); paddle::platform::dynload::vsExp(n, x, y);
...@@ -157,6 +144,7 @@ void vexp_mkl(const int n, const float* x, float* y) { ...@@ -157,6 +144,7 @@ void vexp_mkl(const int n, const float* x, float* y) {
TEST(JitKernel, vexp) { TEST(JitKernel, vexp) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
namespace refer = paddle::operators::math::jitkernel::refer;
for (int d : {1, 3, 4, 6, 7, 8, 12, 15, 16, 20, 30, 128, 256}) { for (int d : {1, 3, 4, 6, 7, 8, 12, 15, 16, 20, 30, 128, 256}) {
std::vector<float> x(d); std::vector<float> x(d);
std::vector<float> zref(d), ztgt(d); std::vector<float> zref(d), ztgt(d);
...@@ -168,7 +156,7 @@ TEST(JitKernel, vexp) { ...@@ -168,7 +156,7 @@ TEST(JitKernel, vexp) {
float* zref_data = zref.data(); float* zref_data = zref.data();
auto trefs = GetCurrentUS(); auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
vexp_ref(d, x_data, zref_data); refer::VExp<float>(x_data, zref_data, d);
} }
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
...@@ -193,26 +181,14 @@ TEST(JitKernel, vexp) { ...@@ -193,26 +181,14 @@ TEST(JitKernel, vexp) {
#else #else
<< " us, " << " us, "
#endif #endif
<< "tgt takes: " << (ttgte - ttgts) / repeat;
<< "tgt takes: " << (ttgte - ttgts) / repeat << " us";
for (int i = 0; i < d; ++i) { for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
} }
} }
} }
inline float _sigmoid(float x) {
const float min = SIGMOID_THRESHOLD_MIN;
const float max = SIGMOID_THRESHOLD_MAX;
float tmp = (x < min) ? min : ((x > max) ? max : x);
return 1.f / (1.f + std::exp(-tmp));
}
void vsigmoid_ref(const int n, const float* x, float* y) {
for (int i = 0; i < n; ++i) {
y[i] = _sigmoid(x[i]);
}
}
void vsigmoid_better( void vsigmoid_better(
const std::shared_ptr< const std::shared_ptr<
const paddle::operators::math::jitkernel::VExpKernel<float>>& vexp, const paddle::operators::math::jitkernel::VExpKernel<float>>& vexp,
...@@ -231,6 +207,7 @@ void vsigmoid_better( ...@@ -231,6 +207,7 @@ void vsigmoid_better(
TEST(JitKernel, vsigmoid) { TEST(JitKernel, vsigmoid) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
namespace refer = paddle::operators::math::jitkernel::refer;
for (int d : {1, 3, 4, 6, 7, 8, 15, 16, 30, 32, 64, 100, 128, 256}) { for (int d : {1, 3, 4, 6, 7, 8, 15, 16, 30, 32, 64, 100, 128, 256}) {
std::vector<float> x(d); std::vector<float> x(d);
std::vector<float> zref(d), ztgt(d); std::vector<float> zref(d), ztgt(d);
...@@ -249,7 +226,7 @@ TEST(JitKernel, vsigmoid) { ...@@ -249,7 +226,7 @@ TEST(JitKernel, vsigmoid) {
auto tmkle = GetCurrentUS(); auto tmkle = GetCurrentUS();
auto trefs = GetCurrentUS(); auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
vsigmoid_ref(d, x_data, zref_data); refer::VSigmoid<float>(x_data, zref_data, d);
} }
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
auto ttgts = GetCurrentUS(); auto ttgts = GetCurrentUS();
...@@ -258,23 +235,17 @@ TEST(JitKernel, vsigmoid) { ...@@ -258,23 +235,17 @@ TEST(JitKernel, vsigmoid) {
} }
auto ttgte = GetCurrentUS(); auto ttgte = GetCurrentUS();
VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat
VLOG(3) << "Vec size " << d
<< ": refer takes: " << (trefe - trefs) / repeat
<< " us, better(jit exp) takes: " << (tmkle - tmkls) / repeat << " us, better(jit exp) takes: " << (tmkle - tmkls) / repeat
<< " us, tgt takes: " << (ttgte - ttgts) / repeat; << " us, tgt takes: " << (ttgte - ttgts) / repeat << " us";
for (int i = 0; i < d; ++i) { for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
} }
} }
} }
inline float _tanh(float x) { return 2.f * _sigmoid(2.f * x) - 1.f; }
void vtanh_ref(const int n, const float* x, float* y) {
for (int i = 0; i < n; ++i) {
y[i] = _tanh(x[i]);
}
}
void vtanh_better( void vtanh_better(
const std::shared_ptr< const std::shared_ptr<
const paddle::operators::math::jitkernel::VScalKernel<float>>& vscal, const paddle::operators::math::jitkernel::VScalKernel<float>>& vscal,
...@@ -294,6 +265,7 @@ void vtanh_better( ...@@ -294,6 +265,7 @@ void vtanh_better(
TEST(JitKernel, vtanh) { TEST(JitKernel, vtanh) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
namespace refer = paddle::operators::math::jitkernel::refer;
for (int d : {1, 2, 3, 4, 5, 6, 7, 8, 15, 16, 30, 32, 64, 100, 128, 256}) { for (int d : {1, 2, 3, 4, 5, 6, 7, 8, 15, 16, 30, 32, 64, 100, 128, 256}) {
std::vector<float> x(d); std::vector<float> x(d);
std::vector<float> zref(d), ztgt(d); std::vector<float> zref(d), ztgt(d);
...@@ -316,7 +288,7 @@ TEST(JitKernel, vtanh) { ...@@ -316,7 +288,7 @@ TEST(JitKernel, vtanh) {
auto tmkle = GetCurrentUS(); auto tmkle = GetCurrentUS();
auto trefs = GetCurrentUS(); auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
vtanh_ref(d, x_data, zref_data); refer::VTanh<float>(x_data, zref_data, d);
} }
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
auto ttgts = GetCurrentUS(); auto ttgts = GetCurrentUS();
...@@ -325,41 +297,17 @@ TEST(JitKernel, vtanh) { ...@@ -325,41 +297,17 @@ TEST(JitKernel, vtanh) {
} }
auto ttgte = GetCurrentUS(); auto ttgte = GetCurrentUS();
VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat
VLOG(3) << "Vec size " << d
<< ": refer takes: " << (trefe - trefs) / repeat
<< " us, better(jit exp) takes: " << (tmkle - tmkls) / repeat << " us, better(jit exp) takes: " << (tmkle - tmkls) / repeat
<< " us, tgt takes: " << (ttgte - ttgts) / repeat; << " us, tgt takes: " << (ttgte - ttgts) / repeat << " us";
for (int i = 0; i < d; ++i) { for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
} }
} }
} }
void lstm_ctht_ref(
const std::shared_ptr<
const paddle::operators::math::jitkernel::VSigmoidKernel<float>>&
vsigmoid_3d,
const std::shared_ptr<
const paddle::operators::math::jitkernel::VTanhKernel<float>>& vtanh_d,
const std::shared_ptr<
const paddle::operators::math::jitkernel::VExpKernel<float>>& vexp_1,
const int d, float* gates, const float* ct_1, float* ct, float* ht) {
vsigmoid_3d->Compute(gates + d, gates + d, 3 * d);
vtanh_d->Compute(gates, gates, d);
const float *i = gates + d, *f = gates + d * 2, *o = gates + d * 3;
const float min = SIGMOID_THRESHOLD_MIN;
const float max = SIGMOID_THRESHOLD_MAX;
for (int k = 0; k < d; ++k) {
// C_t = C_t-1 * fgated + cand_gated * igated
ct[k] = ct_1[k] * f[k] + gates[k] * i[k];
// H_t = act_cell(C_t) * ogated
float tmp = ct[k] * 2;
tmp = 0.f - ((tmp < min) ? min : ((tmp > max) ? max : tmp));
vexp_1->Compute(&tmp, &tmp, 1);
tmp = 2.f / (1.f + tmp) - 1.f;
ht[k] = tmp * o[k];
}
}
void lstm_ctht_better( void lstm_ctht_better(
const std::shared_ptr< const std::shared_ptr<
const paddle::operators::math::jitkernel::VSigmoidKernel<float>>& const paddle::operators::math::jitkernel::VSigmoidKernel<float>>&
...@@ -384,6 +332,7 @@ void lstm_ctht_better( ...@@ -384,6 +332,7 @@ void lstm_ctht_better(
TEST(JitKernel, lstm) { TEST(JitKernel, lstm) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
namespace refer = paddle::operators::math::jitkernel::refer;
for (int d : {1, 2, 3, 4, 5, 6, 7, 8, 15, 16, 30, 32, 64, 100}) { for (int d : {1, 2, 3, 4, 5, 6, 7, 8, 15, 16, 30, 32, 64, 100}) {
int d4 = d * 4; int d4 = d * 4;
int d3 = d * 3; int d3 = d * 3;
...@@ -394,19 +343,17 @@ TEST(JitKernel, lstm) { ...@@ -394,19 +343,17 @@ TEST(JitKernel, lstm) {
RandomVec<float>(d, ct_1.data(), -2.f, 2.f); RandomVec<float>(d, ct_1.data(), -2.f, 2.f);
memcpy(xref.data(), x.data(), sizeof(float) * d4); memcpy(xref.data(), x.data(), sizeof(float) * d4);
std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh"; std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh";
const jit::lstm_attr_t attr(d, act_gate, act_cand, act_cell, false);
const auto& ker = const auto& ker =
jit::KernelPool::Instance() jit::KernelPool::Instance()
.template Get<jit::LSTMKernel<float>, const std::string&, .template Get<jit::LSTMKernel<float>, const jit::lstm_attr_t&>(
const std::string&, const std::string&>( attr);
act_gate, act_cand, act_cell, d, false);
// below kernels are used to compute refer // below kernels are used to compute refer
const auto& vsigmoid_3d = const auto& vsigmoid_3d =
jit::KernelPool::Instance().template Get<jit::VSigmoidKernel<float>>( jit::KernelPool::Instance().template Get<jit::VSigmoidKernel<float>>(
d3); d3);
const auto& vtanh_d = const auto& vtanh_d =
jit::KernelPool::Instance().template Get<jit::VTanhKernel<float>>(d); jit::KernelPool::Instance().template Get<jit::VTanhKernel<float>>(d);
const auto& vexp_1 =
jit::KernelPool::Instance().template Get<jit::VExpKernel<float>>(1);
const auto& vmul_d = const auto& vmul_d =
jit::KernelPool::Instance().template Get<jit::VMulKernel<float>>(d); jit::KernelPool::Instance().template Get<jit::VMulKernel<float>>(d);
const auto& vadd_d = const auto& vadd_d =
...@@ -420,9 +367,17 @@ TEST(JitKernel, lstm) { ...@@ -420,9 +367,17 @@ TEST(JitKernel, lstm) {
float* ct_ref_data = ct_ref.data(); float* ct_ref_data = ct_ref.data();
float* ht_ref_data = ht_ref.data(); float* ht_ref_data = ht_ref.data();
// compute once to check correctness // compute once to check correctness
lstm_ctht_ref(vsigmoid_3d, vtanh_d, vexp_1, d, xref_data, ct_1_data, jit::lstm_t step;
ct_ref_data, ht_ref_data); step.gates = xref_data;
ker->ComputeCtHt(x_data, ct_1_data, ct_tgt_data, ht_tgt_data); step.ct_1 = ct_1_data;
step.ct = ct_ref_data;
step.ht = ht_ref_data;
refer::LSTMCtHt<float>(&step, &attr);
step.gates = x_data;
step.ct = ct_tgt_data;
step.ht = ht_tgt_data;
ker->ComputeCtHt(&step, &attr);
for (int i = 0; i < d; ++i) { for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ct_tgt_data[i], ct_ref_data[i], 1e-3); EXPECT_NEAR(ct_tgt_data[i], ct_ref_data[i], 1e-3);
EXPECT_NEAR(ht_tgt_data[i], ht_ref_data[i], 1e-3); EXPECT_NEAR(ht_tgt_data[i], ht_ref_data[i], 1e-3);
...@@ -436,31 +391,21 @@ TEST(JitKernel, lstm) { ...@@ -436,31 +391,21 @@ TEST(JitKernel, lstm) {
auto tmkle = GetCurrentUS(); auto tmkle = GetCurrentUS();
auto trefs = GetCurrentUS(); auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
lstm_ctht_ref(vsigmoid_3d, vtanh_d, vexp_1, d, xref_data, ct_1_data, refer::LSTMCtHt<float>(&step, &attr);
ct_ref_data, ht_ref_data);
} }
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
auto ttgts = GetCurrentUS(); auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
ker->ComputeCtHt(x_data, ct_1_data, ct_tgt_data, ht_tgt_data); ker->ComputeCtHt(&step, &attr);
} }
auto ttgte = GetCurrentUS(); auto ttgte = GetCurrentUS();
VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat VLOG(3) << "Vec size " << d
<< ": refer takes: " << (trefe - trefs) / repeat
<< " us, better(jit) takes: " << (tmkle - tmkls) / repeat << " us, better(jit) takes: " << (tmkle - tmkls) / repeat
<< " us, tgt takes: " << (ttgte - ttgts) / repeat; << " us, tgt takes: " << (ttgte - ttgts) / repeat << " us";
} }
} }
void vscal_ref(const int n, const float a, const float* x, float* y) {
for (int i = 0; i < n; ++i) {
y[i] = a * x[i];
}
}
void vscal_inp_ref(const int n, const float a, float* x) {
for (int i = 0; i < n; ++i) {
x[i] = a * x[i];
}
}
#if defined __AVX__ || defined __AVX2__ #if defined __AVX__ || defined __AVX2__
void vscal_intri8(const int n, const float a, const float* x, float* y) { void vscal_intri8(const int n, const float a, const float* x, float* y) {
__m256 tmp; __m256 tmp;
...@@ -486,6 +431,7 @@ void vscal_inp_mkl(const int n, const float a, float* x) { ...@@ -486,6 +431,7 @@ void vscal_inp_mkl(const int n, const float a, float* x) {
TEST(JitKernel, vscal) { TEST(JitKernel, vscal) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
namespace refer = paddle::operators::math::jitkernel::refer;
for (int d : {7, 8, 15, 16, 30, 256, 512}) { for (int d : {7, 8, 15, 16, 30, 256, 512}) {
std::vector<float> x(d), y(d); std::vector<float> x(d), y(d);
std::vector<float> zref(d), ztgt(d); std::vector<float> zref(d), ztgt(d);
...@@ -500,12 +446,12 @@ TEST(JitKernel, vscal) { ...@@ -500,12 +446,12 @@ TEST(JitKernel, vscal) {
float* zref_data = zref.data(); float* zref_data = zref.data();
auto trefs = GetCurrentUS(); auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
vscal_ref(d, a, x_data, zref_data); refer::VScal<float>(&a, x_data, zref_data, d);
} }
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
auto trefs1 = GetCurrentUS(); auto trefs1 = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
vscal_inp_ref(d, a, y_data); refer::VScal<float>(&a, y_data, y_data, d);
} }
auto trefe1 = GetCurrentUS(); auto trefe1 = GetCurrentUS();
...@@ -530,7 +476,7 @@ TEST(JitKernel, vscal) { ...@@ -530,7 +476,7 @@ TEST(JitKernel, vscal) {
} }
auto si3 = GetCurrentUS(); auto si3 = GetCurrentUS();
VLOG(3) << "Vec size 8 intr takes: " << (si1 - si0) / repeat VLOG(3) << "Vec size 8 intr takes: " << (si1 - si0) / repeat
<< " us, inplace: " << (si3 - si2) / repeat; << " us, inplace: " << (si3 - si2) / repeat << " us";
} }
#endif #endif
...@@ -552,19 +498,14 @@ TEST(JitKernel, vscal) { ...@@ -552,19 +498,14 @@ TEST(JitKernel, vscal) {
<< " us, " << " us, "
#endif #endif
<< "tgt takes: " << (ttgte - ttgts) / repeat << "tgt takes: " << (ttgte - ttgts) / repeat
<< "us, tgt inplace takes: " << (ttgte1 - ttgts1) / repeat; << "us, tgt inplace takes: " << (ttgte1 - ttgts1) / repeat
<< " us";
for (int i = 0; i < d; ++i) { for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
} }
} }
} }
void vmul_ref(const int n, const float* x, const float* y, float* z) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] * y[i];
}
}
#if defined __AVX__ || defined __AVX2__ #if defined __AVX__ || defined __AVX2__
void vmul_intri8(const int n, const float* x, const float* y, float* z) { void vmul_intri8(const int n, const float* x, const float* y, float* z) {
__m256 tmpx, tmpy; __m256 tmpx, tmpy;
...@@ -583,6 +524,7 @@ void vmul_mkl(const int n, const float* x, const float* y, float* z) { ...@@ -583,6 +524,7 @@ void vmul_mkl(const int n, const float* x, const float* y, float* z) {
TEST(JitKernel, vmul) { TEST(JitKernel, vmul) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
namespace refer = paddle::operators::math::jitkernel::refer;
for (int d : {7, 8, 15, 16, 20, 30, 256, 512, 1000, 1024}) { for (int d : {7, 8, 15, 16, 20, 30, 256, 512, 1000, 1024}) {
std::vector<float> x(d), y(d); std::vector<float> x(d), y(d);
std::vector<float> zref(d), ztgt(d); std::vector<float> zref(d), ztgt(d);
...@@ -596,7 +538,7 @@ TEST(JitKernel, vmul) { ...@@ -596,7 +538,7 @@ TEST(JitKernel, vmul) {
float* zref_data = zref.data(); float* zref_data = zref.data();
auto trefs = GetCurrentUS(); auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
vmul_ref(d, x_data, y_data, zref_data); refer::VMul<float>(x_data, y_data, zref_data, d);
} }
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
...@@ -631,19 +573,13 @@ TEST(JitKernel, vmul) { ...@@ -631,19 +573,13 @@ TEST(JitKernel, vmul) {
#else #else
<< " us, " << " us, "
#endif #endif
<< "tgt takes: " << (ttgte - ttgts) / repeat; << "tgt takes: " << (ttgte - ttgts) / repeat << " us";
for (int i = 0; i < d; ++i) { for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
} }
} }
} }
void vadd_ref(const int n, const float* x, const float* y, float* z) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] + y[i];
}
}
#if defined __AVX__ || defined __AVX2__ #if defined __AVX__ || defined __AVX2__
void vadd_intri8(const int n, const float* x, const float* y, float* z) { void vadd_intri8(const int n, const float* x, const float* y, float* z) {
__m256 tmpx, tmpy; __m256 tmpx, tmpy;
...@@ -662,6 +598,7 @@ void vadd_mkl(const int n, const float* x, const float* y, float* z) { ...@@ -662,6 +598,7 @@ void vadd_mkl(const int n, const float* x, const float* y, float* z) {
TEST(JitKernel, vadd) { TEST(JitKernel, vadd) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
namespace refer = paddle::operators::math::jitkernel::refer;
for (int d : {7, 8, 15, 16, 30, 256, 512}) { for (int d : {7, 8, 15, 16, 30, 256, 512}) {
std::vector<float> x(d), y(d); std::vector<float> x(d), y(d);
std::vector<float> zref(d), ztgt(d); std::vector<float> zref(d), ztgt(d);
...@@ -675,7 +612,7 @@ TEST(JitKernel, vadd) { ...@@ -675,7 +612,7 @@ TEST(JitKernel, vadd) {
float* zref_data = zref.data(); float* zref_data = zref.data();
auto trefs = GetCurrentUS(); auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
vadd_ref(d, x_data, y_data, zref_data); refer::VAdd<float>(x_data, y_data, zref_data, d);
} }
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
...@@ -710,19 +647,13 @@ TEST(JitKernel, vadd) { ...@@ -710,19 +647,13 @@ TEST(JitKernel, vadd) {
#else #else
<< " us, " << " us, "
#endif #endif
<< "tgt takes: " << (ttgte - ttgts) / repeat; << "tgt takes: " << (ttgte - ttgts) / repeat << " us";
for (int i = 0; i < d; ++i) { for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
} }
} }
} }
void vaddrelu_ref(const int n, const float* x, const float* y, float* z) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] + y[i];
z[i] = z[i] > 0 ? z[i] : 0;
}
}
void vaddrelu_better( void vaddrelu_better(
const std::shared_ptr< const std::shared_ptr<
const paddle::operators::math::jitkernel::VAddKernel<float>>& vadd, const paddle::operators::math::jitkernel::VAddKernel<float>>& vadd,
...@@ -735,6 +666,7 @@ void vaddrelu_better( ...@@ -735,6 +666,7 @@ void vaddrelu_better(
TEST(JitKernel, vaddrelu) { TEST(JitKernel, vaddrelu) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
namespace refer = paddle::operators::math::jitkernel::refer;
for (int d : {7, 8, 15, 16, 30, 256, 512}) { for (int d : {7, 8, 15, 16, 30, 256, 512}) {
std::vector<float> x(d), y(d); std::vector<float> x(d), y(d);
std::vector<float> zref(d), ztgt(d); std::vector<float> zref(d), ztgt(d);
...@@ -752,7 +684,7 @@ TEST(JitKernel, vaddrelu) { ...@@ -752,7 +684,7 @@ TEST(JitKernel, vaddrelu) {
float* zref_data = zref.data(); float* zref_data = zref.data();
auto trefs = GetCurrentUS(); auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
vaddrelu_ref(d, x_data, y_data, zref_data); refer::VAddRelu<float>(x_data, y_data, zref_data, d);
} }
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
auto tmkls = GetCurrentUS(); auto tmkls = GetCurrentUS();
...@@ -765,9 +697,10 @@ TEST(JitKernel, vaddrelu) { ...@@ -765,9 +697,10 @@ TEST(JitKernel, vaddrelu) {
ker->Compute(x_data, y_data, ztgt_data, d); ker->Compute(x_data, y_data, ztgt_data, d);
} }
auto ttgte = GetCurrentUS(); auto ttgte = GetCurrentUS();
VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat VLOG(3) << "Vec size " << d
<< ": refer takes: " << (trefe - trefs) / repeat
<< " us, better takes: " << (tmkle - tmkls) / repeat << " us, " << " us, better takes: " << (tmkle - tmkls) / repeat << " us, "
<< "tgt takes: " << (ttgte - ttgts) / repeat; << "tgt takes: " << (ttgte - ttgts) / repeat << " us";
for (int i = 0; i < d; ++i) { for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
} }
...@@ -778,21 +711,23 @@ TEST(JitKernel, pool) { ...@@ -778,21 +711,23 @@ TEST(JitKernel, pool) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
const int frame_size = 4; const int frame_size = 4;
std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh"; std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh";
jit::lstm_attr_t attr(frame_size, act_gate, act_cand, act_cell, false);
// empty call it to avoid unknown flag 'use_pinned_memory' on Mac
paddle::platform::jit::MayIUse(paddle::platform::jit::avx);
const auto& plstm1 = const auto& plstm1 =
jit::KernelPool::Instance() jit::KernelPool::Instance()
.template Get<jit::LSTMKernel<float>, const std::string&, .template Get<jit::LSTMKernel<float>, const jit::lstm_attr_t&>(attr);
const std::string&, const std::string&>(
act_gate, act_cand, act_cell, frame_size, false);
const auto& plstm2 = const auto& plstm2 =
jit::KernelPool::Instance() jit::KernelPool::Instance()
.template Get<jit::LSTMKernel<float>, const std::string&, .template Get<jit::LSTMKernel<float>, const jit::lstm_attr_t&>(attr);
const std::string&, const std::string&>( EXPECT_EQ(plstm1, plstm2);
act_gate, act_cand, act_cell, frame_size, false);
const auto& peephole = const auto& peephole =
jit::KernelPool::Instance() jit::KernelPool::Instance()
.template Get<jit::LSTMKernel<float>, const std::string&, .template Get<jit::LSTMKernel<float>, const jit::lstm_attr_t&>(
const std::string&, const std::string&>( jit::lstm_attr_t(frame_size, act_gate, act_cand, act_cell, true));
act_gate, act_cand, act_cell, frame_size, true);
EXPECT_TRUE(plstm1 != peephole); EXPECT_TRUE(plstm1 != peephole);
const auto& pvmul_f = const auto& pvmul_f =
......
...@@ -671,6 +671,55 @@ EOF ...@@ -671,6 +671,55 @@ EOF
${DOCKERFILE_CUBLAS_DSO} ${DOCKERFILE_CUBLAS_DSO}
${DOCKERFILE_GPU_ENV} ${DOCKERFILE_GPU_ENV}
ENV NCCL_LAUNCH_MODE PARALLEL ENV NCCL_LAUNCH_MODE PARALLEL
EOF
elif [ "$1" == "cp36-cp36m" ]; then
cat >> ${PADDLE_ROOT}/build/Dockerfile <<EOF
ADD python/dist/*.whl /
# run paddle version to install python packages first
RUN apt-get update && ${NCCL_DEPS}
RUN apt-get install -y make build-essential libssl-dev zlib1g-dev libbz2-dev \
libreadline-dev libsqlite3-dev wget curl llvm libncurses5-dev libncursesw5-dev \
xz-utils tk-dev libffi-dev liblzma-dev
RUN mkdir -p /root/python_build/ && wget -q https://www.sqlite.org/2018/sqlite-autoconf-3250300.tar.gz && \
tar -zxf sqlite-autoconf-3250300.tar.gz && cd sqlite-autoconf-3250300 && \
./configure -prefix=/usr/local && make -j8 && make install && cd ../ && rm sqlite-autoconf-3250300.tar.gz && \
wget -q https://www.python.org/ftp/python/3.6.0/Python-3.6.0.tgz && \
tar -xzf Python-3.6.0.tgz && cd Python-3.6.0 && \
CFLAGS="-Wformat" ./configure --prefix=/usr/local/ --enable-shared > /dev/null && \
make -j8 > /dev/null && make altinstall > /dev/null
RUN apt-get install -y libgtk2.0-dev dmidecode python3-tk && \
pip3.6 install opencv-python && pip3.6 install /*.whl; apt-get install -f -y && \
apt-get clean -y && \
rm -f /*.whl && \
${PADDLE_VERSION} && \
ldconfig
${DOCKERFILE_CUDNN_DSO}
${DOCKERFILE_CUBLAS_DSO}
${DOCKERFILE_GPU_ENV}
ENV NCCL_LAUNCH_MODE PARALLEL
EOF
elif [ "$1" == "cp37-cp37m" ]; then
cat >> ${PADDLE_ROOT}/build/Dockerfile <<EOF
ADD python/dist/*.whl /
# run paddle version to install python packages first
RUN apt-get update && ${NCCL_DEPS}
RUN apt-get install -y make build-essential libssl-dev zlib1g-dev libbz2-dev \
libreadline-dev libsqlite3-dev wget curl llvm libncurses5-dev libncursesw5-dev \
xz-utils tk-dev libffi-dev liblzma-dev
RUN wget -q https://www.python.org/ftp/python/3.7.0/Python-3.7.0.tgz && \
tar -xzf Python-3.7.0.tgz && cd Python-3.7.0 && \
CFLAGS="-Wformat" ./configure --prefix=/usr/local/ --enable-shared > /dev/null && \
make -j8 > /dev/null && make altinstall > /dev/null
RUN apt-get install -y libgtk2.0-dev dmidecode python3-tk && \
pip3.7 install opencv-python && pip3.7 install /*.whl; apt-get install -f -y && \
apt-get clean -y && \
rm -f /*.whl && \
${PADDLE_VERSION} && \
ldconfig
${DOCKERFILE_CUDNN_DSO}
${DOCKERFILE_CUBLAS_DSO}
${DOCKERFILE_GPU_ENV}
ENV NCCL_LAUNCH_MODE PARALLEL
EOF EOF
else else
cat >> ${PADDLE_ROOT}/build/Dockerfile <<EOF cat >> ${PADDLE_ROOT}/build/Dockerfile <<EOF
......
...@@ -93,13 +93,13 @@ if(WITH_DISTRIBUTE) ...@@ -93,13 +93,13 @@ if(WITH_DISTRIBUTE)
if(NOT APPLE) if(NOT APPLE)
set_tests_properties(test_dist_mnist PROPERTIES TIMEOUT 200) set_tests_properties(test_dist_mnist PROPERTIES TIMEOUT 200)
set_tests_properties(test_dist_word2vec PROPERTIES TIMEOUT 200) set_tests_properties(test_dist_word2vec PROPERTIES TIMEOUT 200)
py_test_modules(test_dist_se_resnext MODULES test_dist_se_resnext) # FIXME(typhoonzero): add these tests back
set_tests_properties(test_dist_se_resnext PROPERTIES TIMEOUT 1000) # py_test_modules(test_dist_se_resnext MODULES test_dist_se_resnext)
# FIXME(typhoonzero): add this back # set_tests_properties(test_dist_se_resnext PROPERTIES TIMEOUT 1000)
#py_test_modules(test_dist_transformer MODULES test_dist_transformer) # py_test_modules(test_dist_transformer MODULES test_dist_transformer)
#set_tests_properties(test_dist_transformer PROPERTIES TIMEOUT 1000) # set_tests_properties(test_dist_transformer PROPERTIES TIMEOUT 1000)
# TODO(typhoonzero): make dist test parallel when fix port management issue # TODO(typhoonzero): make dist test parallel when fix port management issue
set_tests_properties(test_dist_mnist test_dist_word2vec test_dist_se_resnext test_dist_ctr test_dist_simnet_bow test_dist_save_load test_dist_text_classification test_dist_mnist_batch_merge PROPERTIES RUN_SERIAL TRUE) set_tests_properties(test_dist_mnist test_dist_word2vec test_dist_ctr test_dist_simnet_bow test_dist_save_load test_dist_text_classification test_dist_mnist_batch_merge PROPERTIES RUN_SERIAL TRUE)
endif(NOT APPLE) endif(NOT APPLE)
py_test_modules(test_dist_transpiler MODULES test_dist_transpiler) py_test_modules(test_dist_transpiler MODULES test_dist_transpiler)
endif() endif()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册