提交 159be8cc 编写于 作者: T tensor-tang

optimize fusion gru kernel at size 8

上级 83dc6898
...@@ -136,6 +136,21 @@ static std::shared_ptr<const VActKernel<T>> GetActKernel( ...@@ -136,6 +136,21 @@ static std::shared_ptr<const VActKernel<T>> GetActKernel(
return nullptr; return nullptr;
} }
template <jit::cpu_isa_t isa>
static std::unique_ptr<AVXAct> GetAVXAct(const std::string& type) {
if (type == "sigmoid") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kSigmoid, isa>());
} else if (type == "relu") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kRelu, isa>());
} else if (type == "tanh") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kTanh, isa>());
} else if (type == "identity" || type == "") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kIdentity, isa>());
}
PADDLE_THROW("Not support type: %s", type);
return nullptr;
}
/* LSTM JitKernel */ /* LSTM JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block> template <typename T, jit::cpu_isa_t isa, jit_block>
class LSTMKernelImpl : public LSTMKernel<T> { class LSTMKernelImpl : public LSTMKernel<T> {
...@@ -192,61 +207,49 @@ class LSTMKernelImpl : public LSTMKernel<T> { ...@@ -192,61 +207,49 @@ class LSTMKernelImpl : public LSTMKernel<T> {
#endif #endif
}; };
#define INTRI8_FLOAT(isa) \ #define INTRI8_FLOAT(isa) \
template <> \ template <> \
LSTMKernelImpl<float, isa, kEQ8>::LSTMKernelImpl( \ LSTMKernelImpl<float, isa, kEQ8>::LSTMKernelImpl( \
const std::string& act_gate, const std::string& act_cand, \ const std::string& act_gate, const std::string& act_cand, \
const std::string& act_cell, int d) \ const std::string& act_cell, int d) \
: LSTMKernel<float>() { \ : LSTMKernel<float>() { \
auto GetAVXAct = [&](const std::string& type) -> std::unique_ptr<AVXAct> { \ avx_act_gate_ = GetAVXAct<isa>(act_gate); \
if (type == "sigmoid") { \ avx_act_cand_ = GetAVXAct<isa>(act_cand); \
return std::unique_ptr<AVXAct>(new AVXActImpl<kSigmoid, isa>()); \ avx_act_cell_ = GetAVXAct<isa>(act_cell); \
} else if (type == "relu") { \ } \
return std::unique_ptr<AVXAct>(new AVXActImpl<kRelu, isa>()); \ template <> \
} else if (type == "tanh") { \ void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \
return std::unique_ptr<AVXAct>(new AVXActImpl<kTanh, isa>()); \ float* gates, const float* ct_1, float* ct, float* ht, \
} else if (type == "identity" || type == "") { \ const float* wp_data, float* checked) const { \
return std::unique_ptr<AVXAct>(new AVXActImpl<kIdentity, isa>()); \ /* gates: W_ch, W_ih, W_fh, W_oh */ \
} \ __m256 c, i, f, o; \
PADDLE_THROW("Not support type: %s", type); \ c = _mm256_loadu_ps(gates); \
}; \ i = _mm256_loadu_ps(gates + 8); \
avx_act_gate_ = GetAVXAct(act_gate); \ f = _mm256_loadu_ps(gates + 16); \
avx_act_cand_ = GetAVXAct(act_cand); \ o = _mm256_loadu_ps(gates + 24); \
avx_act_cell_ = GetAVXAct(act_cell); \ /* C_t = C_t-1 * fgated + cand_gated * igated*/ \
} \ c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \
template <> \ i = _mm256_loadu_ps(ct_1); \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \ f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \
float* gates, const float* ct_1, float* ct, float* ht, \ f = _mm256_add_ps(c, f); \
const float* wp_data, float* checked) const { \ _mm256_storeu_ps(ct, f); \
/* gates: W_ch, W_ih, W_fh, W_oh */ \ /* H_t = act_cell(C_t) * ogated */ \
__m256 c, i, f, o; \ o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \
c = _mm256_loadu_ps(gates); \ _mm256_storeu_ps(ht, o); \
i = _mm256_loadu_ps(gates + 8); \ } \
f = _mm256_loadu_ps(gates + 16); \ template <> \
o = _mm256_loadu_ps(gates + 24); \ void LSTMKernelImpl<float, isa, kEQ8>::ComputeC1H1( \
/* C_t = C_t-1 * fgated + cand_gated * igated*/ \ float* gates, float* ct, float* ht, const float* wp_data) const { \
c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \ __m256 c, i, o; \
i = _mm256_loadu_ps(ct_1); \ c = _mm256_loadu_ps(gates); \
f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \ i = _mm256_loadu_ps(gates + 8); \
f = _mm256_add_ps(c, f); \ o = _mm256_loadu_ps(gates + 24); \
_mm256_storeu_ps(ct, f); \ /* C_t = igated * cgated*/ \
/* H_t = act_cell(C_t) * ogated */ \ c = _mm256_mul_ps(avx_act_gate_->Compute(i), avx_act_cand_->Compute(c)); \
o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \ _mm256_storeu_ps(ct, c); \
_mm256_storeu_ps(ht, o); \ /* H_t = act_cell(C_t) * ogated */ \
} \ o = _mm256_mul_ps(avx_act_cell_->Compute(c), avx_act_gate_->Compute(o)); \
template <> \ _mm256_storeu_ps(ht, o); \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeC1H1( \
float* gates, float* ct, float* ht, const float* wp_data) const { \
__m256 c, i, o; \
c = _mm256_loadu_ps(gates); \
i = _mm256_loadu_ps(gates + 8); \
o = _mm256_loadu_ps(gates + 24); \
/* C_t = igated * cgated*/ \
c = _mm256_mul_ps(avx_act_gate_->Compute(i), avx_act_cand_->Compute(c)); \
_mm256_storeu_ps(ct, c); \
/* H_t = act_cell(C_t) * ogated */ \
o = _mm256_mul_ps(avx_act_cell_->Compute(c), avx_act_gate_->Compute(o)); \
_mm256_storeu_ps(ht, o); \
} }
// TODO(TJ): optimize keq16 // TODO(TJ): optimize keq16
...@@ -375,6 +378,7 @@ class GRUKernelImpl : public GRUKernel<T> { ...@@ -375,6 +378,7 @@ class GRUKernelImpl : public GRUKernel<T> {
act_state_d_->Compute(gates + d2_, gates + d2_); act_state_d_->Compute(gates + d2_, gates + d2_);
vmul_d_->Compute(gates, gates + d2_, ht); vmul_d_->Compute(gates, gates + d2_, ht);
} }
void ComputeHtPart1(T* gates, const T* ht_1, T* ht) const override { void ComputeHtPart1(T* gates, const T* ht_1, T* ht) const override {
// W: {W_update, W_reset; W_state} // W: {W_update, W_reset; W_state}
act_gate_d2_->Compute(gates, gates); act_gate_d2_->Compute(gates, gates);
...@@ -394,8 +398,65 @@ class GRUKernelImpl : public GRUKernel<T> { ...@@ -394,8 +398,65 @@ class GRUKernelImpl : public GRUKernel<T> {
int d_, d2_; int d_, d2_;
std::shared_ptr<const VActKernel<T>> act_gate_d2_, act_gate_d_, act_state_d_; std::shared_ptr<const VActKernel<T>> act_gate_d2_, act_gate_d_, act_state_d_;
std::shared_ptr<const VMulKernel<T>> vmul_d_; std::shared_ptr<const VMulKernel<T>> vmul_d_;
#ifdef __AVX__
std::unique_ptr<const AVXAct> avx_act_gate_, avx_act_state_;
#endif
}; };
#define INTRI8_FLOAT(isa) \
template <> \
GRUKernelImpl<float, isa, kEQ8>::GRUKernelImpl( \
const std::string& act_gate, const std::string& act_state, int d) \
: GRUKernel<float>() { \
avx_act_gate_ = GetAVXAct<isa>(act_gate); \
avx_act_state_ = GetAVXAct<isa>(act_state); \
} \
template <> \
void GRUKernelImpl<float, isa, kEQ8>::ComputeH1(float* gates, float* ht) \
const { \
__m256 u, s; \
/* W: {W_update, W_reset; W_state} */ \
u = _mm256_loadu_ps(gates); \
s = _mm256_loadu_ps(gates + 16); \
s = _mm256_mul_ps(avx_act_gate_->Compute(u), avx_act_state_->Compute(s)); \
_mm256_storeu_ps(ht, s); \
} \
template <> \
void GRUKernelImpl<float, isa, kEQ8>::ComputeHtPart1( \
float* gates, const float* ht_1, float* ht) const { \
/* not exactly equal the any implementation */ \
__m256 r, ht0; \
r = _mm256_loadu_ps(gates + 8); \
ht0 = _mm256_loadu_ps(ht_1); \
r = _mm256_mul_ps(avx_act_gate_->Compute(r), ht0); \
_mm256_storeu_ps(ht, r); \
} \
template <> \
void GRUKernelImpl<float, isa, kEQ8>::ComputeHtPart2( \
float* gates, const float* ht_1, float* ht) const { \
/* not exactly equal the any implementation */ \
__m256 u, s, ht0; \
u = _mm256_loadu_ps(gates); \
s = _mm256_loadu_ps(gates + 16); \
ht0 = _mm256_loadu_ps(ht_1); \
u = avx_act_gate_->Compute(u); \
s = _mm256_mul_ps(u, avx_act_state_->Compute(s)); \
u = _mm256_sub_ps(_mm256_set1_ps(1.f), u); \
u = _mm256_mul_ps(u, ht0); \
u = _mm256_add_ps(s, u); \
_mm256_storeu_ps(ht, u); \
}
#ifdef __AVX__
INTRI8_FLOAT(jit::avx);
#endif
#ifdef __AVX2__
INTRI8_FLOAT(jit::avx2);
#endif
#ifdef __AVX512F__
INTRI8_FLOAT(jit::avx512f);
#endif
#define JITKERNEL_DECLARE_GRU(ker_class, ker_dtype) \ #define JITKERNEL_DECLARE_GRU(ker_class, ker_dtype) \
template <> \ template <> \
std::shared_ptr<const GRUKernel<ker_dtype>> KernelPool::Get< \ std::shared_ptr<const GRUKernel<ker_dtype>> KernelPool::Get< \
...@@ -412,6 +473,7 @@ class GRUKernelImpl : public GRUKernel<T> { ...@@ -412,6 +473,7 @@ class GRUKernelImpl : public GRUKernel<T> {
REGISTER_JITKERNEL_ARGS(gru, GRUKernel, JITKERNEL_DECLARE_GRU, REGISTER_JITKERNEL_ARGS(gru, GRUKernel, JITKERNEL_DECLARE_GRU,
JITKERNEL_KEY_GRU, JITKERNEL_NEW_GRU_IMPL); JITKERNEL_KEY_GRU, JITKERNEL_NEW_GRU_IMPL);
#undef INTRI8_FLOAT
#undef JITKERNEL_NEW_GRU_IMPL #undef JITKERNEL_NEW_GRU_IMPL
#undef JITKERNEL_KEY_GRU #undef JITKERNEL_KEY_GRU
#undef JITKERNEL_DECLARE_GRU #undef JITKERNEL_DECLARE_GRU
......
...@@ -125,6 +125,12 @@ class TestFusionGRUOpMD2(TestFusionGRUOp): ...@@ -125,6 +125,12 @@ class TestFusionGRUOpMD2(TestFusionGRUOp):
self.D = 8 self.D = 8
class TestFusionGRUOpMD3(TestFusionGRUOp):
def set_confs(self):
self.M = 17
self.D = 15
class TestFusionGRUOpBS1(TestFusionGRUOp): class TestFusionGRUOpBS1(TestFusionGRUOp):
def set_confs(self): def set_confs(self):
self.lod = [[3]] self.lod = [[3]]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册