提交 2a009691 编写于 作者: T tensor-tang

optimize lstm jitkernel keq8

test=develop
上级 f2adaf1c
...@@ -77,5 +77,6 @@ endif() ...@@ -77,5 +77,6 @@ endif()
cc_test(concat_test SRCS concat_test.cc DEPS concat) cc_test(concat_test SRCS concat_test.cc DEPS concat)
cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info)
cc_library(jit_kernel_exp SRCS jit_kernel_exp.cc DEPS cpu_info cblas activation_functions) cc_library(jit_kernel_exp SRCS jit_kernel_exp.cc DEPS cpu_info cblas activation_functions)
cc_library(jit_kernel SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_lstm.cc DEPS cpu_info cblas jit_kernel_exp) cc_library(jit_kernel_lstm SRCS jit_kernel_lstm.cc DEPS cpu_info cblas activation_functions)
cc_library(jit_kernel SRCS jit_kernel.cc jit_kernel_blas.cc DEPS cpu_info cblas jit_kernel_exp jit_kernel_lstm)
cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel) cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel)
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/operators/math/jit_kernel_macro.h" #include "paddle/fluid/operators/math/jit_kernel_macro.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/macros.h"
#ifdef __AVX__ #ifdef __AVX__
#include <immintrin.h> #include <immintrin.h>
...@@ -24,10 +25,63 @@ limitations under the License. */ ...@@ -24,10 +25,63 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
namespace jitkernel { #ifdef __AVX__
namespace detail {
__m256 Exp(__m256 a);
} // namespace detail
#endif
namespace jitkernel {
namespace jit = platform::jit; namespace jit = platform::jit;
#ifdef __AVX__
typedef enum { kSigmoid, kRelu, kTanh, kIdentity } act_type;
class AVXAct {
public:
virtual ~AVXAct() = default;
virtual __m256 Compute(__m256 x) const = 0;
};
template <act_type type>
class AVXActImpl : public AVXAct {
public:
__m256 Compute(__m256 x) const override { PADDLE_THROW("Unkown type!"); }
};
template <>
__m256 AVXActImpl<kSigmoid>::Compute(__m256 x) const {
__m256 ones = _mm256_set1_ps(1.0f);
x = _mm256_max_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MIN));
x = _mm256_min_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MAX));
x = _mm256_sub_ps(_mm256_set1_ps(0.0f), x);
x = detail::Exp(x);
x = _mm256_add_ps(ones, x);
return _mm256_div_ps(ones, x);
}
template <>
__m256 AVXActImpl<kTanh>::Compute(__m256 x) const {
__m256 ones = _mm256_set1_ps(1.0f);
x = _mm256_mul_ps(_mm256_set1_ps(-2.0f), x);
x = _mm256_min_ps(x, _mm256_set1_ps(EXP_MAX_INPUT));
x = detail::Exp(x);
x = _mm256_add_ps(ones, x);
x = _mm256_div_ps(_mm256_set1_ps(2.0f), x);
return _mm256_sub_ps(x, ones);
}
template <>
__m256 AVXActImpl<kRelu>::Compute(__m256 x) const {
return _mm256_max_ps(x, _mm256_setzero_ps());
}
template <>
__m256 AVXActImpl<kIdentity>::Compute(__m256 x) const {
return x;
}
#endif
/* LSTM JitKernel */ /* LSTM JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block> template <typename T, jit::cpu_isa_t isa, jit_block>
class LSTMKernelImpl : public LSTMKernel<T> { class LSTMKernelImpl : public LSTMKernel<T> {
...@@ -61,6 +115,23 @@ class LSTMKernelImpl : public LSTMKernel<T> { ...@@ -61,6 +115,23 @@ class LSTMKernelImpl : public LSTMKernel<T> {
act_cell_d_ = GetActKernel(act_cell, d); act_cell_d_ = GetActKernel(act_cell, d);
vmul_d_ = KernelPool::Instance().template Get<VMulKernel<T>>(d); vmul_d_ = KernelPool::Instance().template Get<VMulKernel<T>>(d);
vadd_d_ = KernelPool::Instance().template Get<VAddKernel<T>>(d); vadd_d_ = KernelPool::Instance().template Get<VAddKernel<T>>(d);
#ifdef __AVX__
auto GetAVXAct = [&](const std::string& type) -> std::unique_ptr<AVXAct> {
if (type == "sigmoid") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kSigmoid>());
} else if (type == "relu") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kRelu>());
} else if (type == "tanh") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kTanh>());
} else if (type == "identity" || type == "") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kIdentity>());
}
PADDLE_THROW("Not support type: %s", type);
};
avx_act_gate_ = GetAVXAct(act_gate);
avx_act_cand_ = GetAVXAct(act_cand);
avx_act_cell_ = GetAVXAct(act_cell);
#endif
} }
void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht) const override { void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht) const override {
...@@ -83,8 +154,44 @@ class LSTMKernelImpl : public LSTMKernel<T> { ...@@ -83,8 +154,44 @@ class LSTMKernelImpl : public LSTMKernel<T> {
std::shared_ptr<const VActKernel<T>> act_gate_3d_, act_cand_d_, act_cell_d_; std::shared_ptr<const VActKernel<T>> act_gate_3d_, act_cand_d_, act_cell_d_;
std::shared_ptr<const VMulKernel<T>> vmul_d_; std::shared_ptr<const VMulKernel<T>> vmul_d_;
std::shared_ptr<const VAddKernel<T>> vadd_d_; std::shared_ptr<const VAddKernel<T>> vadd_d_;
#ifdef __AVX__
std::unique_ptr<const AVXAct> avx_act_gate_, avx_act_cand_, avx_act_cell_;
#endif
}; };
#define INTRI8_FLOAT(isa) \
template <> \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \
float* gates, const float* ct_1, float* ct, float* ht) const { \
/* gates: W_ch, W_ih, W_fh, W_oh */ \
__m256 c, i, f, o; \
c = _mm256_loadu_ps(gates); \
i = _mm256_loadu_ps(gates + 8); \
f = _mm256_loadu_ps(gates + 16); \
o = _mm256_loadu_ps(gates + 24); \
/* C_t = C_t-1 * fgated + cand_gated * igated*/ \
c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \
i = _mm256_loadu_ps(ct_1); \
f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \
f = _mm256_add_ps(c, f); \
_mm256_storeu_ps(ct, f); \
/* H_t = act_cell(C_t) * ogated */ \
o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \
_mm256_storeu_ps(ht, o); \
}
// TODO(TJ): optimize keq16
#ifdef __AVX__
INTRI8_FLOAT(jit::avx);
#endif
#ifdef __AVX2__
INTRI8_FLOAT(jit::avx2);
#endif
#ifdef __AVX512F__
INTRI8_FLOAT(jit::avx512f);
#endif
#define JITKERNEL_DECLARE_LSTM(ker_class, ker_dtype) \ #define JITKERNEL_DECLARE_LSTM(ker_class, ker_dtype) \
template <> \ template <> \
std::shared_ptr<const ker_class<ker_dtype>> \ std::shared_ptr<const ker_class<ker_dtype>> \
...@@ -104,6 +211,7 @@ class LSTMKernelImpl : public LSTMKernel<T> { ...@@ -104,6 +211,7 @@ class LSTMKernelImpl : public LSTMKernel<T> {
REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM, REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM,
JITKERNEL_KEY_LSTM, JITKERNEL_NEW_LSTM_IMPL); JITKERNEL_KEY_LSTM, JITKERNEL_NEW_LSTM_IMPL);
#undef INTRI8_FLOAT
#undef JITKERNEL_DECLARE_LSTM #undef JITKERNEL_DECLARE_LSTM
#undef JITKERNEL_KEY_LSTM #undef JITKERNEL_KEY_LSTM
#undef JITKERNEL_NEW_LSTM_IMPL #undef JITKERNEL_NEW_LSTM_IMPL
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册