提交 f9138608 编写于 作者: T tensor-tang

jitkernel lstm refer support peephole

test=develop
上级 2f9b5f23
......@@ -236,27 +236,31 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
const int D = wh_dims[0]; \
const int D4 = wh_dims[1]
#define INIT_OTHER_DEFINES \
const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \
/* diagonal weight*/ \
const T* wp_data = bias->data<T>() + D4; \
/* for peephole only*/ \
T* checked_cell_data = nullptr; \
auto place = ctx.GetPlace(); \
if (use_peepholes) { \
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \
auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \
checked_cell_data = checked_cell->mutable_data<T>(place); \
} \
const auto& ker = \
math::jitkernel::KernelPool::Instance() \
.template Get<math::jitkernel::LSTMKernel<T>, const std::string&, \
const std::string&, const std::string&>( \
ctx.Attr<std::string>("gate_activation"), \
ctx.Attr<std::string>("candidate_activation"), \
ctx.Attr<std::string>("cell_activation"), D, use_peepholes)
#define INIT_OTHER_DEFINES \
const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \
/* diagonal weight*/ \
const T* wp_data = bias->data<T>() + D4; \
/* for peephole only*/ \
T* checked_cell_data = nullptr; \
auto place = ctx.GetPlace(); \
if (use_peepholes) { \
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \
auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \
checked_cell_data = checked_cell->mutable_data<T>(place); \
} \
const math::jitkernel::lstm_attr_t attr( \
D, ctx.Attr<std::string>("gate_activation"), \
ctx.Attr<std::string>("candidate_activation"), \
ctx.Attr<std::string>("cell_activation"), use_peepholes); \
math::jitkernel::lstm_t one_step; \
one_step.wp = wp_data; \
one_step.checked = checked_cell_data; \
const auto& ker = \
math::jitkernel::KernelPool::Instance() \
.template Get<math::jitkernel::LSTMKernel<T>, \
const math::jitkernel::lstm_attr_t&>(attr)
// Wh GEMM
#define GEMM_WH_ADDON(bs, prev, out) \
......@@ -299,7 +303,10 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
prev_h_data = h0_data + bid * D;
prev_c_data = c0_data + bid * D;
} else {
ker->ComputeC1H1(xx_data, c_out_data, h_out_data, wp_data);
one_step.gates = xx_data;
one_step.ct = c_out_data;
one_step.ht = h_out_data;
ker->ComputeC1H1(&one_step, &attr);
tstart = 1;
// move one step
prev_h_data = h_out_data;
......@@ -310,8 +317,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
}
for (int step = tstart; step < seq_len; ++step) {
GEMM_WH_ADDON(1, prev_h_data, xx_data);
ker->ComputeCtHt(xx_data, prev_c_data, c_out_data, h_out_data, wp_data,
checked_cell_data);
one_step.gates = xx_data;
one_step.ct_1 = prev_c_data;
one_step.ct = c_out_data;
one_step.ht = h_out_data;
ker->ComputeCtHt(&one_step, &attr);
// move one step
prev_h_data = h_out_data;
prev_c_data = c_out_data;
......@@ -388,7 +399,11 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
T* cur_h_out_data = batched_h_out_data;
T* cur_c_out_data = batched_c_out_data;
for (int i = 0; i < max_bs; ++i) {
ker->ComputeC1H1(cur_in_data, cur_c_out_data, cur_h_out_data, wp_data);
one_step.gates = cur_in_data;
one_step.ct = cur_c_out_data;
one_step.ht = cur_h_out_data;
ker->ComputeC1H1(&one_step, &attr);
cur_in_data += D4;
cur_c_out_data += D;
cur_h_out_data += D;
......@@ -413,8 +428,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
T* cur_c_out_data = batched_c_out_data;
T* cur_h_out_data = batched_h_out_data;
for (int i = 0; i < cur_bs; ++i) {
ker->ComputeCtHt(cur_in_data, cur_prev_c_data, cur_c_out_data,
cur_h_out_data, wp_data, checked_cell_data);
one_step.gates = cur_in_data;
one_step.ct_1 = cur_prev_c_data;
one_step.ct = cur_c_out_data;
one_step.ht = cur_h_out_data;
ker->ComputeCtHt(&one_step, &attr);
// move one batch
cur_in_data += D4;
cur_prev_c_data += D;
......
......@@ -233,7 +233,7 @@ void LSTMJitCode::generate() {
vmovups(ymm_src, ptr[reg_ptr_gates + offset + num_]);
act<ymm_t>(ymm_i, ymm_src, act_gate_);
vmulps(ymm_c, ymm_c, ymm_i);
if (first_) {
if (!compute_c1h1_) {
// f
vmovups(ymm_src, ptr[reg_ptr_gates + offset + 2 * num_]);
act<ymm_t>(ymm_f, ymm_src, act_gate_);
......@@ -242,8 +242,8 @@ void LSTMJitCode::generate() {
vaddps(ymm_f, ymm_f, ymm_c);
}
/* H_t = act_cell(C_t) * ogated */
ymm_t ymm_ct = first_ ? ymm_c : ymm_f;
ymm_t ymm_o = first_ ? ymm_f : ymm_c;
ymm_t ymm_ct = compute_c1h1_ ? ymm_c : ymm_f;
ymm_t ymm_o = compute_c1h1_ ? ymm_f : ymm_c;
ymm_t ymm_tmp = ymm_i;
act<ymm_t>(ymm_tmp, ymm_ct, act_cell_);
vmovups(ymm_src, ptr[reg_ptr_gates + offset + 3 * num_]);
......
......@@ -319,6 +319,12 @@ class LSTMJitCode : public VActJitCode {
public:
const char* name() const override {
std::string base = "LSTMJitCode";
if (use_peephole_) {
base += "_Peephole";
}
if (compute_c1h1_) {
base += "_C1H1";
}
auto AddTypeStr = [&](operand_type type) {
switch (type) {
case operand_type::relu:
......@@ -340,30 +346,42 @@ class LSTMJitCode : public VActJitCode {
break;
}
};
if (first_) {
base += "_C1H1";
}
AddTypeStr(act_gate_);
AddTypeStr(act_cand_);
AddTypeStr(act_cell_);
return base.c_str();
}
explicit LSTMJitCode(int d, bool first, operand_type act_gate,
operand_type act_cand, operand_type act_cell,
explicit LSTMJitCode(bool compute_c1h1, const lstm_attr_t& attr,
size_t code_size = 256 * 1024, void* code_ptr = nullptr)
: VActJitCode(d, act_gate, code_size, code_ptr),
num_(d),
first_(first),
act_gate_(act_gate),
act_cand_(act_cand),
act_cell_(act_cell) {}
: VActJitCode(attr.d, operand_type::sigmoid /* this is bugy*/, code_size,
code_ptr),
compute_c1h1_(compute_c1h1) {
auto typeExchange = [](const std::string& type) -> gen::operand_type {
if (type == "sigmoid") {
return operand_type::sigmoid;
} else if (type == "relu") {
return operand_type::relu;
} else if (type == "tanh") {
return operand_type::tanh;
} else if (type == "identity" || type == "") {
return operand_type::identity;
} // else throw error
return operand_type::identity;
};
num_ = attr.d;
use_peephole_ = attr.use_peephole;
act_gate_ = typeExchange(attr.act_gate);
act_cand_ = typeExchange(attr.act_cand);
act_cell_ = typeExchange(attr.act_cell);
}
static bool init(int d);
void generate() override;
protected:
int num_;
bool first_;
bool compute_c1h1_;
bool use_peephole_;
operand_type act_gate_;
operand_type act_cand_;
operand_type act_cell_;
......
......@@ -122,18 +122,9 @@ class VTanhKernel : public VActKernel<T> {};
template <typename T>
class LSTMKernel : public Kernel {
public:
virtual void ComputeCtHt(T *gates, const T *ct_1, T *ct, T *ht,
/* below only used in peephole*/
const T *wp_data = nullptr,
T *checked = nullptr) const = 0;
virtual void ComputeC1H1(T *gates, T *ct, T *ht,
/* below only used in peephole*/
const T *wp_data = nullptr) const = 0;
// void (*ComputeCtHt)(lstm_t *);
// // compute c1 and h1 without c0 or h0
// void (*ComputeC1H1)(lstm_t *);
void (*ComputeCtHt)(lstm_t *, const lstm_attr_t *);
// compute c1 and h1 without c0 or h0
void (*ComputeC1H1)(lstm_t *, const lstm_attr_t *);
};
template <typename T>
......
......@@ -33,18 +33,24 @@ typedef struct {
const void* ct_1;
void* ct;
void* ht;
/* below only used in peephole*/
const void* wp_data{nullptr};
/* weight_peephole and checked data are only used in peephole*/
const void* wp{nullptr};
void* checked{nullptr};
} lstm_t;
typedef struct lstm_attr_s {
bool use_peephole;
int d;
std::string act_gate, act_cand, act_cell;
lstm_attr_s() = default;
lstm_attr_s(int _d, const std::string& _act_gate,
const std::string& _act_cand, const std::string& _act_cell)
: d(_d), act_gate(_act_gate), act_cand(_act_cand), act_cell(_act_cell) {}
const std::string& _act_cand, const std::string& _act_cell,
bool _use_peephole = false)
: use_peephole(_use_peephole),
d(_d),
act_gate(_act_gate),
act_cand(_act_cand),
act_cell(_act_cell) {}
} lstm_attr_t;
} // namespace jitkernel
......
......@@ -82,10 +82,10 @@ namespace jitkernel {
#define REGISTER_JITKERNEL_ARGS(ker_key, ker_class, marco_define_name, \
marco_declare, macro_find_key, macro_impl) \
marco_define_name(ker_key, ker_class); \
REGISTER_JITKERNEL_WITH_DTYPE(ker_class, float, JITKERNEL_DECLARE, \
JITKERNEL_FIND_KEY, JITKERNEL_IMPL); \
REGISTER_JITKERNEL_WITH_DTYPE(ker_class, double, JITKERNEL_DECLARE, \
JITKERNEL_FIND_KEY, JITKERNEL_IMPL)
REGISTER_JITKERNEL_WITH_DTYPE(ker_class, float, marco_declare, \
macro_find_key, macro_impl); \
REGISTER_JITKERNEL_WITH_DTYPE(ker_class, double, marco_declare, \
macro_find_key, macro_impl)
#define REGISTER_JITKERNEL(ker_key, ker_class) \
REGISTER_JITKERNEL_ARGS(ker_key, ker_class, JITKERNEL_DEFINE_NAME, \
......
......@@ -117,11 +117,13 @@ void (*getActFunc(const std::string& type))(const T*, T*, int) { // NOLINT
}
template <typename T>
void LSTMCtHt(lstm_t* step, lstm_attr_t* attr) {
void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) {
T* gates = reinterpret_cast<T*>(step->gates);
const T* ct_1 = reinterpret_cast<const T*>(step->ct_1);
T* ct = reinterpret_cast<T*>(step->ct);
T* ht = reinterpret_cast<T*>(step->ht);
const T* wp = reinterpret_cast<const T*>(step->wp);
T* checked = reinterpret_cast<T*>(step->checked);
auto act_gate = getActFunc<T>(attr->act_gate);
auto act_cand = getActFunc<T>(attr->act_cand);
auto act_cell = getActFunc<T>(attr->act_cell);
......@@ -129,23 +131,36 @@ void LSTMCtHt(lstm_t* step, lstm_attr_t* attr) {
int d2 = d * 2;
int d3 = d * 3;
// gates: W_ch, W_ih, W_fh, W_oh
act_gate(gates + d, gates + d, d3);
if (attr->use_peephole) {
VMul(wp, ct_1, checked, d);
VMul(wp + d, ct_1, checked + d, d);
VAdd(checked, gates + d, gates + d, d2);
act_gate(gates + d, gates + d, d2);
} else {
act_gate(gates + d, gates + d, d3);
}
/* C_t = C_t-1 * fgated + cand_gated * igated */
// C_t = C_t-1 * fgated + cand_gated * igated
act_cand(gates, gates, d);
VMul(gates, gates + d, gates + d, d);
VMul(ct_1, gates + d2, gates + d2, d);
VAdd(gates + d, gates + d2, ct, d);
/* H_t = act_cell(C_t) * ogated */
if (attr->use_peephole) {
// get ogated
VMul(wp + d2, ct, gates + d, d);
VAdd(gates + d, gates + d3, gates + d3, d);
act_gate(gates + d3, gates + d3, d);
}
// H_t = act_cell(C_t) * ogated
act_cell(ct, gates + d2, d);
VMul(gates + d2, gates + d3, ht, d);
}
// compute c1 and h1 without c0 or h0
template <typename T>
void LSTMC1H1(lstm_t* step, lstm_attr_t* attr) {
void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr) {
T* gates = reinterpret_cast<T*>(step->gates);
const T* ct_1 = reinterpret_cast<const T*>(step->ct_1);
T* ct = reinterpret_cast<T*>(step->ct);
T* ht = reinterpret_cast<T*>(step->ht);
auto act_gate = getActFunc<T>(attr->act_gate);
......@@ -158,10 +173,16 @@ void LSTMC1H1(lstm_t* step, lstm_attr_t* attr) {
act_gate(gates + d, gates + d, d);
act_cand(gates, gates, d);
VMul(gates, gates + d, ct, d);
if (attr->use_peephole) {
// get outgated, put W_oc * C_t on igated
const T* wp = reinterpret_cast<const T*>(step->wp);
VMul(wp + d2, ct, gates + d, d);
VAdd(gates + d, gates + d3, gates + d3, d);
}
/* H_t = act_cell(C_t) * ogated */
act_gate(gates + d3, gates + d3, d);
act_cell(ct, gates + d2, d);
Vmul(gates + d2, gates + d3, ht, d);
VMul(gates + d2, gates + d3, ht, d);
}
} // namespace refer
......
......@@ -15,9 +15,14 @@ limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h"
#include <string>
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
#include "paddle/fluid/operators/math/jit_kernel_refer.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/macros.h"
#ifdef PADDLE_WITH_XBYAK
#include "paddle/fluid/operators/math/jit_code.h"
#endif
#ifdef __AVX__
#include <immintrin.h>
#endif
......@@ -154,211 +159,136 @@ static std::unique_ptr<AVXAct> GetAVXAct(const std::string& type) {
#endif
/* LSTM JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block>
template <typename T>
class LSTMKernelImpl : public LSTMKernel<T> {
public:
explicit LSTMKernelImpl(const std::string& act_gate,
const std::string& act_cand,
const std::string& act_cell, int d)
: LSTMKernel<T>() {
d_ = d;
d2_ = d * 2;
d3_ = d * 3;
act_gate_d3_ = GetActKernel<T>(act_gate, d3_);
act_gate_d_ = GetActKernel<T>(act_gate, d);
act_cand_d_ = GetActKernel<T>(act_cand, d);
act_cell_d_ = GetActKernel<T>(act_cell, d);
vmul_d_ = KernelPool::Instance().template Get<VMulKernel<T>>(d);
vadd_d_ = KernelPool::Instance().template Get<VAddKernel<T>>(d);
static inline std::string name(const lstm_attr_t& attr) {
PADDLE_THROW("DType should be either float or double");
}
static inline bool useJIT(int d) { return false; }
static inline bool useMKL(int d) { return false; }
explicit LSTMKernelImpl(const lstm_attr_t& attr) : LSTMKernel<T>() {
#ifdef PADDLE_WITH_XBYAK
if (useJIT(attr.d)) {
size_t sz = 96 + attr.d / YMM_FLOAT_BLOCK * 84 * 8; // should change
jitcode0_.reset(new gen::LSTMJitCode(false, attr, sz > 4096 ? sz : 4096));
this->ComputeCtHt =
jitcode0_->getCode<void (*)(lstm_t*, const lstm_attr_t*)>();
jitcode1_.reset(new gen::LSTMJitCode(true, attr, sz > 4096 ? sz : 4096));
this->ComputeC1H1 =
jitcode1_->getCode<void (*)(lstm_t*, const lstm_attr_t*)>();
return;
}
#endif
void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data,
T* checked) const override {
// gates: W_ch, W_ih, W_fh, W_oh
act_gate_d3_->Compute(gates + d_, gates + d_, d3_);
/* C_t = C_t-1 * fgated + cand_gated * igated */
act_cand_d_->Compute(gates, gates, d_);
vmul_d_->Compute(gates, gates + d_, gates + d_, d_);
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_);
vadd_d_->Compute(gates + d_, gates + d2_, ct, d_);
/* H_t = act_cell(C_t) * ogated */
act_cell_d_->Compute(ct, gates + d2_, d_);
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
}
void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override {
/* C_t = igated * cgated*/
act_gate_d_->Compute(gates + d_, gates + d_, d_);
act_cand_d_->Compute(gates, gates, d_);
vmul_d_->Compute(gates, gates + d_, ct, d_);
/* H_t = act_cell(C_t) * ogated */
act_gate_d_->Compute(gates + d3_, gates + d3_, d_);
act_cell_d_->Compute(ct, gates + d2_, d_);
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
this->ComputeCtHt = refer::LSTMCtHt<T>;
this->ComputeC1H1 = refer::LSTMC1H1<T>;
}
#ifdef PADDLE_WITH_XBYAK
private:
int d_, d2_, d3_;
std::shared_ptr<const VActKernel<T>> act_gate_d3_, act_gate_d_, act_cand_d_,
act_cell_d_;
std::shared_ptr<const VMulKernel<T>> vmul_d_;
std::shared_ptr<const VAddKernel<T>> vadd_d_;
#ifdef __AVX__
std::unique_ptr<const AVXAct> avx_act_gate_, avx_act_cand_, avx_act_cell_;
std::unique_ptr<gen::LSTMJitCode> jitcode0_{nullptr}, jitcode1_{nullptr};
#endif
};
#define INTRI8_FLOAT(isa) \
template <> \
LSTMKernelImpl<float, isa, kEQ8>::LSTMKernelImpl( \
const std::string& act_gate, const std::string& act_cand, \
const std::string& act_cell, int d) \
: LSTMKernel<float>() { \
avx_act_gate_ = GetAVXAct<isa>(act_gate); \
avx_act_cand_ = GetAVXAct<isa>(act_cand); \
avx_act_cell_ = GetAVXAct<isa>(act_cell); \
} \
template <> \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \
float* gates, const float* ct_1, float* ct, float* ht, \
const float* wp_data, float* checked) const { \
/* gates: W_ch, W_ih, W_fh, W_oh */ \
__m256 c, i, f, o; \
c = _mm256_loadu_ps(gates); \
i = _mm256_loadu_ps(gates + 8); \
f = _mm256_loadu_ps(gates + 16); \
o = _mm256_loadu_ps(gates + 24); \
/* C_t = C_t-1 * fgated + cand_gated * igated*/ \
c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \
i = _mm256_loadu_ps(ct_1); \
f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \
f = _mm256_add_ps(c, f); \
_mm256_storeu_ps(ct, f); \
/* H_t = act_cell(C_t) * ogated */ \
o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \
_mm256_storeu_ps(ht, o); \
} \
template <> \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeC1H1( \
float* gates, float* ct, float* ht, const float* wp_data) const { \
__m256 c, i, o; \
c = _mm256_loadu_ps(gates); \
i = _mm256_loadu_ps(gates + 8); \
o = _mm256_loadu_ps(gates + 24); \
/* C_t = igated * cgated*/ \
c = _mm256_mul_ps(avx_act_gate_->Compute(i), avx_act_cand_->Compute(c)); \
_mm256_storeu_ps(ct, c); \
/* H_t = act_cell(C_t) * ogated */ \
o = _mm256_mul_ps(avx_act_cell_->Compute(c), avx_act_gate_->Compute(o)); \
_mm256_storeu_ps(ht, o); \
}
// TODO(TJ): optimize keq16
#ifdef __AVX__
INTRI8_FLOAT(jit::avx);
#endif
#ifdef __AVX2__
INTRI8_FLOAT(jit::avx2);
#endif
#ifdef __AVX512F__
INTRI8_FLOAT(jit::avx512f);
#ifdef PADDLE_WITH_XBYAK
template <>
bool LSTMKernelImpl<float>::useJIT(int d) {
return false; // not ready yet gen::LSTMJitCode::init(d);
}
#endif
/* Peephole JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block>
template <typename T>
class PeepholeKernelImpl : public LSTMKernel<T> {
public:
explicit PeepholeKernelImpl(const std::string& act_gate,
const std::string& act_cand,
const std::string& act_cell, int d)
: LSTMKernel<T>() {
d_ = d;
d2_ = d * 2;
d3_ = d * 3;
act_gate_d_ = GetActKernel<T>(act_gate, d);
act_cand_d_ = GetActKernel<T>(act_cand, d);
act_cell_d_ = GetActKernel<T>(act_cell, d);
vmul_d_ = KernelPool::Instance().template Get<VMulKernel<T>>(d);
vadd_d_ = KernelPool::Instance().template Get<VAddKernel<T>>(d);
vadd_d2_ = KernelPool::Instance().template Get<VAddKernel<T>>(d2_);
act_gate_d2_ = GetActKernel<T>(act_gate, d2_);
static inline std::string name(const lstm_attr_t& attr) {
PADDLE_THROW("DType should be either float or double");
}
static inline bool useJIT(int d) { return false; }
static inline bool useMKL(int d) { return false; }
explicit PeepholeKernelImpl(const lstm_attr_t& attr) : LSTMKernel<T>() {
#ifdef PADDLE_WITH_XBYAK
if (useJIT(attr.d)) {
size_t sz = 96 + attr.d / YMM_FLOAT_BLOCK * 84 * 8; // should change
jitcode0_.reset(new gen::LSTMJitCode(false, attr, sz > 4096 ? sz : 4096));
this->ComputeCtHt =
jitcode0_->getCode<void (*)(lstm_t*, const lstm_attr_t*)>();
jitcode1_.reset(new gen::LSTMJitCode(true, attr, sz > 4096 ? sz : 4096));
this->ComputeC1H1 =
jitcode1_->getCode<void (*)(lstm_t*, const lstm_attr_t*)>();
return;
}
#endif
void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data,
T* checked) const override {
/* get fgated and igated*/
vmul_d_->Compute(wp_data, ct_1, checked, d_);
vmul_d_->Compute(wp_data + d_, ct_1, checked + d_, d_);
vadd_d2_->Compute(checked, gates + d_, gates + d_, d2_);
act_gate_d2_->Compute(gates + d_, gates + d_, d2_);
/* C_t = C_t-1 * fgated + cand_gated * igated*/
act_cand_d_->Compute(gates, gates, d_);
vmul_d_->Compute(gates, gates + d_, gates + d_, d_);
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_);
vadd_d_->Compute(gates + d_, gates + d2_, ct, d_);
/* get ogated*/
vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_);
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_);
act_gate_d_->Compute(gates + d3_, gates + d3_, d_);
/* H_t = act_cell(C_t) * ogated */
act_cell_d_->Compute(ct, gates + d2_, d_);
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
this->ComputeCtHt = refer::LSTMCtHt<T>;
this->ComputeC1H1 = refer::LSTMC1H1<T>;
}
void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override {
/* C_t = igated * cgated*/
act_gate_d_->Compute(gates + d_, gates + d_, d_);
act_cand_d_->Compute(gates, gates, d_);
vmul_d_->Compute(gates, gates + d_, ct, d_);
/* get outgated, put W_oc * C_t on igated */
vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_);
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_);
/* H_t = act_cell(C_t) * ogated */
act_gate_d_->Compute(gates + d3_, gates + d3_, d_);
act_cell_d_->Compute(ct, gates + d2_, d_);
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
}
#ifdef PADDLE_WITH_XBYAK
private:
int d_, d2_, d3_;
std::shared_ptr<const VActKernel<T>> act_gate_d2_, act_gate_d_, act_cand_d_,
act_cell_d_;
std::shared_ptr<const VMulKernel<T>> vmul_d_;
std::shared_ptr<const VAddKernel<T>> vadd_d_, vadd_d2_;
std::unique_ptr<gen::LSTMJitCode> jitcode0_{nullptr}, jitcode1_{nullptr};
#endif
};
#define JITKERNEL_DECLARE_LSTM(ker_class, ker_dtype) \
template <> \
std::shared_ptr<const LSTMKernel<ker_dtype>> \
KernelPool::Get<LSTMKernel<ker_dtype>, const std::string&, \
const std::string&, const std::string&, int, bool>( \
const std::string& act_gate, const std::string& act_cand, \
const std::string& act_cell, int d, bool use_peephole)
#define JITKERNEL_KEY_LSTM(ker_key, dtype_key) \
#ker_key #dtype_key + std::to_string(d) + act_gate + act_cand + act_cell + \
(use_peephole ? "p" : "n")
#define JITKERNEL_NEW_LSTM_IMPL(ker, dtype, isa, k) \
if (use_peephole) { \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<PeepholeKernelImpl<dtype, isa, k>>( \
act_gate, act_cand, act_cell, d)); \
} else { \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype, isa, k>>(act_gate, act_cand, \
act_cell, d)); \
#ifdef PADDLE_WITH_XBYAK
template <>
bool PeepholeKernelImpl<float>::useJIT(int d) {
return false; // peephole jitcode not ready yet
}
#endif
#define JITKERNEL_DEFINE_NAME_LSTM(ker_key, ker_class) \
template <> \
std::string ker_class##Impl<float>::name(const lstm_attr_t& attr) { \
std::string key(#ker_key "f"); \
key += (attr.act_gate + attr.act_cand + attr.act_cell + \
(attr.use_peephole ? "p" : "n")); \
if (useJIT(attr.d)) { \
/* only jit code need record d*/ \
return key + "jit" + std::to_string(attr.d); \
} else if (useMKL(attr.d)) { \
return key + "mkl"; \
} else { \
return key + "any"; \
} \
} \
template <> \
std::string ker_class##Impl<double>::name(const lstm_attr_t& attr) { \
std::string key(#ker_key "d"); \
/* jit code do not support double yet*/ \
if (useMKL(attr.d)) { \
return key + "mkl"; \
} else { \
return key + "any"; \
} \
}
REGISTER_JITKERNEL_ARGS_DEPRECATED(lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM,
JITKERNEL_KEY_LSTM, JITKERNEL_NEW_LSTM_IMPL);
#define JITKERNEL_DECLARE_LSTM(ker_class, ker_dtype) \
template <> \
std::shared_ptr<const LSTMKernel<ker_dtype>> \
KernelPool::Get<LSTMKernel<ker_dtype>, const lstm_attr_t&>( \
const lstm_attr_t& attr)
#define JITKERNEL_FIND_KEY_LSTM(ker_class, ker_dtype) \
std::string key = ker_class##Impl<ker_dtype>::name(attr)
#define JITKERNEL_LSTM_IMPL(ker, dtype) \
if (attr.use_peephole) { \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<PeepholeKernelImpl<dtype>>(attr)); \
} else { \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype>>(attr)); \
}
#undef INTRI8_FLOAT
#undef JITKERNEL_DECLARE_LSTM
#undef JITKERNEL_KEY_LSTM
#undef JITKERNEL_NEW_LSTM_IMPL
REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DEFINE_NAME_LSTM,
JITKERNEL_DECLARE_LSTM, JITKERNEL_FIND_KEY_LSTM,
JITKERNEL_LSTM_IMPL);
/* GRU JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block>
......
......@@ -341,11 +341,11 @@ TEST(JitKernel, lstm) {
RandomVec<float>(d, ct_1.data(), -2.f, 2.f);
memcpy(xref.data(), x.data(), sizeof(float) * d4);
std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh";
const jit::lstm_attr_t attr(d, act_gate, act_cand, act_cell, false);
const auto& ker =
jit::KernelPool::Instance()
.template Get<jit::LSTMKernel<float>, const std::string&,
const std::string&, const std::string&>(
act_gate, act_cand, act_cell, d, false);
.template Get<jit::LSTMKernel<float>, const jit::lstm_attr_t&>(
attr);
// below kernels are used to compute refer
const auto& vsigmoid_3d =
jit::KernelPool::Instance().template Get<jit::VSigmoidKernel<float>>(
......@@ -366,14 +366,16 @@ TEST(JitKernel, lstm) {
float* ht_ref_data = ht_ref.data();
// compute once to check correctness
jit::lstm_t step;
jit::lstm_attr_t attr(d, act_gate, act_cand, act_cell);
step.gates = xref_data;
step.ct_1 = ct_1_data;
step.ct = ct_ref_data;
step.ht = ht_ref_data;
refer::LSTMCtHt<float>(&step, &attr);
ker->ComputeCtHt(x_data, ct_1_data, ct_tgt_data, ht_tgt_data);
step.gates = x_data;
step.ct = ct_tgt_data;
step.ht = ht_tgt_data;
ker->ComputeCtHt(&step, &attr);
for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ct_tgt_data[i], ct_ref_data[i], 1e-3);
EXPECT_NEAR(ht_tgt_data[i], ht_ref_data[i], 1e-3);
......@@ -392,7 +394,7 @@ TEST(JitKernel, lstm) {
auto trefe = GetCurrentUS();
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->ComputeCtHt(x_data, ct_1_data, ct_tgt_data, ht_tgt_data);
ker->ComputeCtHt(&step, &attr);
}
auto ttgte = GetCurrentUS();
VLOG(30) << "Vec size " << d
......@@ -710,21 +712,21 @@ TEST(JitKernel, pool) {
namespace jit = paddle::operators::math::jitkernel;
const int frame_size = 4;
std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh";
jit::lstm_attr_t attr(frame_size, act_gate, act_cand, act_cell, false);
const auto& plstm1 =
jit::KernelPool::Instance()
.template Get<jit::LSTMKernel<float>, const std::string&,
const std::string&, const std::string&>(
act_gate, act_cand, act_cell, frame_size, false);
.template Get<jit::LSTMKernel<float>, const jit::lstm_attr_t&>(attr);
const auto& plstm2 =
jit::KernelPool::Instance()
.template Get<jit::LSTMKernel<float>, const std::string&,
const std::string&, const std::string&>(
act_gate, act_cand, act_cell, frame_size, false);
.template Get<jit::LSTMKernel<float>, const jit::lstm_attr_t&>(attr);
EXPECT_EQ(plstm1, plstm2);
const auto& peephole =
jit::KernelPool::Instance()
.template Get<jit::LSTMKernel<float>, const std::string&,
const std::string&, const std::string&>(
act_gate, act_cand, act_cell, frame_size, true);
.template Get<jit::LSTMKernel<float>, const jit::lstm_attr_t&>(
jit::lstm_attr_t(frame_size, act_gate, act_cand, act_cell, true));
EXPECT_TRUE(plstm1 != peephole);
const auto& pvmul_f =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册