提交 f9138608 编写于 作者: T tensor-tang

jitkernel lstm refer support peephole

test=develop
上级 2f9b5f23
...@@ -236,27 +236,31 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -236,27 +236,31 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
const int D = wh_dims[0]; \ const int D = wh_dims[0]; \
const int D4 = wh_dims[1] const int D4 = wh_dims[1]
#define INIT_OTHER_DEFINES \ #define INIT_OTHER_DEFINES \
const T* x_data = x->data<T>(); \ const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \ const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \ const T* wh_data = wh->data<T>(); \
/* diagonal weight*/ \ /* diagonal weight*/ \
const T* wp_data = bias->data<T>() + D4; \ const T* wp_data = bias->data<T>() + D4; \
/* for peephole only*/ \ /* for peephole only*/ \
T* checked_cell_data = nullptr; \ T* checked_cell_data = nullptr; \
auto place = ctx.GetPlace(); \ auto place = ctx.GetPlace(); \
if (use_peepholes) { \ if (use_peepholes) { \
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \ /* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \
auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \ auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \
checked_cell_data = checked_cell->mutable_data<T>(place); \ checked_cell_data = checked_cell->mutable_data<T>(place); \
} \ } \
const auto& ker = \ const math::jitkernel::lstm_attr_t attr( \
math::jitkernel::KernelPool::Instance() \ D, ctx.Attr<std::string>("gate_activation"), \
.template Get<math::jitkernel::LSTMKernel<T>, const std::string&, \ ctx.Attr<std::string>("candidate_activation"), \
const std::string&, const std::string&>( \ ctx.Attr<std::string>("cell_activation"), use_peepholes); \
ctx.Attr<std::string>("gate_activation"), \ math::jitkernel::lstm_t one_step; \
ctx.Attr<std::string>("candidate_activation"), \ one_step.wp = wp_data; \
ctx.Attr<std::string>("cell_activation"), D, use_peepholes) one_step.checked = checked_cell_data; \
const auto& ker = \
math::jitkernel::KernelPool::Instance() \
.template Get<math::jitkernel::LSTMKernel<T>, \
const math::jitkernel::lstm_attr_t&>(attr)
// Wh GEMM // Wh GEMM
#define GEMM_WH_ADDON(bs, prev, out) \ #define GEMM_WH_ADDON(bs, prev, out) \
...@@ -299,7 +303,10 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -299,7 +303,10 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
prev_h_data = h0_data + bid * D; prev_h_data = h0_data + bid * D;
prev_c_data = c0_data + bid * D; prev_c_data = c0_data + bid * D;
} else { } else {
ker->ComputeC1H1(xx_data, c_out_data, h_out_data, wp_data); one_step.gates = xx_data;
one_step.ct = c_out_data;
one_step.ht = h_out_data;
ker->ComputeC1H1(&one_step, &attr);
tstart = 1; tstart = 1;
// move one step // move one step
prev_h_data = h_out_data; prev_h_data = h_out_data;
...@@ -310,8 +317,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -310,8 +317,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
} }
for (int step = tstart; step < seq_len; ++step) { for (int step = tstart; step < seq_len; ++step) {
GEMM_WH_ADDON(1, prev_h_data, xx_data); GEMM_WH_ADDON(1, prev_h_data, xx_data);
ker->ComputeCtHt(xx_data, prev_c_data, c_out_data, h_out_data, wp_data,
checked_cell_data); one_step.gates = xx_data;
one_step.ct_1 = prev_c_data;
one_step.ct = c_out_data;
one_step.ht = h_out_data;
ker->ComputeCtHt(&one_step, &attr);
// move one step // move one step
prev_h_data = h_out_data; prev_h_data = h_out_data;
prev_c_data = c_out_data; prev_c_data = c_out_data;
...@@ -388,7 +399,11 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -388,7 +399,11 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
T* cur_h_out_data = batched_h_out_data; T* cur_h_out_data = batched_h_out_data;
T* cur_c_out_data = batched_c_out_data; T* cur_c_out_data = batched_c_out_data;
for (int i = 0; i < max_bs; ++i) { for (int i = 0; i < max_bs; ++i) {
ker->ComputeC1H1(cur_in_data, cur_c_out_data, cur_h_out_data, wp_data); one_step.gates = cur_in_data;
one_step.ct = cur_c_out_data;
one_step.ht = cur_h_out_data;
ker->ComputeC1H1(&one_step, &attr);
cur_in_data += D4; cur_in_data += D4;
cur_c_out_data += D; cur_c_out_data += D;
cur_h_out_data += D; cur_h_out_data += D;
...@@ -413,8 +428,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -413,8 +428,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
T* cur_c_out_data = batched_c_out_data; T* cur_c_out_data = batched_c_out_data;
T* cur_h_out_data = batched_h_out_data; T* cur_h_out_data = batched_h_out_data;
for (int i = 0; i < cur_bs; ++i) { for (int i = 0; i < cur_bs; ++i) {
ker->ComputeCtHt(cur_in_data, cur_prev_c_data, cur_c_out_data, one_step.gates = cur_in_data;
cur_h_out_data, wp_data, checked_cell_data); one_step.ct_1 = cur_prev_c_data;
one_step.ct = cur_c_out_data;
one_step.ht = cur_h_out_data;
ker->ComputeCtHt(&one_step, &attr);
// move one batch // move one batch
cur_in_data += D4; cur_in_data += D4;
cur_prev_c_data += D; cur_prev_c_data += D;
......
...@@ -233,7 +233,7 @@ void LSTMJitCode::generate() { ...@@ -233,7 +233,7 @@ void LSTMJitCode::generate() {
vmovups(ymm_src, ptr[reg_ptr_gates + offset + num_]); vmovups(ymm_src, ptr[reg_ptr_gates + offset + num_]);
act<ymm_t>(ymm_i, ymm_src, act_gate_); act<ymm_t>(ymm_i, ymm_src, act_gate_);
vmulps(ymm_c, ymm_c, ymm_i); vmulps(ymm_c, ymm_c, ymm_i);
if (first_) { if (!compute_c1h1_) {
// f // f
vmovups(ymm_src, ptr[reg_ptr_gates + offset + 2 * num_]); vmovups(ymm_src, ptr[reg_ptr_gates + offset + 2 * num_]);
act<ymm_t>(ymm_f, ymm_src, act_gate_); act<ymm_t>(ymm_f, ymm_src, act_gate_);
...@@ -242,8 +242,8 @@ void LSTMJitCode::generate() { ...@@ -242,8 +242,8 @@ void LSTMJitCode::generate() {
vaddps(ymm_f, ymm_f, ymm_c); vaddps(ymm_f, ymm_f, ymm_c);
} }
/* H_t = act_cell(C_t) * ogated */ /* H_t = act_cell(C_t) * ogated */
ymm_t ymm_ct = first_ ? ymm_c : ymm_f; ymm_t ymm_ct = compute_c1h1_ ? ymm_c : ymm_f;
ymm_t ymm_o = first_ ? ymm_f : ymm_c; ymm_t ymm_o = compute_c1h1_ ? ymm_f : ymm_c;
ymm_t ymm_tmp = ymm_i; ymm_t ymm_tmp = ymm_i;
act<ymm_t>(ymm_tmp, ymm_ct, act_cell_); act<ymm_t>(ymm_tmp, ymm_ct, act_cell_);
vmovups(ymm_src, ptr[reg_ptr_gates + offset + 3 * num_]); vmovups(ymm_src, ptr[reg_ptr_gates + offset + 3 * num_]);
......
...@@ -319,6 +319,12 @@ class LSTMJitCode : public VActJitCode { ...@@ -319,6 +319,12 @@ class LSTMJitCode : public VActJitCode {
public: public:
const char* name() const override { const char* name() const override {
std::string base = "LSTMJitCode"; std::string base = "LSTMJitCode";
if (use_peephole_) {
base += "_Peephole";
}
if (compute_c1h1_) {
base += "_C1H1";
}
auto AddTypeStr = [&](operand_type type) { auto AddTypeStr = [&](operand_type type) {
switch (type) { switch (type) {
case operand_type::relu: case operand_type::relu:
...@@ -340,30 +346,42 @@ class LSTMJitCode : public VActJitCode { ...@@ -340,30 +346,42 @@ class LSTMJitCode : public VActJitCode {
break; break;
} }
}; };
if (first_) {
base += "_C1H1";
}
AddTypeStr(act_gate_); AddTypeStr(act_gate_);
AddTypeStr(act_cand_); AddTypeStr(act_cand_);
AddTypeStr(act_cell_); AddTypeStr(act_cell_);
return base.c_str(); return base.c_str();
} }
explicit LSTMJitCode(int d, bool first, operand_type act_gate, explicit LSTMJitCode(bool compute_c1h1, const lstm_attr_t& attr,
operand_type act_cand, operand_type act_cell,
size_t code_size = 256 * 1024, void* code_ptr = nullptr) size_t code_size = 256 * 1024, void* code_ptr = nullptr)
: VActJitCode(d, act_gate, code_size, code_ptr), : VActJitCode(attr.d, operand_type::sigmoid /* this is bugy*/, code_size,
num_(d), code_ptr),
first_(first), compute_c1h1_(compute_c1h1) {
act_gate_(act_gate), auto typeExchange = [](const std::string& type) -> gen::operand_type {
act_cand_(act_cand), if (type == "sigmoid") {
act_cell_(act_cell) {} return operand_type::sigmoid;
} else if (type == "relu") {
return operand_type::relu;
} else if (type == "tanh") {
return operand_type::tanh;
} else if (type == "identity" || type == "") {
return operand_type::identity;
} // else throw error
return operand_type::identity;
};
num_ = attr.d;
use_peephole_ = attr.use_peephole;
act_gate_ = typeExchange(attr.act_gate);
act_cand_ = typeExchange(attr.act_cand);
act_cell_ = typeExchange(attr.act_cell);
}
static bool init(int d); static bool init(int d);
void generate() override; void generate() override;
protected: protected:
int num_; int num_;
bool first_; bool compute_c1h1_;
bool use_peephole_;
operand_type act_gate_; operand_type act_gate_;
operand_type act_cand_; operand_type act_cand_;
operand_type act_cell_; operand_type act_cell_;
......
...@@ -122,18 +122,9 @@ class VTanhKernel : public VActKernel<T> {}; ...@@ -122,18 +122,9 @@ class VTanhKernel : public VActKernel<T> {};
template <typename T> template <typename T>
class LSTMKernel : public Kernel { class LSTMKernel : public Kernel {
public: public:
virtual void ComputeCtHt(T *gates, const T *ct_1, T *ct, T *ht, void (*ComputeCtHt)(lstm_t *, const lstm_attr_t *);
/* below only used in peephole*/ // compute c1 and h1 without c0 or h0
const T *wp_data = nullptr, void (*ComputeC1H1)(lstm_t *, const lstm_attr_t *);
T *checked = nullptr) const = 0;
virtual void ComputeC1H1(T *gates, T *ct, T *ht,
/* below only used in peephole*/
const T *wp_data = nullptr) const = 0;
// void (*ComputeCtHt)(lstm_t *);
// // compute c1 and h1 without c0 or h0
// void (*ComputeC1H1)(lstm_t *);
}; };
template <typename T> template <typename T>
......
...@@ -33,18 +33,24 @@ typedef struct { ...@@ -33,18 +33,24 @@ typedef struct {
const void* ct_1; const void* ct_1;
void* ct; void* ct;
void* ht; void* ht;
/* below only used in peephole*/ /* weight_peephole and checked data are only used in peephole*/
const void* wp_data{nullptr}; const void* wp{nullptr};
void* checked{nullptr}; void* checked{nullptr};
} lstm_t; } lstm_t;
typedef struct lstm_attr_s { typedef struct lstm_attr_s {
bool use_peephole;
int d; int d;
std::string act_gate, act_cand, act_cell; std::string act_gate, act_cand, act_cell;
lstm_attr_s() = default; lstm_attr_s() = default;
lstm_attr_s(int _d, const std::string& _act_gate, lstm_attr_s(int _d, const std::string& _act_gate,
const std::string& _act_cand, const std::string& _act_cell) const std::string& _act_cand, const std::string& _act_cell,
: d(_d), act_gate(_act_gate), act_cand(_act_cand), act_cell(_act_cell) {} bool _use_peephole = false)
: use_peephole(_use_peephole),
d(_d),
act_gate(_act_gate),
act_cand(_act_cand),
act_cell(_act_cell) {}
} lstm_attr_t; } lstm_attr_t;
} // namespace jitkernel } // namespace jitkernel
......
...@@ -82,10 +82,10 @@ namespace jitkernel { ...@@ -82,10 +82,10 @@ namespace jitkernel {
#define REGISTER_JITKERNEL_ARGS(ker_key, ker_class, marco_define_name, \ #define REGISTER_JITKERNEL_ARGS(ker_key, ker_class, marco_define_name, \
marco_declare, macro_find_key, macro_impl) \ marco_declare, macro_find_key, macro_impl) \
marco_define_name(ker_key, ker_class); \ marco_define_name(ker_key, ker_class); \
REGISTER_JITKERNEL_WITH_DTYPE(ker_class, float, JITKERNEL_DECLARE, \ REGISTER_JITKERNEL_WITH_DTYPE(ker_class, float, marco_declare, \
JITKERNEL_FIND_KEY, JITKERNEL_IMPL); \ macro_find_key, macro_impl); \
REGISTER_JITKERNEL_WITH_DTYPE(ker_class, double, JITKERNEL_DECLARE, \ REGISTER_JITKERNEL_WITH_DTYPE(ker_class, double, marco_declare, \
JITKERNEL_FIND_KEY, JITKERNEL_IMPL) macro_find_key, macro_impl)
#define REGISTER_JITKERNEL(ker_key, ker_class) \ #define REGISTER_JITKERNEL(ker_key, ker_class) \
REGISTER_JITKERNEL_ARGS(ker_key, ker_class, JITKERNEL_DEFINE_NAME, \ REGISTER_JITKERNEL_ARGS(ker_key, ker_class, JITKERNEL_DEFINE_NAME, \
......
...@@ -117,11 +117,13 @@ void (*getActFunc(const std::string& type))(const T*, T*, int) { // NOLINT ...@@ -117,11 +117,13 @@ void (*getActFunc(const std::string& type))(const T*, T*, int) { // NOLINT
} }
template <typename T> template <typename T>
void LSTMCtHt(lstm_t* step, lstm_attr_t* attr) { void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) {
T* gates = reinterpret_cast<T*>(step->gates); T* gates = reinterpret_cast<T*>(step->gates);
const T* ct_1 = reinterpret_cast<const T*>(step->ct_1); const T* ct_1 = reinterpret_cast<const T*>(step->ct_1);
T* ct = reinterpret_cast<T*>(step->ct); T* ct = reinterpret_cast<T*>(step->ct);
T* ht = reinterpret_cast<T*>(step->ht); T* ht = reinterpret_cast<T*>(step->ht);
const T* wp = reinterpret_cast<const T*>(step->wp);
T* checked = reinterpret_cast<T*>(step->checked);
auto act_gate = getActFunc<T>(attr->act_gate); auto act_gate = getActFunc<T>(attr->act_gate);
auto act_cand = getActFunc<T>(attr->act_cand); auto act_cand = getActFunc<T>(attr->act_cand);
auto act_cell = getActFunc<T>(attr->act_cell); auto act_cell = getActFunc<T>(attr->act_cell);
...@@ -129,23 +131,36 @@ void LSTMCtHt(lstm_t* step, lstm_attr_t* attr) { ...@@ -129,23 +131,36 @@ void LSTMCtHt(lstm_t* step, lstm_attr_t* attr) {
int d2 = d * 2; int d2 = d * 2;
int d3 = d * 3; int d3 = d * 3;
// gates: W_ch, W_ih, W_fh, W_oh // gates: W_ch, W_ih, W_fh, W_oh
act_gate(gates + d, gates + d, d3); if (attr->use_peephole) {
VMul(wp, ct_1, checked, d);
VMul(wp + d, ct_1, checked + d, d);
VAdd(checked, gates + d, gates + d, d2);
act_gate(gates + d, gates + d, d2);
} else {
act_gate(gates + d, gates + d, d3);
}
/* C_t = C_t-1 * fgated + cand_gated * igated */ // C_t = C_t-1 * fgated + cand_gated * igated
act_cand(gates, gates, d); act_cand(gates, gates, d);
VMul(gates, gates + d, gates + d, d); VMul(gates, gates + d, gates + d, d);
VMul(ct_1, gates + d2, gates + d2, d); VMul(ct_1, gates + d2, gates + d2, d);
VAdd(gates + d, gates + d2, ct, d); VAdd(gates + d, gates + d2, ct, d);
/* H_t = act_cell(C_t) * ogated */ if (attr->use_peephole) {
// get ogated
VMul(wp + d2, ct, gates + d, d);
VAdd(gates + d, gates + d3, gates + d3, d);
act_gate(gates + d3, gates + d3, d);
}
// H_t = act_cell(C_t) * ogated
act_cell(ct, gates + d2, d); act_cell(ct, gates + d2, d);
VMul(gates + d2, gates + d3, ht, d); VMul(gates + d2, gates + d3, ht, d);
} }
// compute c1 and h1 without c0 or h0
template <typename T> template <typename T>
void LSTMC1H1(lstm_t* step, lstm_attr_t* attr) { void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr) {
T* gates = reinterpret_cast<T*>(step->gates); T* gates = reinterpret_cast<T*>(step->gates);
const T* ct_1 = reinterpret_cast<const T*>(step->ct_1);
T* ct = reinterpret_cast<T*>(step->ct); T* ct = reinterpret_cast<T*>(step->ct);
T* ht = reinterpret_cast<T*>(step->ht); T* ht = reinterpret_cast<T*>(step->ht);
auto act_gate = getActFunc<T>(attr->act_gate); auto act_gate = getActFunc<T>(attr->act_gate);
...@@ -158,10 +173,16 @@ void LSTMC1H1(lstm_t* step, lstm_attr_t* attr) { ...@@ -158,10 +173,16 @@ void LSTMC1H1(lstm_t* step, lstm_attr_t* attr) {
act_gate(gates + d, gates + d, d); act_gate(gates + d, gates + d, d);
act_cand(gates, gates, d); act_cand(gates, gates, d);
VMul(gates, gates + d, ct, d); VMul(gates, gates + d, ct, d);
if (attr->use_peephole) {
// get outgated, put W_oc * C_t on igated
const T* wp = reinterpret_cast<const T*>(step->wp);
VMul(wp + d2, ct, gates + d, d);
VAdd(gates + d, gates + d3, gates + d3, d);
}
/* H_t = act_cell(C_t) * ogated */ /* H_t = act_cell(C_t) * ogated */
act_gate(gates + d3, gates + d3, d); act_gate(gates + d3, gates + d3, d);
act_cell(ct, gates + d2, d); act_cell(ct, gates + d2, d);
Vmul(gates + d2, gates + d3, ht, d); VMul(gates + d2, gates + d3, ht, d);
} }
} // namespace refer } // namespace refer
......
...@@ -15,9 +15,14 @@ limitations under the License. */ ...@@ -15,9 +15,14 @@ limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h" #include "paddle/fluid/operators/math/jit_kernel.h"
#include <string> #include <string>
#include "paddle/fluid/operators/math/jit_kernel_macro.h" #include "paddle/fluid/operators/math/jit_kernel_macro.h"
#include "paddle/fluid/operators/math/jit_kernel_refer.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
#ifdef PADDLE_WITH_XBYAK
#include "paddle/fluid/operators/math/jit_code.h"
#endif
#ifdef __AVX__ #ifdef __AVX__
#include <immintrin.h> #include <immintrin.h>
#endif #endif
...@@ -154,211 +159,136 @@ static std::unique_ptr<AVXAct> GetAVXAct(const std::string& type) { ...@@ -154,211 +159,136 @@ static std::unique_ptr<AVXAct> GetAVXAct(const std::string& type) {
#endif #endif
/* LSTM JitKernel */ /* LSTM JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block> template <typename T>
class LSTMKernelImpl : public LSTMKernel<T> { class LSTMKernelImpl : public LSTMKernel<T> {
public: public:
explicit LSTMKernelImpl(const std::string& act_gate, static inline std::string name(const lstm_attr_t& attr) {
const std::string& act_cand, PADDLE_THROW("DType should be either float or double");
const std::string& act_cell, int d)
: LSTMKernel<T>() {
d_ = d;
d2_ = d * 2;
d3_ = d * 3;
act_gate_d3_ = GetActKernel<T>(act_gate, d3_);
act_gate_d_ = GetActKernel<T>(act_gate, d);
act_cand_d_ = GetActKernel<T>(act_cand, d);
act_cell_d_ = GetActKernel<T>(act_cell, d);
vmul_d_ = KernelPool::Instance().template Get<VMulKernel<T>>(d);
vadd_d_ = KernelPool::Instance().template Get<VAddKernel<T>>(d);
} }
static inline bool useJIT(int d) { return false; }
static inline bool useMKL(int d) { return false; }
explicit LSTMKernelImpl(const lstm_attr_t& attr) : LSTMKernel<T>() {
#ifdef PADDLE_WITH_XBYAK
if (useJIT(attr.d)) {
size_t sz = 96 + attr.d / YMM_FLOAT_BLOCK * 84 * 8; // should change
jitcode0_.reset(new gen::LSTMJitCode(false, attr, sz > 4096 ? sz : 4096));
this->ComputeCtHt =
jitcode0_->getCode<void (*)(lstm_t*, const lstm_attr_t*)>();
jitcode1_.reset(new gen::LSTMJitCode(true, attr, sz > 4096 ? sz : 4096));
this->ComputeC1H1 =
jitcode1_->getCode<void (*)(lstm_t*, const lstm_attr_t*)>();
return;
}
#endif
void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data, this->ComputeCtHt = refer::LSTMCtHt<T>;
T* checked) const override { this->ComputeC1H1 = refer::LSTMC1H1<T>;
// gates: W_ch, W_ih, W_fh, W_oh
act_gate_d3_->Compute(gates + d_, gates + d_, d3_);
/* C_t = C_t-1 * fgated + cand_gated * igated */
act_cand_d_->Compute(gates, gates, d_);
vmul_d_->Compute(gates, gates + d_, gates + d_, d_);
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_);
vadd_d_->Compute(gates + d_, gates + d2_, ct, d_);
/* H_t = act_cell(C_t) * ogated */
act_cell_d_->Compute(ct, gates + d2_, d_);
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
}
void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override {
/* C_t = igated * cgated*/
act_gate_d_->Compute(gates + d_, gates + d_, d_);
act_cand_d_->Compute(gates, gates, d_);
vmul_d_->Compute(gates, gates + d_, ct, d_);
/* H_t = act_cell(C_t) * ogated */
act_gate_d_->Compute(gates + d3_, gates + d3_, d_);
act_cell_d_->Compute(ct, gates + d2_, d_);
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
} }
#ifdef PADDLE_WITH_XBYAK
private: private:
int d_, d2_, d3_; std::unique_ptr<gen::LSTMJitCode> jitcode0_{nullptr}, jitcode1_{nullptr};
std::shared_ptr<const VActKernel<T>> act_gate_d3_, act_gate_d_, act_cand_d_,
act_cell_d_;
std::shared_ptr<const VMulKernel<T>> vmul_d_;
std::shared_ptr<const VAddKernel<T>> vadd_d_;
#ifdef __AVX__
std::unique_ptr<const AVXAct> avx_act_gate_, avx_act_cand_, avx_act_cell_;
#endif #endif
}; };
#define INTRI8_FLOAT(isa) \ #ifdef PADDLE_WITH_XBYAK
template <> \ template <>
LSTMKernelImpl<float, isa, kEQ8>::LSTMKernelImpl( \ bool LSTMKernelImpl<float>::useJIT(int d) {
const std::string& act_gate, const std::string& act_cand, \ return false; // not ready yet gen::LSTMJitCode::init(d);
const std::string& act_cell, int d) \ }
: LSTMKernel<float>() { \
avx_act_gate_ = GetAVXAct<isa>(act_gate); \
avx_act_cand_ = GetAVXAct<isa>(act_cand); \
avx_act_cell_ = GetAVXAct<isa>(act_cell); \
} \
template <> \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \
float* gates, const float* ct_1, float* ct, float* ht, \
const float* wp_data, float* checked) const { \
/* gates: W_ch, W_ih, W_fh, W_oh */ \
__m256 c, i, f, o; \
c = _mm256_loadu_ps(gates); \
i = _mm256_loadu_ps(gates + 8); \
f = _mm256_loadu_ps(gates + 16); \
o = _mm256_loadu_ps(gates + 24); \
/* C_t = C_t-1 * fgated + cand_gated * igated*/ \
c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \
i = _mm256_loadu_ps(ct_1); \
f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \
f = _mm256_add_ps(c, f); \
_mm256_storeu_ps(ct, f); \
/* H_t = act_cell(C_t) * ogated */ \
o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \
_mm256_storeu_ps(ht, o); \
} \
template <> \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeC1H1( \
float* gates, float* ct, float* ht, const float* wp_data) const { \
__m256 c, i, o; \
c = _mm256_loadu_ps(gates); \
i = _mm256_loadu_ps(gates + 8); \
o = _mm256_loadu_ps(gates + 24); \
/* C_t = igated * cgated*/ \
c = _mm256_mul_ps(avx_act_gate_->Compute(i), avx_act_cand_->Compute(c)); \
_mm256_storeu_ps(ct, c); \
/* H_t = act_cell(C_t) * ogated */ \
o = _mm256_mul_ps(avx_act_cell_->Compute(c), avx_act_gate_->Compute(o)); \
_mm256_storeu_ps(ht, o); \
}
// TODO(TJ): optimize keq16
#ifdef __AVX__
INTRI8_FLOAT(jit::avx);
#endif
#ifdef __AVX2__
INTRI8_FLOAT(jit::avx2);
#endif
#ifdef __AVX512F__
INTRI8_FLOAT(jit::avx512f);
#endif #endif
/* Peephole JitKernel */ /* Peephole JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block> template <typename T>
class PeepholeKernelImpl : public LSTMKernel<T> { class PeepholeKernelImpl : public LSTMKernel<T> {
public: public:
explicit PeepholeKernelImpl(const std::string& act_gate, static inline std::string name(const lstm_attr_t& attr) {
const std::string& act_cand, PADDLE_THROW("DType should be either float or double");
const std::string& act_cell, int d)
: LSTMKernel<T>() {
d_ = d;
d2_ = d * 2;
d3_ = d * 3;
act_gate_d_ = GetActKernel<T>(act_gate, d);
act_cand_d_ = GetActKernel<T>(act_cand, d);
act_cell_d_ = GetActKernel<T>(act_cell, d);
vmul_d_ = KernelPool::Instance().template Get<VMulKernel<T>>(d);
vadd_d_ = KernelPool::Instance().template Get<VAddKernel<T>>(d);
vadd_d2_ = KernelPool::Instance().template Get<VAddKernel<T>>(d2_);
act_gate_d2_ = GetActKernel<T>(act_gate, d2_);
} }
static inline bool useJIT(int d) { return false; }
static inline bool useMKL(int d) { return false; }
explicit PeepholeKernelImpl(const lstm_attr_t& attr) : LSTMKernel<T>() {
#ifdef PADDLE_WITH_XBYAK
if (useJIT(attr.d)) {
size_t sz = 96 + attr.d / YMM_FLOAT_BLOCK * 84 * 8; // should change
jitcode0_.reset(new gen::LSTMJitCode(false, attr, sz > 4096 ? sz : 4096));
this->ComputeCtHt =
jitcode0_->getCode<void (*)(lstm_t*, const lstm_attr_t*)>();
jitcode1_.reset(new gen::LSTMJitCode(true, attr, sz > 4096 ? sz : 4096));
this->ComputeC1H1 =
jitcode1_->getCode<void (*)(lstm_t*, const lstm_attr_t*)>();
return;
}
#endif
void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data, this->ComputeCtHt = refer::LSTMCtHt<T>;
T* checked) const override { this->ComputeC1H1 = refer::LSTMC1H1<T>;
/* get fgated and igated*/
vmul_d_->Compute(wp_data, ct_1, checked, d_);
vmul_d_->Compute(wp_data + d_, ct_1, checked + d_, d_);
vadd_d2_->Compute(checked, gates + d_, gates + d_, d2_);
act_gate_d2_->Compute(gates + d_, gates + d_, d2_);
/* C_t = C_t-1 * fgated + cand_gated * igated*/
act_cand_d_->Compute(gates, gates, d_);
vmul_d_->Compute(gates, gates + d_, gates + d_, d_);
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_);
vadd_d_->Compute(gates + d_, gates + d2_, ct, d_);
/* get ogated*/
vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_);
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_);
act_gate_d_->Compute(gates + d3_, gates + d3_, d_);
/* H_t = act_cell(C_t) * ogated */
act_cell_d_->Compute(ct, gates + d2_, d_);
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
} }
void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override { #ifdef PADDLE_WITH_XBYAK
/* C_t = igated * cgated*/
act_gate_d_->Compute(gates + d_, gates + d_, d_);
act_cand_d_->Compute(gates, gates, d_);
vmul_d_->Compute(gates, gates + d_, ct, d_);
/* get outgated, put W_oc * C_t on igated */
vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_);
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_);
/* H_t = act_cell(C_t) * ogated */
act_gate_d_->Compute(gates + d3_, gates + d3_, d_);
act_cell_d_->Compute(ct, gates + d2_, d_);
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
}
private: private:
int d_, d2_, d3_; std::unique_ptr<gen::LSTMJitCode> jitcode0_{nullptr}, jitcode1_{nullptr};
std::shared_ptr<const VActKernel<T>> act_gate_d2_, act_gate_d_, act_cand_d_, #endif
act_cell_d_;
std::shared_ptr<const VMulKernel<T>> vmul_d_;
std::shared_ptr<const VAddKernel<T>> vadd_d_, vadd_d2_;
}; };
#define JITKERNEL_DECLARE_LSTM(ker_class, ker_dtype) \ #ifdef PADDLE_WITH_XBYAK
template <> \ template <>
std::shared_ptr<const LSTMKernel<ker_dtype>> \ bool PeepholeKernelImpl<float>::useJIT(int d) {
KernelPool::Get<LSTMKernel<ker_dtype>, const std::string&, \ return false; // peephole jitcode not ready yet
const std::string&, const std::string&, int, bool>( \ }
const std::string& act_gate, const std::string& act_cand, \ #endif
const std::string& act_cell, int d, bool use_peephole)
#define JITKERNEL_DEFINE_NAME_LSTM(ker_key, ker_class) \
#define JITKERNEL_KEY_LSTM(ker_key, dtype_key) \ template <> \
#ker_key #dtype_key + std::to_string(d) + act_gate + act_cand + act_cell + \ std::string ker_class##Impl<float>::name(const lstm_attr_t& attr) { \
(use_peephole ? "p" : "n") std::string key(#ker_key "f"); \
key += (attr.act_gate + attr.act_cand + attr.act_cell + \
#define JITKERNEL_NEW_LSTM_IMPL(ker, dtype, isa, k) \ (attr.use_peephole ? "p" : "n")); \
if (use_peephole) { \ if (useJIT(attr.d)) { \
p = std::dynamic_pointer_cast<ker<dtype>>( \ /* only jit code need record d*/ \
std::make_shared<PeepholeKernelImpl<dtype, isa, k>>( \ return key + "jit" + std::to_string(attr.d); \
act_gate, act_cand, act_cell, d)); \ } else if (useMKL(attr.d)) { \
} else { \ return key + "mkl"; \
p = std::dynamic_pointer_cast<ker<dtype>>( \ } else { \
std::make_shared<ker##Impl<dtype, isa, k>>(act_gate, act_cand, \ return key + "any"; \
act_cell, d)); \ } \
} \
template <> \
std::string ker_class##Impl<double>::name(const lstm_attr_t& attr) { \
std::string key(#ker_key "d"); \
/* jit code do not support double yet*/ \
if (useMKL(attr.d)) { \
return key + "mkl"; \
} else { \
return key + "any"; \
} \
} }
REGISTER_JITKERNEL_ARGS_DEPRECATED(lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM, #define JITKERNEL_DECLARE_LSTM(ker_class, ker_dtype) \
JITKERNEL_KEY_LSTM, JITKERNEL_NEW_LSTM_IMPL); template <> \
std::shared_ptr<const LSTMKernel<ker_dtype>> \
KernelPool::Get<LSTMKernel<ker_dtype>, const lstm_attr_t&>( \
const lstm_attr_t& attr)
#define JITKERNEL_FIND_KEY_LSTM(ker_class, ker_dtype) \
std::string key = ker_class##Impl<ker_dtype>::name(attr)
#define JITKERNEL_LSTM_IMPL(ker, dtype) \
if (attr.use_peephole) { \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<PeepholeKernelImpl<dtype>>(attr)); \
} else { \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype>>(attr)); \
}
#undef INTRI8_FLOAT REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DEFINE_NAME_LSTM,
#undef JITKERNEL_DECLARE_LSTM JITKERNEL_DECLARE_LSTM, JITKERNEL_FIND_KEY_LSTM,
#undef JITKERNEL_KEY_LSTM JITKERNEL_LSTM_IMPL);
#undef JITKERNEL_NEW_LSTM_IMPL
/* GRU JitKernel */ /* GRU JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block> template <typename T, jit::cpu_isa_t isa, jit_block>
......
...@@ -341,11 +341,11 @@ TEST(JitKernel, lstm) { ...@@ -341,11 +341,11 @@ TEST(JitKernel, lstm) {
RandomVec<float>(d, ct_1.data(), -2.f, 2.f); RandomVec<float>(d, ct_1.data(), -2.f, 2.f);
memcpy(xref.data(), x.data(), sizeof(float) * d4); memcpy(xref.data(), x.data(), sizeof(float) * d4);
std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh"; std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh";
const jit::lstm_attr_t attr(d, act_gate, act_cand, act_cell, false);
const auto& ker = const auto& ker =
jit::KernelPool::Instance() jit::KernelPool::Instance()
.template Get<jit::LSTMKernel<float>, const std::string&, .template Get<jit::LSTMKernel<float>, const jit::lstm_attr_t&>(
const std::string&, const std::string&>( attr);
act_gate, act_cand, act_cell, d, false);
// below kernels are used to compute refer // below kernels are used to compute refer
const auto& vsigmoid_3d = const auto& vsigmoid_3d =
jit::KernelPool::Instance().template Get<jit::VSigmoidKernel<float>>( jit::KernelPool::Instance().template Get<jit::VSigmoidKernel<float>>(
...@@ -366,14 +366,16 @@ TEST(JitKernel, lstm) { ...@@ -366,14 +366,16 @@ TEST(JitKernel, lstm) {
float* ht_ref_data = ht_ref.data(); float* ht_ref_data = ht_ref.data();
// compute once to check correctness // compute once to check correctness
jit::lstm_t step; jit::lstm_t step;
jit::lstm_attr_t attr(d, act_gate, act_cand, act_cell);
step.gates = xref_data; step.gates = xref_data;
step.ct_1 = ct_1_data; step.ct_1 = ct_1_data;
step.ct = ct_ref_data; step.ct = ct_ref_data;
step.ht = ht_ref_data; step.ht = ht_ref_data;
refer::LSTMCtHt<float>(&step, &attr); refer::LSTMCtHt<float>(&step, &attr);
ker->ComputeCtHt(x_data, ct_1_data, ct_tgt_data, ht_tgt_data); step.gates = x_data;
step.ct = ct_tgt_data;
step.ht = ht_tgt_data;
ker->ComputeCtHt(&step, &attr);
for (int i = 0; i < d; ++i) { for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ct_tgt_data[i], ct_ref_data[i], 1e-3); EXPECT_NEAR(ct_tgt_data[i], ct_ref_data[i], 1e-3);
EXPECT_NEAR(ht_tgt_data[i], ht_ref_data[i], 1e-3); EXPECT_NEAR(ht_tgt_data[i], ht_ref_data[i], 1e-3);
...@@ -392,7 +394,7 @@ TEST(JitKernel, lstm) { ...@@ -392,7 +394,7 @@ TEST(JitKernel, lstm) {
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
auto ttgts = GetCurrentUS(); auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
ker->ComputeCtHt(x_data, ct_1_data, ct_tgt_data, ht_tgt_data); ker->ComputeCtHt(&step, &attr);
} }
auto ttgte = GetCurrentUS(); auto ttgte = GetCurrentUS();
VLOG(30) << "Vec size " << d VLOG(30) << "Vec size " << d
...@@ -710,21 +712,21 @@ TEST(JitKernel, pool) { ...@@ -710,21 +712,21 @@ TEST(JitKernel, pool) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
const int frame_size = 4; const int frame_size = 4;
std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh"; std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh";
jit::lstm_attr_t attr(frame_size, act_gate, act_cand, act_cell, false);
const auto& plstm1 = const auto& plstm1 =
jit::KernelPool::Instance() jit::KernelPool::Instance()
.template Get<jit::LSTMKernel<float>, const std::string&, .template Get<jit::LSTMKernel<float>, const jit::lstm_attr_t&>(attr);
const std::string&, const std::string&>(
act_gate, act_cand, act_cell, frame_size, false);
const auto& plstm2 = const auto& plstm2 =
jit::KernelPool::Instance() jit::KernelPool::Instance()
.template Get<jit::LSTMKernel<float>, const std::string&, .template Get<jit::LSTMKernel<float>, const jit::lstm_attr_t&>(attr);
const std::string&, const std::string&>( EXPECT_EQ(plstm1, plstm2);
act_gate, act_cand, act_cell, frame_size, false);
const auto& peephole = const auto& peephole =
jit::KernelPool::Instance() jit::KernelPool::Instance()
.template Get<jit::LSTMKernel<float>, const std::string&, .template Get<jit::LSTMKernel<float>, const jit::lstm_attr_t&>(
const std::string&, const std::string&>( jit::lstm_attr_t(frame_size, act_gate, act_cand, act_cell, true));
act_gate, act_cand, act_cell, frame_size, true);
EXPECT_TRUE(plstm1 != peephole); EXPECT_TRUE(plstm1 != peephole);
const auto& pvmul_f = const auto& pvmul_f =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册