diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 78ef6f207eadea6799864fe22889103b468d1780..067f2f7316c30b1ec57d2768ba271b1e9d41ab31 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -300,6 +300,7 @@ op_library(flatten_op DEPS reshape_op) op_library(sequence_pad_op DEPS sequence_padding) op_library(unstack_op DEPS stack_op) op_library(fake_quantize_op DEPS memory) +op_library(crf_decoding_op DEPS jit_kernel) op_library(fusion_lstm_op DEPS jit_kernel) if (WITH_GPU) op_library(conv_op DEPS vol2col depthwise_conv im2col) diff --git a/paddle/fluid/operators/crf_decoding_op.h b/paddle/fluid/operators/crf_decoding_op.h index 8181897c3d3844bda5574e85a08b2af038fcd664..e9d2e84a434d7084c526a6e75363a65577197262 100644 --- a/paddle/fluid/operators/crf_decoding_op.h +++ b/paddle/fluid/operators/crf_decoding_op.h @@ -16,6 +16,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/jit_kernel.h" #include "paddle/fluid/operators/math/math_function.h" namespace paddle { @@ -69,9 +70,6 @@ class CRFDecodingOpKernel : public framework::OpKernel { auto emission_dims = emission_weights.dims(); const size_t seq_len = emission_dims[0]; const size_t tag_num = emission_dims[1]; - - const size_t state_trans_base_idx = 2; - const T* x = emission_weights.data(); const T* w = transition_weights.data(); int64_t* path = decoded_path->data(); @@ -84,221 +82,10 @@ class CRFDecodingOpKernel : public framework::OpKernel { Tensor track; int* track_value = track.mutable_data(emission_dims, platform::CPUPlace()); - -#ifdef __AVX__ -// It use the AVX or AVX512 instruction to deal the data as the vector of 8 or -// 16 elements per iteration. Then it can implement the parallel processing. -// Only optimize for float type. -#ifdef __AVX512F__ - size_t step_size = 16; -#else - size_t step_size = 8; -#endif - if (std::is_same::value && (tag_num >= step_size)) { - size_t steps = tag_num / step_size; - size_t remain = tag_num % step_size; - int last_offset = static_cast(remain) - static_cast(step_size); - - // Setup the alpha initial value. - size_t i_offset = 0; - for (size_t i = 0; i <= steps; ++i) { -#ifdef __AVX512F__ - // Declare the variable for the content of weights, input and alpha - // values. - __m512 w_content, x_content, alpha_content; - - // Load the relevant data into the variables from un-aligned address. - w_content = _mm512_loadu_ps((const float*)(w + i_offset)); - x_content = _mm512_loadu_ps((const float*)(x + i_offset)); - alpha_content = _mm512_add_ps(w_content, x_content); - - // Save the alpha value. - _mm512_storeu_ps(reinterpret_cast(alpha_value + i_offset), - alpha_content); -#else - // Declare the variable for the content of weights, input and alpha - // values. - __m256 w_content, x_content, alpha_content; - - // Load the relevant data into the variables from un-aligned address. - w_content = _mm256_loadu_ps((const float*)(w + i_offset)); - x_content = _mm256_loadu_ps((const float*)(x + i_offset)); - alpha_content = _mm256_add_ps(w_content, x_content); - - // Save the alpha value. - _mm256_storeu_ps(reinterpret_cast(alpha_value + i_offset), - alpha_content); -#endif - i_offset += step_size; - if (i == steps - 1) { - if (remain > 0) { - i_offset += last_offset; - } else { - break; - } - } - } - - // Use the column-major strategy to get the location of maximum score. - size_t seq_offset = 0; - for (size_t k = 1; k < seq_len; ++k) { - size_t j_offset = 0; - for (size_t j = 0; j <= steps; ++j) { -#ifdef __AVX512F__ - // Initialize the variables of maximum score and location. - __m512 max_score = _mm512_set1_ps(-std::numeric_limits::max()); - __m512i max_j = _mm512_setzero_si512(); -#else - // Initialize the variables of maximum score and location. - __m256 max_score = _mm256_set1_ps(-std::numeric_limits::max()); - __m256i max_j = _mm256_set1_epi32(0); -#endif - // Calculate the offset of transition_weights. - size_t trans_offset = state_trans_base_idx * tag_num + j_offset; - for (size_t i = 0; i < tag_num; ++i) { -#ifdef __AVX512F__ - // Initalize the content of alpha variable with related offset. - __m512 alpha_content = - _mm512_set1_ps(*(const float*)(alpha_value + seq_offset + i)); - // Obtain the content of weights from un-aligned address. - __m512 w_content = - _mm512_loadu_ps((const float*)(w + trans_offset)); - - __m512 score_v = _mm512_add_ps(alpha_content, w_content); - - __mmask16 mask = _mm512_cmp_ps_mask(score_v, max_score, _CMP_GT_OS); - - // According to the mask value, it update the index of the max_score - // location. - max_j = _mm512_mask_set1_epi32(max_j, mask, i); - - // Update the max_score value. - max_score = _mm512_max_ps(max_score, score_v); -#else - // Initalize the content of alpha variable with related offset. - __m256 alpha_content = _mm256_broadcast_ss( - (const float*)(alpha_value + seq_offset + i)); - // Obtain the content of weights from un-aligned address. - __m256 w_content = - _mm256_loadu_ps((const float*)(w + trans_offset)); - __m256 score_v = _mm256_add_ps(alpha_content, w_content); - - __m256 mask = _mm256_cmp_ps(score_v, max_score, _CMP_GT_OS); - -#ifdef __AVX2__ - // According to the mask value, it update the index of the max_score - // location. - max_j = _mm256_or_si256( - _mm256_andnot_si256((__m256i)mask, max_j), - _mm256_and_si256((__m256i)mask, _mm256_set1_epi32(i))); -#else - __m128i lo_max_j = _mm256_extractf128_si256(max_j, 0); - __m128i hi_max_j = _mm256_extractf128_si256(max_j, 1); - __m128i lo_mask = _mm256_extractf128_si256((__m256i)mask, 0); - __m128i hi_mask = _mm256_extractf128_si256((__m256i)mask, 1); - - lo_max_j = _mm_andnot_si128(lo_mask, lo_max_j); - hi_max_j = _mm_andnot_si128(hi_mask, hi_max_j); - lo_mask = _mm_and_si128(lo_mask, _mm_set1_epi32(i)); - hi_mask = _mm_and_si128(hi_mask, _mm_set1_epi32(i)); - - lo_max_j = _mm_or_si128(lo_mask, lo_max_j); - hi_max_j = _mm_or_si128(hi_mask, hi_max_j); - - // According to the mask value, it update the index of the max_score - // location. - max_j = _mm256_insertf128_si256(max_j, lo_max_j, 0); - max_j = _mm256_insertf128_si256(max_j, hi_max_j, 1); -#endif - - // Update the max_score value. - max_score = _mm256_max_ps(max_score, score_v); -#endif - trans_offset += tag_num; - } - -#ifdef __AVX512F__ - // Update the alpha and track values. - __m512 x_content = _mm512_loadu_ps( - (const float*)(x + seq_offset + tag_num + j_offset)); - max_score = _mm512_add_ps(max_score, x_content); - _mm512_storeu_ps(reinterpret_cast(alpha_value + seq_offset + - tag_num + j_offset), - max_score); - _mm512_storeu_si512( - reinterpret_cast<__m512i*>(track_value + seq_offset + tag_num + - j_offset), - max_j); -#else - // Update the alpha and track values. - __m256 x_content = _mm256_loadu_ps( - (const float*)(x + seq_offset + tag_num + j_offset)); - max_score = _mm256_add_ps(max_score, x_content); - _mm256_storeu_ps(reinterpret_cast(alpha_value + seq_offset + - tag_num + j_offset), - max_score); - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(track_value + seq_offset + tag_num + - j_offset), - max_j); -#endif - - // Calculate the offset of next step - j_offset += step_size; - if (j == steps - 1) { - if (remain > 0) { - j_offset += last_offset; - } else { - break; - } - } - } - - seq_offset += tag_num; - } - } else { - for (size_t i = 0; i < tag_num; ++i) alpha_value[i] = w[i] + x[i]; - - for (size_t k = 1; k < seq_len; ++k) { - for (size_t i = 0; i < tag_num; ++i) { - T max_score = -std::numeric_limits::max(); - int max_j = 0; - for (size_t j = 0; j < tag_num; ++j) { - T score = alpha_value[(k - 1) * tag_num + j] + - w[(j + state_trans_base_idx) * tag_num + i]; - if (score > max_score) { - max_score = score; - max_j = j; - } - } - - alpha_value[k * tag_num + i] = max_score + x[k * tag_num + i]; - track_value[k * tag_num + i] = max_j; - } - } - } -#else - for (size_t i = 0; i < tag_num; ++i) alpha_value[i] = w[i] + x[i]; - - for (size_t k = 1; k < seq_len; ++k) { - for (size_t i = 0; i < tag_num; ++i) { - T max_score = -std::numeric_limits::max(); - int max_j = 0; - for (size_t j = 0; j < tag_num; ++j) { - T score = alpha_value[(k - 1) * tag_num + j] + - w[(j + state_trans_base_idx) * tag_num + i]; - if (score > max_score) { - max_score = score; - max_j = j; - } - } - - alpha_value[k * tag_num + i] = max_score + x[k * tag_num + i]; - track_value[k * tag_num + i] = max_j; - } - } - -#endif + const auto& ker = math::jitkernel::KernelPool::Instance() + .template Get>( + static_cast(tag_num)); + ker->Compute(static_cast(seq_len), x, w, alpha_value, track_value); T max_score = -std::numeric_limits::max(); int max_i = 0; for (size_t i = 0; i < tag_num; ++i) { diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index 55e2ea760158cda631ec07e2c7d318ec1cf79b77..17b675fba8067851f6149edafcc9096690a3fd34 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -76,6 +76,6 @@ endif() cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split) cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) cc_library(jit_kernel - SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc + SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc jit_kernel_crf_decode.cc DEPS cpu_info cblas) cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel) diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index 9088d0c7a6307c3fbd9707c719ec9e6f6c85fbdb..48e180b1fd43b06cc13f7a4b00c73aff2eb940ac 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -151,6 +151,13 @@ class GRUKernel : public Kernel { virtual void ComputeHtPart2(T *gates, const T *ht_1, T *ht) const = 0; }; +template +class CRFDecodeKernel : public Kernel { + public: + virtual void Compute(const int seq_len, const T *x, const T *w, T *alpha, + int *track) const = 0; +}; + } // namespace jitkernel } // namespace math } // namespace operators diff --git a/paddle/fluid/operators/math/jit_kernel_crf_decode.cc b/paddle/fluid/operators/math/jit_kernel_crf_decode.cc new file mode 100644 index 0000000000000000000000000000000000000000..bfc1b911a76038d6816f45fec42833390fdc6075 --- /dev/null +++ b/paddle/fluid/operators/math/jit_kernel_crf_decode.cc @@ -0,0 +1,297 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/math/jit_kernel.h" +#include +#include +#include "paddle/fluid/operators/math/jit_kernel_macro.h" +#ifdef __AVX__ +#include +#endif + +namespace paddle { +namespace operators { +namespace math { +namespace jitkernel { + +namespace jit = platform::jit; + +/* CRF Decode JitKernel */ +template +class CRFDecodeKernelImpl : public CRFDecodeKernel { + public: + explicit CRFDecodeKernelImpl(int tag_num) : CRFDecodeKernel() { + this->num_ = tag_num; + } + void Compute(const int seq_len, const T* x, const T* w, T* alpha, + int* track) const override { + constexpr int state_trans_base_idx = 2; + for (int i = 0; i < this->num_; ++i) { + alpha[i] = w[i] + x[i]; + } + for (int k = 1; k < seq_len; ++k) { + for (int i = 0; i < this->num_; ++i) { + T max_score = -std::numeric_limits::max(); + int max_j = 0; + for (int j = 0; j < this->num_; ++j) { + T score = alpha[(k - 1) * this->num_ + j] + + w[(j + state_trans_base_idx) * this->num_ + i]; + if (score > max_score) { + max_score = score; + max_j = j; + } + } + alpha[k * this->num_ + i] = max_score + x[k * this->num_ + i]; + track[k * this->num_ + i] = max_j; + } + } + } +}; + +#define INIT_ALPHA(step_size) \ + /* Setup the alpha initial value.*/ \ + int i_offset = 0; \ + int last_offset = this->rest_ - step_size; \ + for (int i = 0; i <= this->end_; ++i) { \ + /* weights, input and alpha values. */ \ + __m256 w_content, x_content, alpha_content; \ + /* Load the relevant data into the variables from un-aligned address.*/ \ + w_content = _mm256_loadu_ps(w + i_offset); \ + x_content = _mm256_loadu_ps(x + i_offset); \ + alpha_content = _mm256_add_ps(w_content, x_content); \ + _mm256_storeu_ps(alpha + i_offset, alpha_content); \ + i_offset += step_size; \ + if (i == this->end_ - 1) { \ + if (this->rest_ > 0) { \ + i_offset += last_offset; \ + } else { \ + break; \ + } \ + } \ + } + +#define UPDATE_ALPHA(step_size) \ + /* Update the alpha and track values. */ \ + __m256 x_content = _mm256_loadu_ps(x + seq_offset + this->num_ + j_offset); \ + max_score = _mm256_add_ps(max_score, x_content); \ + _mm256_storeu_ps(alpha + seq_offset + this->num_ + j_offset, max_score); \ + _mm256_storeu_si256( \ + reinterpret_cast<__m256i*>(track + seq_offset + this->num_ + j_offset), \ + max_j); \ + /* Calculate the offset of next step*/ \ + j_offset += step_size; \ + if (j == this->end_ - 1) { \ + if (this->rest_ > 0) { \ + j_offset += last_offset; \ + } else { \ + break; \ + } \ + } + +#define INTRIAVX_FLOAT(block) \ + template <> \ + CRFDecodeKernelImpl::CRFDecodeKernelImpl( \ + int tag_num) \ + : CRFDecodeKernel() { \ + this->num_ = tag_num; \ + this->end_ = this->num_ / AVX_FLOAT_BLOCK; \ + this->rest_ = this->num_ % AVX_FLOAT_BLOCK; \ + } \ + template <> \ + void CRFDecodeKernelImpl::Compute( \ + const int seq_len, const float* x, const float* w, float* alpha, \ + int* track) const { \ + INIT_ALPHA(AVX_FLOAT_BLOCK) \ + /* Use the column-major strategy to get the location of maximum score.*/ \ + int seq_offset = 0; \ + constexpr int state_trans_base_idx = 2; \ + for (int k = 1; k < seq_len; ++k) { \ + int j_offset = 0; \ + for (int j = 0; j <= this->end_; ++j) { \ + /* Initialize the variables of maximum score and location.*/ \ + __m256 max_score = _mm256_set1_ps(-std::numeric_limits::max()); \ + __m256i max_j = _mm256_set1_epi32(0); \ + /* Calculate the offset of transition_weights.*/ \ + int trans_offset = state_trans_base_idx * this->num_ + j_offset; \ + for (int i = 0; i < this->num_; ++i) { \ + /* Initalize the content of alpha variable with related offset.*/ \ + __m256 alpha_content = _mm256_broadcast_ss(alpha + seq_offset + i); \ + /* Obtain the content of weights from un-aligned address.*/ \ + __m256 w_content = _mm256_loadu_ps(w + trans_offset); \ + __m256 score_v = _mm256_add_ps(alpha_content, w_content); \ + __m256 mask = _mm256_cmp_ps(score_v, max_score, _CMP_GT_OS); \ + /* According to the mask value, update the index of the max_score.*/ \ + /* AVX instructions.*/ \ + __m128i lo_max_j = _mm256_extractf128_si256(max_j, 0); \ + __m128i hi_max_j = _mm256_extractf128_si256(max_j, 1); \ + __m128i lo_mask = _mm256_extractf128_si256((__m256i)mask, 0); \ + __m128i hi_mask = _mm256_extractf128_si256((__m256i)mask, 1); \ + lo_max_j = _mm_andnot_si128(lo_mask, lo_max_j); \ + hi_max_j = _mm_andnot_si128(hi_mask, hi_max_j); \ + lo_mask = _mm_and_si128(lo_mask, _mm_set1_epi32(i)); \ + hi_mask = _mm_and_si128(hi_mask, _mm_set1_epi32(i)); \ + lo_max_j = _mm_or_si128(lo_mask, lo_max_j); \ + hi_max_j = _mm_or_si128(hi_mask, hi_max_j); \ + max_j = _mm256_insertf128_si256(max_j, lo_max_j, 0); \ + max_j = _mm256_insertf128_si256(max_j, hi_max_j, 1); \ + /* AVX done*/ \ + /* Update the max_score value.*/ \ + max_score = _mm256_max_ps(max_score, score_v); \ + trans_offset += this->num_; \ + } \ + UPDATE_ALPHA(AVX_FLOAT_BLOCK) \ + } \ + seq_offset += this->num_; \ + } \ + } + +#define INTRIAVX2_FLOAT(block) \ + template <> \ + CRFDecodeKernelImpl::CRFDecodeKernelImpl( \ + int tag_num) \ + : CRFDecodeKernel() { \ + this->num_ = tag_num; \ + this->end_ = this->num_ / AVX2_FLOAT_BLOCK; \ + this->rest_ = this->num_ % AVX2_FLOAT_BLOCK; \ + } \ + template <> \ + void CRFDecodeKernelImpl::Compute( \ + const int seq_len, const float* x, const float* w, float* alpha, \ + int* track) const { \ + INIT_ALPHA(AVX2_FLOAT_BLOCK) \ + /* Use the column-major strategy to get the location of maximum score.*/ \ + int seq_offset = 0; \ + constexpr int state_trans_base_idx = 2; \ + for (int k = 1; k < seq_len; ++k) { \ + int j_offset = 0; \ + for (int j = 0; j <= this->end_; ++j) { \ + /* Initialize the variables of maximum score and location.*/ \ + __m256 max_score = _mm256_set1_ps(-std::numeric_limits::max()); \ + __m256i max_j = _mm256_set1_epi32(0); \ + /* Calculate the offset of transition_weights.*/ \ + int trans_offset = state_trans_base_idx * this->num_ + j_offset; \ + for (int i = 0; i < this->num_; ++i) { \ + /* Initalize the content of alpha variable with related offset.*/ \ + __m256 alpha_content = _mm256_broadcast_ss(alpha + seq_offset + i); \ + /* Obtain the content of weights from un-aligned address.*/ \ + __m256 w_content = _mm256_loadu_ps(w + trans_offset); \ + __m256 score_v = _mm256_add_ps(alpha_content, w_content); \ + __m256 mask = _mm256_cmp_ps(score_v, max_score, _CMP_GT_OS); \ + /* According to the mask value, update the index of the max_score.*/ \ + /* AVX2 instructions.*/ \ + max_j = _mm256_or_si256( \ + _mm256_andnot_si256((__m256i)mask, max_j), \ + _mm256_and_si256((__m256i)mask, _mm256_set1_epi32(i))); \ + /* Update the max_score value.*/ \ + max_score = _mm256_max_ps(max_score, score_v); \ + trans_offset += this->num_; \ + } \ + UPDATE_ALPHA(AVX2_FLOAT_BLOCK) \ + } \ + seq_offset += this->num_; \ + } \ + } + +#define INTRIAVX512_FLOAT(block) \ + template <> \ + CRFDecodeKernelImpl::CRFDecodeKernelImpl( \ + int tag_num) \ + : CRFDecodeKernel() { \ + this->num_ = tag_num; \ + this->end_ = this->num_ / AVX512_FLOAT_BLOCK; \ + this->rest_ = this->num_ % AVX512_FLOAT_BLOCK; \ + } \ + template <> \ + void CRFDecodeKernelImpl::Compute( \ + const int seq_len, const float* x, const float* w, float* alpha, \ + int* track) const { \ + INIT_ALPHA(AVX512_FLOAT_BLOCK) \ + /* Use the column-major strategy to get the location of maximum score.*/ \ + int seq_offset = 0; \ + constexpr int state_trans_base_idx = 2; \ + for (int k = 1; k < seq_len; ++k) { \ + int j_offset = 0; \ + for (int j = 0; j <= this->end_; ++j) { \ + /* Initialize the variables of maximum score and location.*/ \ + __m512 max_score = _mm512_set1_ps(-std::numeric_limits::max()); \ + __m512i max_j = _mm512_setzero_si512(); \ + /* Calculate the offset of transition_weights.*/ \ + int trans_offset = state_trans_base_idx * this->num_ + j_offset; \ + for (int i = 0; i < this->num_; ++i) { \ + /* Initalize the content of alpha variable with related offset.*/ \ + __m512 alpha_content = _mm512_set1_ps(*(alpha + seq_offset + i)); \ + /* Obtain the content of weights from un-aligned address.*/ \ + __m512 w_content = _mm512_loadu_ps(w + trans_offset); \ + __m512 score_v = _mm512_add_ps(alpha_content, w_content); \ + __mmask16 mask = _mm512_cmp_ps_mask(score_v, max_score, _CMP_GT_OS); \ + /* AVX512 instructions.*/ \ + max_j = _mm512_mask_set1_epi32(max_j, mask, i); \ + /* Update the max_score value.*/ \ + max_score = _mm512_max_ps(max_score, score_v); \ + trans_offset += this->num_; \ + } \ + /* Update the alpha and track values.*/ \ + __m512 x_content = \ + _mm512_loadu_ps(x + seq_offset + this->num_ + j_offset); \ + max_score = _mm512_add_ps(max_score, x_content); \ + _mm512_storeu_ps(alpha_value + seq_offset + this->tag_num_ + j_offset, \ + max_score); \ + _mm512_storeu_si512(reinterpret_cast<__m512i*>(track + seq_offset + \ + this->num_ + j_offset), \ + max_j); \ + /* Calculate the offset of next step*/ \ + j_offset += AVX512_FLOAT_BLOCK; \ + if (j == this->end_ - 1) { \ + if (this->rest_ > 0) { \ + j_offset += last_offset; \ + } else { \ + break; \ + } \ + } \ + } \ + seq_offset += this->num_; \ + } \ + } + +#ifdef __AVX__ +INTRIAVX_FLOAT(kEQ8); +INTRIAVX_FLOAT(kGT8LT16); +INTRIAVX_FLOAT(kEQ16); +INTRIAVX_FLOAT(kGT16); +#endif +#ifdef __AVX2__ +INTRIAVX2_FLOAT(kEQ8); +INTRIAVX2_FLOAT(kGT8LT16); +INTRIAVX2_FLOAT(kEQ16); +INTRIAVX2_FLOAT(kGT16); +#endif +#ifdef __AVX512F__ +INTRIAVX2_FLOAT(kEQ8); +INTRIAVX2_FLOAT(kGT8LT16); +INTRIAVX512_FLOAT(kEQ16); +INTRIAVX512_FLOAT(kGT16); +#endif + +#undef INTRIAVX512_FLOAT +#undef INTRIAVX2_FLOAT +#undef INTRIAVX_FLOAT +#undef INIT_ALPHA +#undef UPDATE_ALPHA + +REGISTER_JITKERNEL(crf_decode, CRFDecodeKernel); + +} // namespace jitkernel +} // namespace math +} // namespace operators +} // namespace paddle