提交 64d5b438 编写于 作者: T tensor-tang

fix crf decode avx512

上级 21487d78
......@@ -156,17 +156,16 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
} \
}
#define INTRIAVX2_FLOAT(block) \
#define INTRIAVX2_FLOAT(isa, block) \
template <> \
CRFDecodeKernelImpl<float, jit::avx2, block>::CRFDecodeKernelImpl( \
int tag_num) \
CRFDecodeKernelImpl<float, isa, block>::CRFDecodeKernelImpl(int tag_num) \
: CRFDecodeKernel<float>() { \
this->num_ = tag_num; \
this->end_ = this->num_ / AVX2_FLOAT_BLOCK; \
this->rest_ = this->num_ % AVX2_FLOAT_BLOCK; \
} \
template <> \
void CRFDecodeKernelImpl<float, jit::avx2, block>::Compute( \
void CRFDecodeKernelImpl<float, isa, block>::Compute( \
const int seq_len, const float* x, const float* w, float* alpha, \
int* track) const { \
INIT_ALPHA(AVX2_FLOAT_BLOCK) \
......@@ -224,7 +223,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
int j_offset = 0; \
for (int j = 0; j <= this->end_; ++j) { \
/* Initialize the variables of maximum score and location.*/ \
__m512 max_score = _mm512_set1_ps(-std::numeric_limits<T>::max()); \
__m512 max_score = _mm512_set1_ps(-std::numeric_limits<float>::max()); \
__m512i max_j = _mm512_setzero_si512(); \
/* Calculate the offset of transition_weights.*/ \
int trans_offset = state_trans_base_idx * this->num_ + j_offset; \
......@@ -245,7 +244,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
__m512 x_content = \
_mm512_loadu_ps(x + seq_offset + this->num_ + j_offset); \
max_score = _mm512_add_ps(max_score, x_content); \
_mm512_storeu_ps(alpha_value + seq_offset + this->tag_num_ + j_offset, \
_mm512_storeu_ps(alpha + seq_offset + this->num_ + j_offset, \
max_score); \
_mm512_storeu_si512(reinterpret_cast<__m512i*>(track + seq_offset + \
this->num_ + j_offset), \
......@@ -271,14 +270,14 @@ INTRIAVX_FLOAT(kEQ16);
INTRIAVX_FLOAT(kGT16);
#endif
#ifdef __AVX2__
INTRIAVX2_FLOAT(kEQ8);
INTRIAVX2_FLOAT(kGT8LT16);
INTRIAVX2_FLOAT(kEQ16);
INTRIAVX2_FLOAT(kGT16);
INTRIAVX2_FLOAT(jit::avx2, kEQ8);
INTRIAVX2_FLOAT(jit::avx2, kGT8LT16);
INTRIAVX2_FLOAT(jit::avx2, kEQ16);
INTRIAVX2_FLOAT(jit::avx2, kGT16);
#endif
#ifdef __AVX512F__
INTRIAVX2_FLOAT(kEQ8);
INTRIAVX2_FLOAT(kGT8LT16);
INTRIAVX2_FLOAT(jit::avx512f, kEQ8);
INTRIAVX2_FLOAT(jit::avx512f, kGT8LT16);
INTRIAVX512_FLOAT(kEQ16);
INTRIAVX512_FLOAT(kGT16);
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册