提交 b4751a34 编写于 作者: T tensor-tang

fix illegal instruction of rnn2

上级 36588b33
...@@ -27,13 +27,6 @@ limitations under the License. */ ...@@ -27,13 +27,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
#ifdef __AVX__
namespace detail {
__m256 Exp(__m256 a);
} // namespace detail
#endif
namespace jitkernel { namespace jitkernel {
namespace jit = platform::jit; namespace jit = platform::jit;
...@@ -205,7 +198,7 @@ __m256 ExpAVX(__m256 x) { ...@@ -205,7 +198,7 @@ __m256 ExpAVX(__m256 x) {
#ifdef __AVX2__ #ifdef __AVX2__
__m256 ExpAVX2(__m256 x) { __m256 ExpAVX2(__m256 x) {
__m256 tmp = _mm256_setzero_ps(), fx; __m256 tmp = _mm256_setzero_ps(), fx;
__m256 one = *reinterpret_cast<const __m256*> _ps256_one; __m256 one = *reinterpret_cast<const __m256*>(_ps256_one);
__m256i imm0; __m256i imm0;
x = _mm256_min_ps(x, *reinterpret_cast<const __m256*>(_ps256_exp_hi)); x = _mm256_min_ps(x, *reinterpret_cast<const __m256*>(_ps256_exp_hi));
...@@ -335,7 +328,8 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> { ...@@ -335,7 +328,8 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
template <> \ template <> \
void VSigmoidKernelImpl<float, isa, kEQ8>::Compute(const float* x, float* y) \ void VSigmoidKernelImpl<float, isa, kEQ8>::Compute(const float* x, float* y) \
const { \ const { \
/*use static const??*/ __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ /* TODO(TJ): try to use static const*/ \
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
__m256 tmp = _mm256_loadu_ps(x); \ __m256 tmp = _mm256_loadu_ps(x); \
INTRI_SIGMOID(tmp, min, max, expisa); \ INTRI_SIGMOID(tmp, min, max, expisa); \
......
...@@ -25,13 +25,18 @@ limitations under the License. */ ...@@ -25,13 +25,18 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
#ifdef __AVX__ namespace jitkernel {
namespace detail { namespace detail {
__m256 Exp(__m256 a); #ifdef __AVX__
} // namespace detail __m256 ExpAVX(__m256 x);
#endif #endif
namespace jitkernel { #ifdef __AVX2__
__m256 ExpAVX2(__m256 x);
#endif
} // namespace detail
namespace jit = platform::jit; namespace jit = platform::jit;
#ifdef __AVX__ #ifdef __AVX__
...@@ -43,43 +48,72 @@ class AVXAct { ...@@ -43,43 +48,72 @@ class AVXAct {
virtual __m256 Compute(__m256 x) const = 0; virtual __m256 Compute(__m256 x) const = 0;
}; };
template <act_type type> template <act_type type, jit::cpu_isa_t isa>
class AVXActImpl : public AVXAct { class AVXActImpl : public AVXAct {
public: public:
__m256 Compute(__m256 x) const override { PADDLE_THROW("Unkown type!"); } __m256 Compute(__m256 x) const override { PADDLE_THROW("Unkown type!"); }
}; };
template <> #define AVX_SIGMOID(isa, expisa) \
__m256 AVXActImpl<kSigmoid>::Compute(__m256 x) const { template <> \
__m256 ones = _mm256_set1_ps(1.0f); __m256 AVXActImpl<kSigmoid, isa>::Compute(__m256 x) const { \
x = _mm256_max_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MIN)); __m256 ones = _mm256_set1_ps(1.0f); \
x = _mm256_min_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MAX)); x = _mm256_max_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MIN)); \
x = _mm256_sub_ps(_mm256_set1_ps(0.0f), x); x = _mm256_min_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MAX)); \
x = detail::Exp(x); x = _mm256_sub_ps(_mm256_set1_ps(0.0f), x); \
x = _mm256_add_ps(ones, x); x = expisa(x); \
return _mm256_div_ps(ones, x); x = _mm256_add_ps(ones, x); \
} return _mm256_div_ps(ones, x); \
}
template <> #define AVX_TANH(isa, expisa) \
__m256 AVXActImpl<kTanh>::Compute(__m256 x) const { template <> \
__m256 ones = _mm256_set1_ps(1.0f); __m256 AVXActImpl<kTanh, isa>::Compute(__m256 x) const { \
x = _mm256_mul_ps(_mm256_set1_ps(-2.0f), x); __m256 ones = _mm256_set1_ps(1.0f); \
x = _mm256_min_ps(x, _mm256_set1_ps(EXP_MAX_INPUT)); x = _mm256_mul_ps(_mm256_set1_ps(-2.0f), x); \
x = detail::Exp(x); x = _mm256_min_ps(x, _mm256_set1_ps(EXP_MAX_INPUT)); \
x = _mm256_add_ps(ones, x); x = expisa(x); \
x = _mm256_div_ps(_mm256_set1_ps(2.0f), x); x = _mm256_add_ps(ones, x); \
return _mm256_sub_ps(x, ones); x = _mm256_div_ps(_mm256_set1_ps(2.0f), x); \
} return _mm256_sub_ps(x, ones); \
}
template <> #define AVX_RELU(isa) \
__m256 AVXActImpl<kRelu>::Compute(__m256 x) const { template <> \
return _mm256_max_ps(x, _mm256_setzero_ps()); __m256 AVXActImpl<kRelu, isa>::Compute(__m256 x) const { \
} return _mm256_max_ps(x, _mm256_setzero_ps()); \
}
#define AVX_IDENTITY(isa) \
template <> \
__m256 AVXActImpl<kIdentity, isa>::Compute(__m256 x) const { \
return x; \
}
#define FOR_EACH_AVX_ISA(macro_) \
macro_(jit::avx); \
macro_(jit::avx2); \
macro_(jit::avx512f)
FOR_EACH_AVX_ISA(AVX_RELU);
FOR_EACH_AVX_ISA(AVX_IDENTITY);
AVX_SIGMOID(jit::avx, detail::ExpAVX);
AVX_TANH(jit::avx, detail::ExpAVX);
#ifdef __AVX2__
AVX_SIGMOID(jit::avx2, detail::ExpAVX2);
AVX_SIGMOID(jit::avx512f, detail::ExpAVX2);
AVX_TANH(jit::avx2, detail::ExpAVX2);
AVX_TANH(jit::avx512f, detail::ExpAVX2);
#endif
#undef FOR_EACH_AVX_ISA
#undef AVX_IDENTITY
#undef AVX_RELU
#undef AVX_TANH
#undef AVX_SIGMOID
template <>
__m256 AVXActImpl<kIdentity>::Compute(__m256 x) const {
return x;
}
#endif #endif
template <typename T> template <typename T>
...@@ -119,23 +153,6 @@ class LSTMKernelImpl : public LSTMKernel<T> { ...@@ -119,23 +153,6 @@ class LSTMKernelImpl : public LSTMKernel<T> {
act_cell_d_ = GetActKernel<T>(act_cell, d); act_cell_d_ = GetActKernel<T>(act_cell, d);
vmul_d_ = KernelPool::Instance().template Get<VMulKernel<T>>(d); vmul_d_ = KernelPool::Instance().template Get<VMulKernel<T>>(d);
vadd_d_ = KernelPool::Instance().template Get<VAddKernel<T>>(d); vadd_d_ = KernelPool::Instance().template Get<VAddKernel<T>>(d);
#ifdef __AVX__
auto GetAVXAct = [&](const std::string& type) -> std::unique_ptr<AVXAct> {
if (type == "sigmoid") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kSigmoid>());
} else if (type == "relu") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kRelu>());
} else if (type == "tanh") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kTanh>());
} else if (type == "identity" || type == "") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kIdentity>());
}
PADDLE_THROW("Not support type: %s", type);
};
avx_act_gate_ = GetAVXAct(act_gate);
avx_act_cand_ = GetAVXAct(act_cand);
avx_act_cell_ = GetAVXAct(act_cell);
#endif
} }
void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data, void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data,
...@@ -175,26 +192,61 @@ class LSTMKernelImpl : public LSTMKernel<T> { ...@@ -175,26 +192,61 @@ class LSTMKernelImpl : public LSTMKernel<T> {
#endif #endif
}; };
#define INTRI8_FLOAT(isa) \ #define INTRI8_FLOAT(isa) \
template <> \ template <> \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \ LSTMKernelImpl<float, isa, kEQ8>::LSTMKernelImpl( \
float* gates, const float* ct_1, float* ct, float* ht, \ const std::string& act_gate, const std::string& act_cand, \
const float* wp_data, float* checked) const { \ const std::string& act_cell, int d) \
/* gates: W_ch, W_ih, W_fh, W_oh */ \ : LSTMKernel<float>() { \
__m256 c, i, f, o; \ auto GetAVXAct = [&](const std::string& type) -> std::unique_ptr<AVXAct> { \
c = _mm256_loadu_ps(gates); \ if (type == "sigmoid") { \
i = _mm256_loadu_ps(gates + 8); \ return std::unique_ptr<AVXAct>(new AVXActImpl<kSigmoid, isa>()); \
f = _mm256_loadu_ps(gates + 16); \ } else if (type == "relu") { \
o = _mm256_loadu_ps(gates + 24); \ return std::unique_ptr<AVXAct>(new AVXActImpl<kRelu, isa>()); \
/* C_t = C_t-1 * fgated + cand_gated * igated*/ \ } else if (type == "tanh") { \
c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \ return std::unique_ptr<AVXAct>(new AVXActImpl<kTanh, isa>()); \
i = _mm256_loadu_ps(ct_1); \ } else if (type == "identity" || type == "") { \
f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \ return std::unique_ptr<AVXAct>(new AVXActImpl<kIdentity, isa>()); \
f = _mm256_add_ps(c, f); \ } \
_mm256_storeu_ps(ct, f); \ PADDLE_THROW("Not support type: %s", type); \
/* H_t = act_cell(C_t) * ogated */ \ }; \
o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \ avx_act_gate_ = GetAVXAct(act_gate); \
_mm256_storeu_ps(ht, o); \ avx_act_cand_ = GetAVXAct(act_cand); \
avx_act_cell_ = GetAVXAct(act_cell); \
} \
template <> \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \
float* gates, const float* ct_1, float* ct, float* ht, \
const float* wp_data, float* checked) const { \
/* gates: W_ch, W_ih, W_fh, W_oh */ \
__m256 c, i, f, o; \
c = _mm256_loadu_ps(gates); \
i = _mm256_loadu_ps(gates + 8); \
f = _mm256_loadu_ps(gates + 16); \
o = _mm256_loadu_ps(gates + 24); \
/* C_t = C_t-1 * fgated + cand_gated * igated*/ \
c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \
i = _mm256_loadu_ps(ct_1); \
f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \
f = _mm256_add_ps(c, f); \
_mm256_storeu_ps(ct, f); \
/* H_t = act_cell(C_t) * ogated */ \
o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \
_mm256_storeu_ps(ht, o); \
} \
template <> \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeC1H1( \
float* gates, float* ct, float* ht, const float* wp_data) const { \
__m256 c, i, o; \
c = _mm256_loadu_ps(gates); \
i = _mm256_loadu_ps(gates + 8); \
o = _mm256_loadu_ps(gates + 24); \
/* C_t = igated * cgated*/ \
c = _mm256_mul_ps(avx_act_gate_->Compute(i), avx_act_cand_->Compute(c)); \
_mm256_storeu_ps(ct, c); \
/* H_t = act_cell(C_t) * ogated */ \
o = _mm256_mul_ps(avx_act_cell_->Compute(c), avx_act_gate_->Compute(o)); \
_mm256_storeu_ps(ht, o); \
} }
// TODO(TJ): optimize keq16 // TODO(TJ): optimize keq16
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册