提交 35620513 编写于 作者: T tensor-tang

add gru refer code and remove redundant avx code

test=develop
上级 f9138608
...@@ -192,11 +192,14 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -192,11 +192,14 @@ class FusionGRUKernel : public framework::OpKernel<T> {
const int M = x_dims[1]; \ const int M = x_dims[1]; \
const int D = wh_dims[0]; \ const int D = wh_dims[0]; \
const int D2 = D * 2; \ const int D2 = D * 2; \
const auto& ker = math::jitkernel::KernelPool::Instance() \ const math::jitkernel::gru_attr_t attr( \
D, ctx.Attr<std::string>("gate_activation"), \
ctx.Attr<std::string>("activation")); \
math::jitkernel::gru_t one_step; \
const auto& ker = \
math::jitkernel::KernelPool::Instance() \
.template Get<math::jitkernel::GRUKernel<T>, \ .template Get<math::jitkernel::GRUKernel<T>, \
const std::string&, const std::string&>( \ const math::jitkernel::gru_attr_t&>(attr); \
ctx.Attr<std::string>("gate_activation"), \
ctx.Attr<std::string>("activation"), D); \
const T* x_data = x->data<T>(); \ const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \ const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \ const T* wh_data = wh->data<T>(); \
...@@ -237,7 +240,9 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -237,7 +240,9 @@ class FusionGRUKernel : public framework::OpKernel<T> {
if (h0_data) { if (h0_data) {
prev_hidden_data = h0_data + bid * D; prev_hidden_data = h0_data + bid * D;
} else { } else {
ker->ComputeH1(xx_data, hidden_out_data); one_step.gates = xx_data;
one_step.ht = hidden_out_data;
ker->ComputeH1(&one_step, &attr);
prev_hidden_data = hidden_out_data; prev_hidden_data = hidden_out_data;
tstart = 1; tstart = 1;
move_step(); move_step();
...@@ -247,12 +252,15 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -247,12 +252,15 @@ class FusionGRUKernel : public framework::OpKernel<T> {
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D2, D, static_cast<T>(1), blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D2, D, static_cast<T>(1),
prev_hidden_data, D, wh_data, D2, static_cast<T>(1), xx_data, prev_hidden_data, D, wh_data, D2, static_cast<T>(1), xx_data,
D3); D3);
ker->ComputeHtPart1(xx_data, prev_hidden_data, hidden_out_data); one_step.gates = xx_data;
one_step.ht_1 = prev_hidden_data;
one_step.ht = hidden_out_data;
ker->ComputeHtPart1(&one_step, &attr);
// gemm rt * Ws // gemm rt * Ws
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast<T>(1), blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast<T>(1),
hidden_out_data, D, wh_state_data, D, static_cast<T>(1), hidden_out_data, D, wh_state_data, D, static_cast<T>(1),
xx_data + D2, D3); xx_data + D2, D3);
ker->ComputeHtPart2(xx_data, prev_hidden_data, hidden_out_data); ker->ComputeHtPart2(&one_step, &attr);
// save prev // save prev
prev_hidden_data = hidden_out_data; prev_hidden_data = hidden_out_data;
move_step(); move_step();
...@@ -314,7 +322,9 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -314,7 +322,9 @@ class FusionGRUKernel : public framework::OpKernel<T> {
T* cur_out_data = batched_out_data; T* cur_out_data = batched_out_data;
// W: {W_update, W_reset; W_state} // W: {W_update, W_reset; W_state}
for (int i = 0; i < max_bs; ++i) { for (int i = 0; i < max_bs; ++i) {
ker->ComputeH1(cur_in_data, cur_out_data); one_step.gates = cur_in_data;
one_step.ht = cur_out_data;
ker->ComputeH1(&one_step, &attr);
// add offset // add offset
cur_in_data += D3; cur_in_data += D3;
cur_out_data += D; cur_out_data += D;
...@@ -339,8 +349,11 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -339,8 +349,11 @@ class FusionGRUKernel : public framework::OpKernel<T> {
T* cur_out_data = batched_out_data; T* cur_out_data = batched_out_data;
T* cur_prev_hidden_data = prev_hidden_data; T* cur_prev_hidden_data = prev_hidden_data;
for (int i = 0; i < cur_bs; ++i) { for (int i = 0; i < cur_bs; ++i) {
ker->ComputeHtPart1(cur_batched_data, cur_prev_hidden_data, one_step.gates = cur_batched_data;
cur_out_data); one_step.ht_1 = cur_prev_hidden_data;
one_step.ht = cur_out_data;
ker->ComputeHtPart1(&one_step, &attr);
cur_batched_data += D3; cur_batched_data += D3;
cur_prev_hidden_data += D; cur_prev_hidden_data += D;
cur_out_data += D; cur_out_data += D;
...@@ -354,8 +367,10 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -354,8 +367,10 @@ class FusionGRUKernel : public framework::OpKernel<T> {
cur_prev_hidden_data = prev_hidden_data; cur_prev_hidden_data = prev_hidden_data;
for (int i = 0; i < cur_bs; ++i) { for (int i = 0; i < cur_bs; ++i) {
ker->ComputeHtPart2(cur_batched_data, cur_prev_hidden_data, one_step.gates = cur_batched_data;
cur_out_data); one_step.ht_1 = cur_prev_hidden_data;
one_step.ht = cur_out_data;
ker->ComputeHtPart2(&one_step, &attr);
cur_batched_data += D3; cur_batched_data += D3;
cur_prev_hidden_data += D; cur_prev_hidden_data += D;
cur_out_data += D; cur_out_data += D;
......
...@@ -122,18 +122,18 @@ class VTanhKernel : public VActKernel<T> {}; ...@@ -122,18 +122,18 @@ class VTanhKernel : public VActKernel<T> {};
template <typename T> template <typename T>
class LSTMKernel : public Kernel { class LSTMKernel : public Kernel {
public: public:
void (*ComputeCtHt)(lstm_t *, const lstm_attr_t *);
// compute c1 and h1 without c0 or h0 // compute c1 and h1 without c0 or h0
void (*ComputeC1H1)(lstm_t *, const lstm_attr_t *); void (*ComputeC1H1)(lstm_t *, const lstm_attr_t *);
void (*ComputeCtHt)(lstm_t *, const lstm_attr_t *);
}; };
template <typename T> template <typename T>
class GRUKernel : public Kernel { class GRUKernel : public Kernel {
public: public:
// compute h1 without h0 // compute h1 without h0
virtual void ComputeH1(T *gates, T *ht) const = 0; void (*ComputeH1)(gru_t *, const gru_attr_t *);
virtual void ComputeHtPart1(T *gates, const T *ht_1, T *ht) const = 0; void (*ComputeHtPart1)(gru_t *, const gru_attr_t *);
virtual void ComputeHtPart2(T *gates, const T *ht_1, T *ht) const = 0; void (*ComputeHtPart2)(gru_t *, const gru_attr_t *);
}; };
template <typename T> template <typename T>
......
...@@ -25,10 +25,6 @@ limitations under the License. */ ...@@ -25,10 +25,6 @@ limitations under the License. */
#include "paddle/fluid/platform/dynload/mklml.h" #include "paddle/fluid/platform/dynload/mklml.h"
#endif #endif
#ifdef __AVX__
#include <immintrin.h>
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
...@@ -235,154 +231,6 @@ REGISTER_JITKERNEL(vexp, VExpKernel); ...@@ -235,154 +231,6 @@ REGISTER_JITKERNEL(vexp, VExpKernel);
REGISTER_JITKERNEL(vsigmoid, VSigmoidKernel); REGISTER_JITKERNEL(vsigmoid, VSigmoidKernel);
REGISTER_JITKERNEL(vtanh, VTanhKernel); REGISTER_JITKERNEL(vtanh, VTanhKernel);
namespace detail {
#ifdef __AVX__
#define ALIGN32 __attribute__((aligned(32)))
#define _PS256_CONST(Name, Val) \
static const float _ps256_##Name[8] ALIGN32 = {Val, Val, Val, Val, \
Val, Val, Val, Val}
#define _PI256_CONST(Name, Val) \
static const int _pi256_##Name[8] ALIGN32 = {Val, Val, Val, Val, \
Val, Val, Val, Val}
_PI256_CONST(0x7f, 0x7f);
_PS256_CONST(one, 1.f);
_PS256_CONST(0p5, 0.5f);
_PS256_CONST(exp_hi, 88.3762626647949f);
_PS256_CONST(exp_lo, -88.3762626647949f);
_PS256_CONST(cephes_LOG2EF, 1.44269504088896341);
_PS256_CONST(cephes_exp_C1, 0.693359375);
_PS256_CONST(cephes_exp_C2, -2.12194440e-4);
_PS256_CONST(cephes_exp_p0, 1.9875691500E-4);
_PS256_CONST(cephes_exp_p1, 1.3981999507E-3);
_PS256_CONST(cephes_exp_p2, 8.3334519073E-3);
_PS256_CONST(cephes_exp_p3, 4.1665795894E-2);
_PS256_CONST(cephes_exp_p4, 1.6666665459E-1);
_PS256_CONST(cephes_exp_p5, 5.0000001201E-1);
typedef union imm_xmm_union {
__m256i imm;
__m128i xmm[2];
} imm_xmm_union;
#define COPY_IMM_TO_XMM(imm_, xmm0_, xmm1_) \
{ \
imm_xmm_union u ALIGN32; \
u.imm = imm_; \
xmm0_ = u.xmm[0]; \
xmm1_ = u.xmm[1]; \
}
#define COPY_XMM_TO_IMM(xmm0_, xmm1_, imm_) \
{ \
imm_xmm_union u ALIGN32; \
u.xmm[0] = xmm0_; \
u.xmm[1] = xmm1_; \
imm_ = u.imm; \
}
#define AVX2_BITOP_USING_SSE2(fn) \
static inline __m256i avx2_mm256_##fn(__m256i x, int y) { \
/* use SSE2 to perform the bitop AVX2 */ \
__m128i x1, x2; \
__m256i ret; \
COPY_IMM_TO_XMM(x, x1, x2); \
x1 = _mm_##fn(x1, y); \
x2 = _mm_##fn(x2, y); \
COPY_XMM_TO_IMM(x1, x2, ret); \
return ret; \
}
#define AVX2_INTOP_USING_SSE2(fn) \
static inline __m256i avx2_mm256_add_epi32(__m256i x, __m256i y) { \
/* use SSE2 to perform the AVX2 integer operation */ \
__m128i x1, x2; \
__m128i y1, y2; \
__m256i ret; \
COPY_IMM_TO_XMM(x, x1, x2); \
COPY_IMM_TO_XMM(y, y1, y2); \
x1 = _mm_##fn(x1, y1); \
x2 = _mm_##fn(x2, y2); \
COPY_XMM_TO_IMM(x1, x2, ret); \
return ret; \
}
AVX2_BITOP_USING_SSE2(slli_epi32);
AVX2_INTOP_USING_SSE2(add_epi32);
#define AVXEXP_BASE \
__m256 tmp = _mm256_setzero_ps(), fx; \
__m256 one = *reinterpret_cast<const __m256*>(_ps256_one); \
__m256i imm0; \
x = _mm256_min_ps(x, *reinterpret_cast<const __m256*>(_ps256_exp_hi)); \
x = _mm256_max_ps(x, *reinterpret_cast<const __m256*>(_ps256_exp_lo)); \
/* express exp(x) as exp(g + n*log(2)) */ \
fx = _mm256_mul_ps(x, \
*reinterpret_cast<const __m256*>(_ps256_cephes_LOG2EF)); \
fx = _mm256_add_ps(fx, *reinterpret_cast<const __m256*>(_ps256_0p5)); \
tmp = _mm256_floor_ps(fx); \
/* if greater, substract 1 */ \
__m256 mask = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS); \
mask = _mm256_and_ps(mask, one); \
fx = _mm256_sub_ps(tmp, mask); \
tmp = _mm256_mul_ps(fx, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_C1)); \
__m256 z = _mm256_mul_ps( \
fx, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_C2)); \
x = _mm256_sub_ps(x, tmp); \
x = _mm256_sub_ps(x, z); \
z = _mm256_mul_ps(x, x); \
__m256 y = *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p0); \
y = _mm256_mul_ps(y, x); \
y = _mm256_add_ps(y, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p1)); \
y = _mm256_mul_ps(y, x); \
y = _mm256_add_ps(y, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p2)); \
y = _mm256_mul_ps(y, x); \
y = _mm256_add_ps(y, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p3)); \
y = _mm256_mul_ps(y, x); \
y = _mm256_add_ps(y, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p4)); \
y = _mm256_mul_ps(y, x); \
y = _mm256_add_ps(y, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p5)); \
y = _mm256_mul_ps(y, z); \
y = _mm256_add_ps(y, x); \
y = _mm256_add_ps(y, one); \
/* build 2^n */ \
imm0 = _mm256_cvttps_epi32(fx)
__m256 ExpAVX(__m256 x) {
AVXEXP_BASE;
// two AVX2 instructions using SSE2
imm0 = avx2_mm256_add_epi32(imm0,
*reinterpret_cast<const __m256i*>(_pi256_0x7f));
imm0 = avx2_mm256_slli_epi32(imm0, 23);
__m256 pow2n = _mm256_castsi256_ps(imm0);
y = _mm256_mul_ps(y, pow2n);
return y;
}
#endif
#ifdef __AVX2__
__m256 ExpAVX2(__m256 x) {
AVXEXP_BASE;
// two AVX2 instructions
imm0 = _mm256_add_epi32(imm0, *reinterpret_cast<const __m256i*>(_pi256_0x7f));
imm0 = _mm256_slli_epi32(imm0, 23);
__m256 pow2n = _mm256_castsi256_ps(imm0);
y = _mm256_mul_ps(y, pow2n);
return y;
}
#endif
} // namespace detail
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -38,20 +38,34 @@ typedef struct { ...@@ -38,20 +38,34 @@ typedef struct {
void* checked{nullptr}; void* checked{nullptr};
} lstm_t; } lstm_t;
typedef struct lstm_attr_s { typedef struct {
bool use_peephole; void* gates; // gates: {W_update, W_reset; W_state}
const void* ht_1;
void* ht;
} gru_t;
struct rnn_attr_s {
int d; int d;
std::string act_gate, act_cand, act_cell; std::string act_gate, act_cand;
rnn_attr_s() = default;
rnn_attr_s(int _d, const std::string& _act_gate, const std::string& _act_cand)
: d(_d), act_gate(_act_gate), act_cand(_act_cand) {}
};
struct lstm_attr_s : public rnn_attr_s {
bool use_peephole;
std::string act_cell;
lstm_attr_s() = default; lstm_attr_s() = default;
lstm_attr_s(int _d, const std::string& _act_gate, lstm_attr_s(int _d, const std::string& _act_gate,
const std::string& _act_cand, const std::string& _act_cell, const std::string& _act_cand, const std::string& _act_cell,
bool _use_peephole = false) bool _use_peephole = false)
: use_peephole(_use_peephole), : rnn_attr_s(_d, _act_gate, _act_cand),
d(_d), use_peephole(_use_peephole),
act_gate(_act_gate),
act_cand(_act_cand),
act_cell(_act_cell) {} act_cell(_act_cell) {}
} lstm_attr_t; };
typedef struct rnn_attr_s gru_attr_t;
typedef struct lstm_attr_s lstm_attr_t;
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
......
...@@ -185,6 +185,46 @@ void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr) { ...@@ -185,6 +185,46 @@ void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr) {
VMul(gates + d2, gates + d3, ht, d); VMul(gates + d2, gates + d3, ht, d);
} }
// compute h1 without h0
template <typename T>
void GRUH1(gru_t* step, const gru_attr_t* attr) {
T* gates = reinterpret_cast<T*>(step->gates);
T* ht = reinterpret_cast<T*>(step->ht);
auto act_gate = getActFunc<T>(attr->act_gate);
auto act_cand = getActFunc<T>(attr->act_cand);
int d = attr->d;
int d2 = d * 2;
act_gate(gates, gates, d);
act_cand(gates + d2, gates + d2, d);
VMul(gates, gates + d2, ht, d);
}
template <typename T>
void GRUHtPart1(gru_t* step, const gru_attr_t* attr) {
// W: {W_update, W_reset; W_state}
T* gates = reinterpret_cast<T*>(step->gates);
T* ht = reinterpret_cast<T*>(step->ht);
const T* ht_1 = reinterpret_cast<const T*>(step->ht_1);
auto act_gate = getActFunc<T>(attr->act_gate);
act_gate(gates, gates, attr->d * 2);
VMul(ht_1, gates + attr->d, ht, attr->d);
}
template <typename T>
void GRUHtPart2(gru_t* step, const gru_attr_t* attr) {
T* gates = reinterpret_cast<T*>(step->gates);
T* ht = reinterpret_cast<T*>(step->ht);
const T* ht_1 = reinterpret_cast<const T*>(step->ht_1);
auto act_cand = getActFunc<T>(attr->act_cand);
int d = attr->d;
T* y = gates + d * 2;
act_cand(y, y, d);
// out = zt*ht~ + (1-zt)*ht_1
for (int i = 0; i < d; ++i) {
ht[i] = gates[i] * y[i] + (static_cast<T>(1) - gates[i]) * ht_1[i];
}
}
} // namespace refer } // namespace refer
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
......
...@@ -23,140 +23,10 @@ limitations under the License. */ ...@@ -23,140 +23,10 @@ limitations under the License. */
#include "paddle/fluid/operators/math/jit_code.h" #include "paddle/fluid/operators/math/jit_code.h"
#endif #endif
#ifdef __AVX__
#include <immintrin.h>
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
namespace jitkernel { namespace jitkernel {
namespace detail {
#ifdef __AVX__
__m256 ExpAVX(__m256 x);
#endif
#ifdef __AVX2__
__m256 ExpAVX2(__m256 x);
#endif
} // namespace detail
namespace jit = platform::jit;
#ifdef __AVX__
typedef enum { kSigmoid, kRelu, kTanh, kIdentity } act_type;
class AVXAct {
public:
virtual ~AVXAct() = default;
virtual __m256 Compute(__m256 x) const = 0;
};
template <act_type type, jit::cpu_isa_t isa>
class AVXActImpl : public AVXAct {
public:
__m256 Compute(__m256 x) const override { PADDLE_THROW("Unkown type!"); }
};
#define AVX_SIGMOID(isa, expisa) \
template <> \
__m256 AVXActImpl<kSigmoid, isa>::Compute(__m256 x) const { \
__m256 ones = _mm256_set1_ps(1.0f); \
x = _mm256_max_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MIN)); \
x = _mm256_min_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MAX)); \
x = _mm256_sub_ps(_mm256_set1_ps(0.0f), x); \
x = expisa(x); \
x = _mm256_add_ps(ones, x); \
return _mm256_div_ps(ones, x); \
}
#define AVX_TANH(isa, expisa) \
template <> \
__m256 AVXActImpl<kTanh, isa>::Compute(__m256 x) const { \
__m256 ones = _mm256_set1_ps(1.0f); \
x = _mm256_mul_ps(_mm256_set1_ps(-2.0f), x); \
x = _mm256_min_ps(x, _mm256_set1_ps(EXP_MAX_INPUT)); \
x = expisa(x); \
x = _mm256_add_ps(ones, x); \
x = _mm256_div_ps(_mm256_set1_ps(2.0f), x); \
return _mm256_sub_ps(x, ones); \
}
#define AVX_RELU(isa) \
template <> \
__m256 AVXActImpl<kRelu, isa>::Compute(__m256 x) const { \
return _mm256_max_ps(x, _mm256_setzero_ps()); \
}
#define AVX_IDENTITY(isa) \
template <> \
__m256 AVXActImpl<kIdentity, isa>::Compute(__m256 x) const { \
return x; \
}
#define FOR_EACH_AVX_ISA(macro_) \
macro_(jit::avx); \
macro_(jit::avx2); \
macro_(jit::avx512f)
FOR_EACH_AVX_ISA(AVX_RELU);
FOR_EACH_AVX_ISA(AVX_IDENTITY);
AVX_SIGMOID(jit::avx, detail::ExpAVX);
AVX_TANH(jit::avx, detail::ExpAVX);
#ifdef __AVX2__
AVX_SIGMOID(jit::avx2, detail::ExpAVX2);
AVX_SIGMOID(jit::avx512f, detail::ExpAVX2);
AVX_TANH(jit::avx2, detail::ExpAVX2);
AVX_TANH(jit::avx512f, detail::ExpAVX2);
#endif
#undef FOR_EACH_AVX_ISA
#undef AVX_IDENTITY
#undef AVX_RELU
#undef AVX_TANH
#undef AVX_SIGMOID
#endif
template <typename T>
static std::shared_ptr<const VActKernel<T>> GetActKernel(
const std::string& type, int n) {
if (type == "sigmoid") {
return std::dynamic_pointer_cast<const VActKernel<T>>(
KernelPool::Instance().template Get<VSigmoidKernel<T>>(n));
} else if (type == "relu") {
return std::dynamic_pointer_cast<const VActKernel<T>>(
KernelPool::Instance().template Get<VReluKernel<T>>(n));
} else if (type == "tanh") {
return std::dynamic_pointer_cast<const VActKernel<T>>(
KernelPool::Instance().template Get<VTanhKernel<T>>(n));
} else if (type == "identity" || type == "") {
return std::dynamic_pointer_cast<const VActKernel<T>>(
KernelPool::Instance().template Get<VIdentityKernel<T>>(n));
}
PADDLE_THROW("Not support type: %s", type);
return nullptr;
}
#ifdef __AVX__
template <jit::cpu_isa_t isa>
static std::unique_ptr<AVXAct> GetAVXAct(const std::string& type) {
if (type == "sigmoid") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kSigmoid, isa>());
} else if (type == "relu") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kRelu, isa>());
} else if (type == "tanh") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kTanh, isa>());
} else if (type == "identity" || type == "") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kIdentity, isa>());
}
PADDLE_THROW("Not support type: %s", type);
return nullptr;
}
#endif
/* LSTM JitKernel */ /* LSTM JitKernel */
template <typename T> template <typename T>
...@@ -290,125 +160,73 @@ REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DEFINE_NAME_LSTM, ...@@ -290,125 +160,73 @@ REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DEFINE_NAME_LSTM,
JITKERNEL_DECLARE_LSTM, JITKERNEL_FIND_KEY_LSTM, JITKERNEL_DECLARE_LSTM, JITKERNEL_FIND_KEY_LSTM,
JITKERNEL_LSTM_IMPL); JITKERNEL_LSTM_IMPL);
#undef JITKERNEL_LSTM_IMPL
#undef JITKERNEL_FIND_KEY_LSTM
#undef JITKERNEL_DECLARE_LSTM
#undef JITKERNEL_DEFINE_NAME_LSTM
/* GRU JitKernel */ /* GRU JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block> template <typename T>
class GRUKernelImpl : public GRUKernel<T> { class GRUKernelImpl : public GRUKernel<T> {
public: public:
explicit GRUKernelImpl(const std::string& act_gate, static inline std::string name(const gru_attr_t& attr) {
const std::string& act_state, int d) PADDLE_THROW("DType should be either float or double");
: GRUKernel<T>() {
d_ = d;
d2_ = d * 2;
act_gate_d2_ = GetActKernel<T>(act_gate, d2_);
act_gate_d_ = GetActKernel<T>(act_gate, d);
act_state_d_ = GetActKernel<T>(act_state, d);
vmul_d_ = KernelPool::Instance().template Get<VMulKernel<T>>(d);
}
void ComputeH1(T* gates, T* ht) const override {
act_gate_d_->Compute(gates, gates, d_);
act_state_d_->Compute(gates + d2_, gates + d2_, d_);
vmul_d_->Compute(gates, gates + d2_, ht, d_);
}
void ComputeHtPart1(T* gates, const T* ht_1, T* ht) const override {
// W: {W_update, W_reset; W_state}
act_gate_d2_->Compute(gates, gates, d2_);
vmul_d_->Compute(ht_1, gates + d_, ht, d_);
}
void ComputeHtPart2(T* gates, const T* ht_1, T* ht) const override {
T* y = gates + d2_;
act_state_d_->Compute(y, y, d_);
// out = zt*ht~ + (1-zt)*ht_1
for (int i = 0; i < d_; ++i) {
ht[i] = gates[i] * y[i] + (static_cast<T>(1) - gates[i]) * ht_1[i];
} }
static inline bool useJIT(int d) { return false; }
static inline bool useMKL(int d) { return false; }
explicit GRUKernelImpl(const gru_attr_t& attr) : GRUKernel<T>() {
this->ComputeH1 = refer::GRUH1<T>;
this->ComputeHtPart1 = refer::GRUHtPart1<T>;
this->ComputeHtPart2 = refer::GRUHtPart2<T>;
} }
private:
int d_, d2_;
std::shared_ptr<const VActKernel<T>> act_gate_d2_, act_gate_d_, act_state_d_;
std::shared_ptr<const VMulKernel<T>> vmul_d_;
#ifdef __AVX__
std::unique_ptr<const AVXAct> avx_act_gate_, avx_act_state_;
#endif
}; };
#define INTRI8_FLOAT(isa) \ #define JITKERNEL_DEFINE_NAME_GRU(ker_key, ker_class) \
template <> \ template <> \
GRUKernelImpl<float, isa, kEQ8>::GRUKernelImpl( \ std::string ker_class##Impl<float>::name(const gru_attr_t& attr) { \
const std::string& act_gate, const std::string& act_state, int d) \ std::string key(#ker_key "f"); \
: GRUKernel<float>() { \ key += (attr.act_gate + attr.act_cand); \
avx_act_gate_ = GetAVXAct<isa>(act_gate); \ if (useJIT(attr.d)) { \
avx_act_state_ = GetAVXAct<isa>(act_state); \ /* only jit code need record d*/ \
return key + "jit" + std::to_string(attr.d); \
} else if (useMKL(attr.d)) { \
return key + "mkl"; \
} else { \
return key + "any"; \
} \ } \
template <> \
void GRUKernelImpl<float, isa, kEQ8>::ComputeH1(float* gates, float* ht) \
const { \
__m256 u, s; \
/* W: {W_update, W_reset; W_state} */ \
u = _mm256_loadu_ps(gates); \
s = _mm256_loadu_ps(gates + 16); \
s = _mm256_mul_ps(avx_act_gate_->Compute(u), avx_act_state_->Compute(s)); \
_mm256_storeu_ps(ht, s); \
} \ } \
template <> \ template <> \
void GRUKernelImpl<float, isa, kEQ8>::ComputeHtPart1( \ std::string ker_class##Impl<double>::name(const gru_attr_t& attr) { \
float* gates, const float* ht_1, float* ht) const { \ std::string key(#ker_key "d"); \
/* not exactly equal the any implementation */ \ /* jit code do not support double yet*/ \
__m256 r, ht0; \ if (useMKL(attr.d)) { \
r = _mm256_loadu_ps(gates + 8); \ return key + "mkl"; \
ht0 = _mm256_loadu_ps(ht_1); \ } else { \
r = _mm256_mul_ps(avx_act_gate_->Compute(r), ht0); \ return key + "any"; \
_mm256_storeu_ps(ht, r); \
} \ } \
template <> \
void GRUKernelImpl<float, isa, kEQ8>::ComputeHtPart2( \
float* gates, const float* ht_1, float* ht) const { \
/* not exactly equal the any implementation */ \
__m256 u, s, ht0; \
u = _mm256_loadu_ps(gates); \
s = _mm256_loadu_ps(gates + 16); \
ht0 = _mm256_loadu_ps(ht_1); \
u = avx_act_gate_->Compute(u); \
s = _mm256_mul_ps(u, avx_act_state_->Compute(s)); \
u = _mm256_sub_ps(_mm256_set1_ps(1.f), u); \
u = _mm256_mul_ps(u, ht0); \
u = _mm256_add_ps(s, u); \
_mm256_storeu_ps(ht, u); \
} }
#ifdef __AVX__
INTRI8_FLOAT(jit::avx);
#endif
#ifdef __AVX2__
INTRI8_FLOAT(jit::avx2);
#endif
#ifdef __AVX512F__
INTRI8_FLOAT(jit::avx512f);
#endif
#define JITKERNEL_DECLARE_GRU(ker_class, ker_dtype) \ #define JITKERNEL_DECLARE_GRU(ker_class, ker_dtype) \
template <> \ template <> \
std::shared_ptr<const GRUKernel<ker_dtype>> KernelPool::Get< \ std::shared_ptr<const ker_class<ker_dtype>> \
GRUKernel<ker_dtype>, const std::string&, const std::string&, int>( \ KernelPool::Get<ker_class<ker_dtype>, const gru_attr_t&>( \
const std::string& act_gate, const std::string& act_state, int d) const gru_attr_t& attr)
#define JITKERNEL_KEY_GRU(ker_key, dtype_key) \ #define JITKERNEL_FIND_KEY_GRU(ker_class, ker_dtype) \
#ker_key #dtype_key + std::to_string(d) + act_gate + act_state std::string key = ker_class##Impl<ker_dtype>::name(attr)
#define JITKERNEL_NEW_GRU_IMPL(ker, dtype, isa, k) \ #define JITKERNEL_GRU_IMPL(ker, dtype) \
p = std::dynamic_pointer_cast<ker<dtype>>( \ p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype, isa, k>>(act_gate, act_state, d)); std::make_shared<ker##Impl<dtype>>(attr));
REGISTER_JITKERNEL_ARGS_DEPRECATED(gru, GRUKernel, JITKERNEL_DECLARE_GRU, REGISTER_JITKERNEL_ARGS(gru, GRUKernel, JITKERNEL_DEFINE_NAME_GRU,
JITKERNEL_KEY_GRU, JITKERNEL_NEW_GRU_IMPL); JITKERNEL_DECLARE_GRU, JITKERNEL_FIND_KEY_GRU,
JITKERNEL_GRU_IMPL);
#undef INTRI8_FLOAT #undef JITKERNEL_GRU_IMPL
#undef JITKERNEL_NEW_GRU_IMPL #undef JITKERNEL_FIND_KEY_GRU
#undef JITKERNEL_KEY_GRU
#undef JITKERNEL_DECLARE_GRU #undef JITKERNEL_DECLARE_GRU
#undef JITKERNEL_DEFINE_NAME_GRU
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册