diff --git a/paddle/fluid/operators/math/jit_kernel_crf_decode.cc b/paddle/fluid/operators/math/jit_kernel_crf_decode.cc index bfc1b911a76038d6816f45fec42833390fdc6075..e481d1921a7dc4fd6da3fffbc3959eafa7b4b461 100644 --- a/paddle/fluid/operators/math/jit_kernel_crf_decode.cc +++ b/paddle/fluid/operators/math/jit_kernel_crf_decode.cc @@ -156,17 +156,16 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel { } \ } -#define INTRIAVX2_FLOAT(block) \ +#define INTRIAVX2_FLOAT(isa, block) \ template <> \ - CRFDecodeKernelImpl::CRFDecodeKernelImpl( \ - int tag_num) \ + CRFDecodeKernelImpl::CRFDecodeKernelImpl(int tag_num) \ : CRFDecodeKernel() { \ this->num_ = tag_num; \ this->end_ = this->num_ / AVX2_FLOAT_BLOCK; \ this->rest_ = this->num_ % AVX2_FLOAT_BLOCK; \ } \ template <> \ - void CRFDecodeKernelImpl::Compute( \ + void CRFDecodeKernelImpl::Compute( \ const int seq_len, const float* x, const float* w, float* alpha, \ int* track) const { \ INIT_ALPHA(AVX2_FLOAT_BLOCK) \ @@ -224,7 +223,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel { int j_offset = 0; \ for (int j = 0; j <= this->end_; ++j) { \ /* Initialize the variables of maximum score and location.*/ \ - __m512 max_score = _mm512_set1_ps(-std::numeric_limits::max()); \ + __m512 max_score = _mm512_set1_ps(-std::numeric_limits::max()); \ __m512i max_j = _mm512_setzero_si512(); \ /* Calculate the offset of transition_weights.*/ \ int trans_offset = state_trans_base_idx * this->num_ + j_offset; \ @@ -245,7 +244,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel { __m512 x_content = \ _mm512_loadu_ps(x + seq_offset + this->num_ + j_offset); \ max_score = _mm512_add_ps(max_score, x_content); \ - _mm512_storeu_ps(alpha_value + seq_offset + this->tag_num_ + j_offset, \ + _mm512_storeu_ps(alpha + seq_offset + this->num_ + j_offset, \ max_score); \ _mm512_storeu_si512(reinterpret_cast<__m512i*>(track + seq_offset + \ this->num_ + j_offset), \ @@ -271,14 +270,14 @@ INTRIAVX_FLOAT(kEQ16); INTRIAVX_FLOAT(kGT16); #endif #ifdef __AVX2__ -INTRIAVX2_FLOAT(kEQ8); -INTRIAVX2_FLOAT(kGT8LT16); -INTRIAVX2_FLOAT(kEQ16); -INTRIAVX2_FLOAT(kGT16); +INTRIAVX2_FLOAT(jit::avx2, kEQ8); +INTRIAVX2_FLOAT(jit::avx2, kGT8LT16); +INTRIAVX2_FLOAT(jit::avx2, kEQ16); +INTRIAVX2_FLOAT(jit::avx2, kGT16); #endif #ifdef __AVX512F__ -INTRIAVX2_FLOAT(kEQ8); -INTRIAVX2_FLOAT(kGT8LT16); +INTRIAVX2_FLOAT(jit::avx512f, kEQ8); +INTRIAVX2_FLOAT(jit::avx512f, kGT8LT16); INTRIAVX512_FLOAT(kEQ16); INTRIAVX512_FLOAT(kGT16); #endif