提交 64d5b438 编写于 作者: T tensor-tang

fix crf decode avx512

上级 21487d78
...@@ -156,17 +156,16 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> { ...@@ -156,17 +156,16 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
} \ } \
} }
#define INTRIAVX2_FLOAT(block) \ #define INTRIAVX2_FLOAT(isa, block) \
template <> \ template <> \
CRFDecodeKernelImpl<float, jit::avx2, block>::CRFDecodeKernelImpl( \ CRFDecodeKernelImpl<float, isa, block>::CRFDecodeKernelImpl(int tag_num) \
int tag_num) \
: CRFDecodeKernel<float>() { \ : CRFDecodeKernel<float>() { \
this->num_ = tag_num; \ this->num_ = tag_num; \
this->end_ = this->num_ / AVX2_FLOAT_BLOCK; \ this->end_ = this->num_ / AVX2_FLOAT_BLOCK; \
this->rest_ = this->num_ % AVX2_FLOAT_BLOCK; \ this->rest_ = this->num_ % AVX2_FLOAT_BLOCK; \
} \ } \
template <> \ template <> \
void CRFDecodeKernelImpl<float, jit::avx2, block>::Compute( \ void CRFDecodeKernelImpl<float, isa, block>::Compute( \
const int seq_len, const float* x, const float* w, float* alpha, \ const int seq_len, const float* x, const float* w, float* alpha, \
int* track) const { \ int* track) const { \
INIT_ALPHA(AVX2_FLOAT_BLOCK) \ INIT_ALPHA(AVX2_FLOAT_BLOCK) \
...@@ -224,7 +223,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> { ...@@ -224,7 +223,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
int j_offset = 0; \ int j_offset = 0; \
for (int j = 0; j <= this->end_; ++j) { \ for (int j = 0; j <= this->end_; ++j) { \
/* Initialize the variables of maximum score and location.*/ \ /* Initialize the variables of maximum score and location.*/ \
__m512 max_score = _mm512_set1_ps(-std::numeric_limits<T>::max()); \ __m512 max_score = _mm512_set1_ps(-std::numeric_limits<float>::max()); \
__m512i max_j = _mm512_setzero_si512(); \ __m512i max_j = _mm512_setzero_si512(); \
/* Calculate the offset of transition_weights.*/ \ /* Calculate the offset of transition_weights.*/ \
int trans_offset = state_trans_base_idx * this->num_ + j_offset; \ int trans_offset = state_trans_base_idx * this->num_ + j_offset; \
...@@ -245,7 +244,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> { ...@@ -245,7 +244,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
__m512 x_content = \ __m512 x_content = \
_mm512_loadu_ps(x + seq_offset + this->num_ + j_offset); \ _mm512_loadu_ps(x + seq_offset + this->num_ + j_offset); \
max_score = _mm512_add_ps(max_score, x_content); \ max_score = _mm512_add_ps(max_score, x_content); \
_mm512_storeu_ps(alpha_value + seq_offset + this->tag_num_ + j_offset, \ _mm512_storeu_ps(alpha + seq_offset + this->num_ + j_offset, \
max_score); \ max_score); \
_mm512_storeu_si512(reinterpret_cast<__m512i*>(track + seq_offset + \ _mm512_storeu_si512(reinterpret_cast<__m512i*>(track + seq_offset + \
this->num_ + j_offset), \ this->num_ + j_offset), \
...@@ -271,14 +270,14 @@ INTRIAVX_FLOAT(kEQ16); ...@@ -271,14 +270,14 @@ INTRIAVX_FLOAT(kEQ16);
INTRIAVX_FLOAT(kGT16); INTRIAVX_FLOAT(kGT16);
#endif #endif
#ifdef __AVX2__ #ifdef __AVX2__
INTRIAVX2_FLOAT(kEQ8); INTRIAVX2_FLOAT(jit::avx2, kEQ8);
INTRIAVX2_FLOAT(kGT8LT16); INTRIAVX2_FLOAT(jit::avx2, kGT8LT16);
INTRIAVX2_FLOAT(kEQ16); INTRIAVX2_FLOAT(jit::avx2, kEQ16);
INTRIAVX2_FLOAT(kGT16); INTRIAVX2_FLOAT(jit::avx2, kGT16);
#endif #endif
#ifdef __AVX512F__ #ifdef __AVX512F__
INTRIAVX2_FLOAT(kEQ8); INTRIAVX2_FLOAT(jit::avx512f, kEQ8);
INTRIAVX2_FLOAT(kGT8LT16); INTRIAVX2_FLOAT(jit::avx512f, kGT8LT16);
INTRIAVX512_FLOAT(kEQ16); INTRIAVX512_FLOAT(kEQ16);
INTRIAVX512_FLOAT(kGT16); INTRIAVX512_FLOAT(kGT16);
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册