提交 b1516783 编写于 作者: T tensor-tang

enable crf decoding intrinsic code

上级 4cc7707d
......@@ -19,6 +19,10 @@ PaddlePaddle/Paddle/paddle/fluid/
│ ├── ...
│ ├── mkl/
│ │ └── ...
│ ├── mkldnn/
│ │ └── ...
│ ├── intrinsic/
│ │ └── ...
│ └── openblas/
│ └── ...
└── refer/
......
......@@ -6,4 +6,3 @@ set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} jit_kernel_intrinsic PARENT_SCOPE)
# use mkl kernels by name and type
USE_JITKERNEL_MORE(crfdecoding, intrinsic)
USE_JITKERNEL_MORE(layernorm, intrinsic)
......@@ -13,7 +13,7 @@
* limitations under the License. */
#include "paddle/fluid/operators/jit/more/intrinsic/crf_decoding.h"
#include "paddle/fluid/operators/jit/refer/refer.h"
#include <limits>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
......@@ -21,118 +21,151 @@ namespace paddle {
namespace operators {
namespace jit {
namespace more {
namespace mkl {
template <>
void VMul<float>(const float* x, const float* y, float* z, int n) {
platform::dynload::vsMul(n, x, y, z);
}
template <>
void VMul<double>(const double* x, const double* y, double* z, int n) {
platform::dynload::vdMul(n, x, y, z);
}
template <>
void VAdd<float>(const float* x, const float* y, float* z, int n) {
platform::dynload::vsAdd(n, x, y, z);
}
template <>
void VAdd<double>(const double* x, const double* y, double* z, int n) {
platform::dynload::vdAdd(n, x, y, z);
}
template <>
void VScal<float>(const float* a, const float* x, float* y, int n) {
if (x == y) {
platform::dynload::cblas_sscal(n, *a, y, 1);
} else {
refer::VScal<float>(a, x, y, n);
namespace intrinsic {
void CRFDecoding(const int seq_len, const float* x, const float* w,
float* alpha, int* track, int tag_num) {
const int step_size =
platform::MayIUse(platform::avx512f) ? ZMM_FLOAT_BLOCK : YMM_FLOAT_BLOCK;
const int end = tag_num / step_size;
const int rest = tag_num % step_size;
/* Setup the alpha initial value.*/
int i_offset = 0;
int last_offset = rest - step_size;
for (int i = 0; i <= end; ++i) {
#ifdef __AVX512F__
// Declare the variable for the content of weights, input and alpha values.
__m512 w_content, x_content, alpha_content;
// Load the relevant data into the variables from un-aligned address.
w_content = _mm512_loadu_ps(w + i_offset);
x_content = _mm512_loadu_ps(x + i_offset);
alpha_content = _mm512_add_ps(w_content, x_content);
// Save the alpha value.
_mm512_storeu_ps(alpha_value + i_offset, alpha_content);
#else
// AVX or AVX2
// weights, input and alpha values.
__m256 w_content, x_content, alpha_content;
// Load the relevant data into the variables from un-aligned address.
w_content = _mm256_loadu_ps(w + i_offset);
x_content = _mm256_loadu_ps(x + i_offset);
alpha_content = _mm256_add_ps(w_content, x_content);
_mm256_storeu_ps(alpha + i_offset, alpha_content);
#endif
i_offset += step_size;
if (i == end - 1) {
if (rest > 0) {
i_offset += last_offset;
} else {
break;
}
}
}
}
template <>
void VScal<double>(const double* a, const double* x, double* y, int n) {
if (x == y) {
platform::dynload::cblas_dscal(n, *a, y, 1);
} else {
refer::VScal<double>(a, x, y, n);
// Use the column-major strategy to get the location of maximum score.
int seq_offset = 0;
constexpr int state_trans_base_idx = 2;
for (int k = 1; k < seq_len; ++k) {
int j_offset = 0;
for (int j = 0; j <= end; ++j) {
/* Initialize the variables of maximum score and location.*/
#ifdef __AVX512F__
__m512 max_score = _mm512_set1_ps(-std::numeric_limits<float>::max());
__m512i max_j = _mm512_setzero_si512();
#else
__m256 max_score = _mm256_set1_ps(-std::numeric_limits<float>::max());
__m256i max_j = _mm256_set1_epi32(0);
#endif
/* Calculate the offset of transition_weights.*/
int trans_offset = state_trans_base_idx * tag_num + j_offset;
for (int i = 0; i < tag_num; ++i) {
/* Initalize the content of alpha variable with related offset.*/
#ifdef __AVX512F__
__m512 alpha_content = _mm512_set1_ps(*(alpha + seq_offset + i));
/* Obtain the content of weights from un-aligned address.*/
__m512 w_content = _mm512_loadu_ps(w + trans_offset);
__m512 score_v = _mm512_add_ps(alpha_content, w_content);
__mmask16 mask = _mm512_cmp_ps_mask(score_v, max_score, _CMP_GT_OS);
/* AVX512 instructions.*/
max_j = _mm512_mask_set1_epi32(max_j, mask, i);
/* Update the max_score value.*/
max_score = _mm512_max_ps(max_score, score_v);
#else
__m256 alpha_content = _mm256_broadcast_ss(alpha + seq_offset + i);
/* Obtain the content of weights from un-aligned address.*/
__m256 w_content = _mm256_loadu_ps(w + trans_offset);
__m256 score_v = _mm256_add_ps(alpha_content, w_content);
__m256 mask = _mm256_cmp_ps(score_v, max_score, _CMP_GT_OS);
/* According to the mask value, update the index of the max_score.*/
#ifdef __AVX2__
max_j = _mm256_or_si256(
_mm256_andnot_si256((__m256i)mask, max_j),
_mm256_and_si256((__m256i)mask, _mm256_set1_epi32(i)));
#else
__m128i lo_max_j = _mm256_extractf128_si256(max_j, 0);
__m128i hi_max_j = _mm256_extractf128_si256(max_j, 1);
__m128i lo_mask =
_mm256_extractf128_si256(*(__m256i*)&mask, 0); // NOLINT
__m128i hi_mask =
_mm256_extractf128_si256(*(__m256i*)&mask, 1); // NOLINT
lo_max_j = _mm_andnot_si128(lo_mask, lo_max_j);
hi_max_j = _mm_andnot_si128(hi_mask, hi_max_j);
lo_mask = _mm_and_si128(lo_mask, _mm_set1_epi32(i));
hi_mask = _mm_and_si128(hi_mask, _mm_set1_epi32(i));
lo_max_j = _mm_or_si128(lo_mask, lo_max_j);
hi_max_j = _mm_or_si128(hi_mask, hi_max_j);
max_j = _mm256_insertf128_si256(max_j, lo_max_j, 0);
max_j = _mm256_insertf128_si256(max_j, hi_max_j, 1);
#endif
/* Update the max_score value.*/
max_score = _mm256_max_ps(max_score, score_v);
#endif
trans_offset += tag_num;
}
/* Update the alpha and track values. */
#ifdef __AVX512F__
__m512 x_content =
_mm512_loadu_ps(x + seq_offset + this->num_ + j_offset);
max_score = _mm512_add_ps(max_score, x_content);
_mm512_storeu_ps(alpha + seq_offset + this->num_ + j_offset, max_score);
_mm512_storeu_si512(reinterpret_cast<__m512i*>(track + seq_offset +
this->num_ + j_offset),
max_j);
#else
__m256 x_content = _mm256_loadu_ps(x + seq_offset + tag_num + j_offset);
max_score = _mm256_add_ps(max_score, x_content);
_mm256_storeu_ps(alpha + seq_offset + tag_num + j_offset, max_score);
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(track + seq_offset + tag_num + j_offset),
max_j);
#endif
/* Calculate the offset of next step*/
j_offset += step_size;
if (j == end - 1) {
if (rest > 0) {
j_offset += last_offset;
} else {
break;
}
}
}
seq_offset += tag_num;
}
}
template <>
void VExp<float>(const float* x, float* y, int n) {
platform::dynload::vsExp(n, x, y);
}
template <>
void VExp<double>(const double* x, double* y, int n) {
platform::dynload::vdExp(n, x, y);
}
// TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512
template <>
bool VMulKernel<float>::UseMe(int d) const {
return platform::MayIUse(platform::avx512f) && d > 512;
bool CRFDecodingKernel::UseMe(int d) const {
return platform::MayIUse(platform::avx);
}
template <>
bool VAddKernel<float>::UseMe(int d) const {
return platform::MayIUse(platform::avx512f) && d > 512;
}
template <>
bool VScalKernel<float>::UseMe(int d) const {
return platform::MayIUse(platform::avx512f) && d > 512;
}
template <>
bool VExpKernel<float>::UseMe(int d) const {
return d > 7;
}
template <>
bool VSigmoidKernel<float>::UseMe(int d) const {
return d > 7;
}
template <>
bool VTanhKernel<float>::UseMe(int d) const {
return d > 7;
}
#define AWALYS_USE_ME_WITH_DOUBLE(func) \
template <> \
bool func##Kernel<double>::UseMe(int d) const { \
return true; \
}
AWALYS_USE_ME_WITH_DOUBLE(VMul);
AWALYS_USE_ME_WITH_DOUBLE(VAdd);
AWALYS_USE_ME_WITH_DOUBLE(VScal);
AWALYS_USE_ME_WITH_DOUBLE(VExp);
AWALYS_USE_ME_WITH_DOUBLE(VSigmoid);
AWALYS_USE_ME_WITH_DOUBLE(VTanh);
#undef AWALYS_USE_ME_WITH_DOUBLE
} // namespace mkl
} // namespace intrinsic
} // namespace more
} // namespace jit
} // namespace operators
} // namespace paddle
namespace mkl = paddle::operators::jit::more::mkl;
#define REGISTER_MKL_KERNEL(key, func) \
REGISTER_JITKERNEL_MORE(key, mkl, mkl::func##Kernel<float>, \
mkl::func##Kernel<double>)
REGISTER_MKL_KERNEL(vmul, VMul);
REGISTER_MKL_KERNEL(vadd, VAdd);
REGISTER_MKL_KERNEL(vscal, VScal);
REGISTER_MKL_KERNEL(vexp, VExp);
REGISTER_MKL_KERNEL(vsigmoid, VSigmoid);
REGISTER_MKL_KERNEL(vtanh, VTanh);
namespace intrinsic = paddle::operators::jit::more::intrinsic;
#undef REGISTER_MKL_KERNEL
REGISTER_JITKERNEL_MORE(crfdecoding, intrinsic, intrinsic::CRFDecodingKernel);
......@@ -21,68 +21,18 @@ namespace paddle {
namespace operators {
namespace jit {
namespace more {
namespace mkl {
namespace intrinsic {
template <typename T>
void VMul(const T* x, const T* y, T* z, int n);
void CRFDecoding(const int seq_len, const float* x, const float* w,
float* alpha, int* track, int tag_num);
template <typename T>
void VAdd(const T* x, const T* y, T* z, int n);
class CRFDecodingKernel : public KernelImpl<CRFDecodingTuples<float>> {
public:
CRFDecodingKernel() { this->func = CRFDecoding; }
bool UseMe(typename CRFDecodingTuples<float>::attr_type) const override;
};
template <typename T>
void VScal(const T* a, const T* x, T* y, int n);
template <typename T>
void VExp(const T* x, T* y, int n);
template <typename T>
void VSigmoid(const T* x, T* y, int n) {
const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX;
for (int i = 0; i < n; ++i) {
y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = static_cast<T>(0) - y[i];
}
VExp(y, y, n);
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(1) / (static_cast<T>(1) + y[i]);
}
}
template <typename T>
void VTanh(const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(2) * x[i];
}
VSigmoid(y, y, n);
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(2) * y[i] - static_cast<T>(1);
}
}
#define DECLARE_MKL_KERNEL(name, tuples) \
template <typename T> \
class name##Kernel : public KernelImpl<tuples<T>> { \
public: \
name##Kernel() { this->func = name<T>; } \
bool UseMe(typename tuples<T>::attr_type) const override; \
}
// XYZN
DECLARE_MKL_KERNEL(VMul, XYZNTuples);
DECLARE_MKL_KERNEL(VAdd, XYZNTuples);
// AXYN
DECLARE_MKL_KERNEL(VScal, AXYNTuples);
// XYN
DECLARE_MKL_KERNEL(VExp, XYNTuples);
DECLARE_MKL_KERNEL(VSigmoid, XYNTuples);
DECLARE_MKL_KERNEL(VTanh, XYNTuples);
#undef DECLARE_MKL_KERNEL
} // namespace mkl
} // namespace intrinsic
} // namespace more
} // namespace jit
} // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册