提交 1f00723f 编写于 作者: T tensor-tang

exp, sigmoid, tanh jitcode support more size

test=develop
上级 8cda7b3d
...@@ -33,11 +33,11 @@ namespace math { ...@@ -33,11 +33,11 @@ namespace math {
#define SIGMOID_THRESHOLD_MIN -40.0 #define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0 #define SIGMOID_THRESHOLD_MAX 13.0
#define AVX_FLOAT_BLOCK 8 #define YMM_FLOAT_BLOCK 8
#define AVX_DOUBLE_BLOCK 4 #define AVX_DOUBLE_BLOCK 4
#define AVX2_FLOAT_BLOCK 8 #define YMM_FLOAT_BLOCK 8
#define AVX2_DOUBLE_BLOCK 4 #define AVX2_DOUBLE_BLOCK 4
#define AVX512_FLOAT_BLOCK 16 #define ZMM_FLOAT_BLOCK 16
#define AVX512_DOUBLE_BLOCK 8 #define AVX512_DOUBLE_BLOCK 8
template <typename T> template <typename T>
...@@ -88,7 +88,7 @@ template <> ...@@ -88,7 +88,7 @@ template <>
inline void vec_scal<float, platform::jit::avx>(const int n, const float a, inline void vec_scal<float, platform::jit::avx>(const int n, const float a,
const float* x, float* y) { const float* x, float* y) {
#ifdef __AVX__ #ifdef __AVX__
constexpr int block = AVX_FLOAT_BLOCK; constexpr int block = YMM_FLOAT_BLOCK;
if (n < block) { if (n < block) {
vec_scal<float, platform::jit::isa_any>(n, a, x, y); vec_scal<float, platform::jit::isa_any>(n, a, x, y);
return; return;
...@@ -142,7 +142,7 @@ template <> ...@@ -142,7 +142,7 @@ template <>
inline void vec_bias_sub<float, platform::jit::avx>(const int n, const float a, inline void vec_bias_sub<float, platform::jit::avx>(const int n, const float a,
const float* x, float* y) { const float* x, float* y) {
#ifdef __AVX__ #ifdef __AVX__
constexpr int block = AVX_FLOAT_BLOCK; constexpr int block = YMM_FLOAT_BLOCK;
if (n < block) { if (n < block) {
vec_bias_sub<float, platform::jit::isa_any>(n, a, x, y); vec_bias_sub<float, platform::jit::isa_any>(n, a, x, y);
return; return;
...@@ -200,7 +200,7 @@ inline void vec_cross<float, platform::jit::avx>(const int n, const float* x, ...@@ -200,7 +200,7 @@ inline void vec_cross<float, platform::jit::avx>(const int n, const float* x,
const float* y, const float* z, const float* y, const float* z,
float* out) { float* out) {
#ifdef __AVX__ #ifdef __AVX__
constexpr int block = AVX_FLOAT_BLOCK; constexpr int block = YMM_FLOAT_BLOCK;
if (n < block) { if (n < block) {
vec_cross<float, platform::jit::isa_any>(n, x, y, z, out); vec_cross<float, platform::jit::isa_any>(n, x, y, z, out);
return; return;
...@@ -257,7 +257,7 @@ template <> ...@@ -257,7 +257,7 @@ template <>
inline void vec_add_bias<float, platform::jit::avx>(const int n, const float a, inline void vec_add_bias<float, platform::jit::avx>(const int n, const float a,
const float* x, float* y) { const float* x, float* y) {
#ifdef __AVX__ #ifdef __AVX__
constexpr int block = AVX_FLOAT_BLOCK; constexpr int block = YMM_FLOAT_BLOCK;
if (n < block) { if (n < block) {
vec_add_bias<float, platform::jit::isa_any>(n, a, x, y); vec_add_bias<float, platform::jit::isa_any>(n, a, x, y);
return; return;
...@@ -326,7 +326,7 @@ template <> ...@@ -326,7 +326,7 @@ template <>
inline void vec_sigmoid<float, platform::jit::avx>(const int n, const float* x, inline void vec_sigmoid<float, platform::jit::avx>(const int n, const float* x,
float* y) { float* y) {
#ifdef __AVX__ #ifdef __AVX__
constexpr int block = AVX_FLOAT_BLOCK; constexpr int block = YMM_FLOAT_BLOCK;
if (n < block) { if (n < block) {
vec_sigmoid<float, platform::jit::isa_any>(n, x, y); vec_sigmoid<float, platform::jit::isa_any>(n, x, y);
return; return;
...@@ -415,7 +415,7 @@ template <> ...@@ -415,7 +415,7 @@ template <>
inline void vec_relu<float, platform::jit::avx>(const int n, const float* x, inline void vec_relu<float, platform::jit::avx>(const int n, const float* x,
float* y) { float* y) {
#ifdef __AVX__ #ifdef __AVX__
constexpr int block = AVX_FLOAT_BLOCK; constexpr int block = YMM_FLOAT_BLOCK;
if (n < block * 4) { if (n < block * 4) {
vec_relu<float, platform::jit::isa_any>(n, x, y); vec_relu<float, platform::jit::isa_any>(n, x, y);
return; return;
......
...@@ -41,7 +41,7 @@ void VXXJitCode::generate() { ...@@ -41,7 +41,7 @@ void VXXJitCode::generate() {
} else if (scalar_index_ == 2) { } else if (scalar_index_ == 2) {
vbroadcastss(ymm_src2, ptr[param2]); vbroadcastss(ymm_src2, ptr[param2]);
} }
for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) { for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
if (scalar_index_ != 1) { if (scalar_index_ != 1) {
vmovups(ymm_src1, ptr[param1 + offset]); vmovups(ymm_src1, ptr[param1 + offset]);
} }
...@@ -57,9 +57,9 @@ void VXXJitCode::generate() { ...@@ -57,9 +57,9 @@ void VXXJitCode::generate() {
vmaxps(ymm_dst, ymm_zero, ymm_dst); vmaxps(ymm_dst, ymm_zero, ymm_dst);
} }
vmovups(ptr[param3 + offset], ymm_dst); vmovups(ptr[param3 + offset], ymm_dst);
offset += sizeof(float) * AVX_FLOAT_BLOCK; offset += sizeof(float) * YMM_FLOAT_BLOCK;
} }
int rest = num_ % AVX_FLOAT_BLOCK; int rest = num_ % YMM_FLOAT_BLOCK;
if (rest >= 4) { if (rest >= 4) {
if (scalar_index_ != 1) { if (scalar_index_ != 1) {
vmovups(xmm_src1, ptr[param1 + offset]); vmovups(xmm_src1, ptr[param1 + offset]);
...@@ -133,23 +133,23 @@ void VXXJitCode::generate() { ...@@ -133,23 +133,23 @@ void VXXJitCode::generate() {
#define REPEAT_8TIMES(val) val, val, val, val, val, val, val, val #define REPEAT_8TIMES(val) val, val, val, val, val, val, val, val
#define OFFSET_EXP_ONE 0 * AVX_FLOAT_BLOCK * sizeof(float) #define OFFSET_EXP_ONE 0 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_TWO 1 * AVX_FLOAT_BLOCK * sizeof(float) #define OFFSET_EXP_TWO 1 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_0P5 2 * AVX_FLOAT_BLOCK * sizeof(float) #define OFFSET_EXP_0P5 2 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_HIG 3 * AVX_FLOAT_BLOCK * sizeof(float) #define OFFSET_EXP_HIG 3 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOW 4 * AVX_FLOAT_BLOCK * sizeof(float) #define OFFSET_EXP_LOW 4 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOG2EF 5 * AVX_FLOAT_BLOCK * sizeof(float) #define OFFSET_EXP_LOG2EF 5 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_C1 6 * AVX_FLOAT_BLOCK * sizeof(float) #define OFFSET_EXP_C1 6 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_C2 7 * AVX_FLOAT_BLOCK * sizeof(float) #define OFFSET_EXP_C2 7 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P0 8 * AVX_FLOAT_BLOCK * sizeof(float) #define OFFSET_EXP_P0 8 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P1 9 * AVX_FLOAT_BLOCK * sizeof(float) #define OFFSET_EXP_P1 9 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P2 10 * AVX_FLOAT_BLOCK * sizeof(float) #define OFFSET_EXP_P2 10 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P3 11 * AVX_FLOAT_BLOCK * sizeof(float) #define OFFSET_EXP_P3 11 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P4 12 * AVX_FLOAT_BLOCK * sizeof(float) #define OFFSET_EXP_P4 12 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P5 13 * AVX_FLOAT_BLOCK * sizeof(float) #define OFFSET_EXP_P5 13 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_MAX_INPUT 14 * AVX_FLOAT_BLOCK * sizeof(float) #define OFFSET_EXP_MAX_INPUT 14 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MAX 15 * AVX_FLOAT_BLOCK * sizeof(float) #define OFFSET_SIGMOID_MAX 15 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MIN 16 * AVX_FLOAT_BLOCK * sizeof(float) #define OFFSET_SIGMOID_MIN 16 * YMM_FLOAT_BLOCK * sizeof(float)
static const float exp_float_consts[] ALIGN32 = { static const float exp_float_consts[] ALIGN32 = {
REPEAT_8TIMES(1.f), REPEAT_8TIMES(1.f),
...@@ -177,9 +177,12 @@ bool VActJitCode::init(int d, operand_type type) { ...@@ -177,9 +177,12 @@ bool VActJitCode::init(int d, operand_type type) {
bool ok = MayIUse(avx); bool ok = MayIUse(avx);
if (type == operand_type::relu) { if (type == operand_type::relu) {
return ok; return ok;
} else if (type == operand_type::exp) {
// exp is slower than mkl when d >= 256
return ok && d % 8 == 0 && d < 256;
} else { } else {
// TODO(TJ): support more // TODO(TJ): support more
return ok && d == 8; // only 8 yet return ok && d % 8 == 0;
} }
} }
...@@ -224,7 +227,7 @@ void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx, ...@@ -224,7 +227,7 @@ void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P0]); vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P0]);
vmulps(ymm_dst, ymm_src, ymm_tmp); vmulps(ymm_dst, ymm_src, ymm_tmp);
for (size_t i = OFFSET_EXP_P1; i < OFFSET_EXP_P5; for (size_t i = OFFSET_EXP_P1; i < OFFSET_EXP_P5;
i += (AVX_FLOAT_BLOCK * sizeof(float))) { i += (YMM_FLOAT_BLOCK * sizeof(float))) {
vmovaps(ymm_tmp, ptr[reg_ptr_global + i]); // P1~P4 vmovaps(ymm_tmp, ptr[reg_ptr_global + i]); // P1~P4
vaddps(ymm_dst, ymm_dst, ymm_tmp); vaddps(ymm_dst, ymm_dst, ymm_tmp);
vmulps(ymm_dst, ymm_dst, ymm_src); vmulps(ymm_dst, ymm_dst, ymm_src);
...@@ -249,7 +252,7 @@ void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx, ...@@ -249,7 +252,7 @@ void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
reg64_t reg_ptr_tmp = reg_ptr_global; reg64_t reg_ptr_tmp = reg_ptr_global;
mov(reg_ptr_tmp, reinterpret_cast<size_t>(g_tmp_mem)); mov(reg_ptr_tmp, reinterpret_cast<size_t>(g_tmp_mem));
vmovdqa(ptr[reg_ptr_tmp], ymm_int); vmovdqa(ptr[reg_ptr_tmp], ymm_int);
vmovdqa(ptr[reg_ptr_tmp + AVX_FLOAT_BLOCK * sizeof(float)], ymm_tmp); vmovdqa(ptr[reg_ptr_tmp + YMM_FLOAT_BLOCK * sizeof(float)], ymm_tmp);
vpaddd(xtmp1, xtmp1, xtmp2); vpaddd(xtmp1, xtmp1, xtmp2);
vpslld(xtmp1, xtmp1, 23); vpslld(xtmp1, xtmp1, 23);
vmovdqa(ptr[reg_ptr_tmp], xtmp1); vmovdqa(ptr[reg_ptr_tmp], xtmp1);
...@@ -257,7 +260,7 @@ void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx, ...@@ -257,7 +260,7 @@ void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
vmovdqa(xtmp1, ptr[reg_ptr_tmp + 4 /*xmm float block*/ * sizeof(float)]); vmovdqa(xtmp1, ptr[reg_ptr_tmp + 4 /*xmm float block*/ * sizeof(float)]);
vmovdqa(xtmp2, vmovdqa(xtmp2,
ptr[reg_ptr_tmp + ptr[reg_ptr_tmp +
(AVX_FLOAT_BLOCK + 4 /*xmm float block*/) * sizeof(float)]); (YMM_FLOAT_BLOCK + 4 /*xmm float block*/) * sizeof(float)]);
vpaddd(xtmp1, xtmp1, xtmp2); vpaddd(xtmp1, xtmp1, xtmp2);
vpslld(xtmp1, xtmp1, 23); vpslld(xtmp1, xtmp1, 23);
vmovdqa(ptr[reg_ptr_tmp + 4 /*xmm float block*/ * sizeof(float)], xtmp1); vmovdqa(ptr[reg_ptr_tmp + 4 /*xmm float block*/ * sizeof(float)], xtmp1);
...@@ -317,7 +320,7 @@ void VActJitCode::generate() { ...@@ -317,7 +320,7 @@ void VActJitCode::generate() {
vxorps(ymm_zero, ymm_zero, ymm_zero); vxorps(ymm_zero, ymm_zero, ymm_zero);
} }
int offset = 0; int offset = 0;
for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) { for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
vmovups(ymm_src, ptr[param1 + offset]); vmovups(ymm_src, ptr[param1 + offset]);
switch (type_) { switch (type_) {
case operand_type::relu: case operand_type::relu:
...@@ -338,14 +341,14 @@ void VActJitCode::generate() { ...@@ -338,14 +341,14 @@ void VActJitCode::generate() {
break; break;
} }
vmovups(ptr[param2 + offset], ymm_dst); vmovups(ptr[param2 + offset], ymm_dst);
offset += sizeof(float) * AVX_FLOAT_BLOCK; offset += sizeof(float) * YMM_FLOAT_BLOCK;
} }
if (type_ != operand_type::relu) { if (type_ != operand_type::relu) {
// TODO(TJ): remove me // TODO(TJ): remove me
ret(); ret();
return; return;
} }
int rest = num_ % AVX_FLOAT_BLOCK; int rest = num_ % YMM_FLOAT_BLOCK;
if (rest >= 4) { if (rest >= 4) {
vmovups(xmm_src, ptr[param1 + offset]); vmovups(xmm_src, ptr[param1 + offset]);
vmaxps(xmm_dst, xmm_zero, xmm_src); vmaxps(xmm_dst, xmm_zero, xmm_src);
......
...@@ -29,10 +29,9 @@ namespace jitkernel { ...@@ -29,10 +29,9 @@ namespace jitkernel {
#define SIGMOID_THRESHOLD_MIN -40.0 #define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0 #define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0 #define EXP_MAX_INPUT 40.0
// TODO(TJ): change AVX_FLOAT_BLOCK to YMM_FLOAT_BLOCK #define XMM_FLOAT_BLOCK 4
#define AVX_FLOAT_BLOCK 8 #define YMM_FLOAT_BLOCK 8
#define AVX2_FLOAT_BLOCK 8 #define ZMM_FLOAT_BLOCK 16
#define AVX512_FLOAT_BLOCK 16
typedef enum { kLT8, kEQ8, kGT8LT16, kEQ16, kGT16 } jit_block; typedef enum { kLT8, kEQ8, kGT8LT16, kEQ16, kGT16 } jit_block;
......
...@@ -133,7 +133,7 @@ class VMulKernelImpl : public VMulKernel<T> { ...@@ -133,7 +133,7 @@ class VMulKernelImpl : public VMulKernel<T> {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) { if (useJIT(d)) {
// roughly estimate the size of code // roughly estimate the size of code
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::mul, 0, false, jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::mul, 0, false,
sz > 4096 ? sz : 4096)); sz > 4096 ? sz : 4096));
this->Compute = this->Compute =
...@@ -184,7 +184,7 @@ class VAddKernelImpl : public VAddKernel<T> { ...@@ -184,7 +184,7 @@ class VAddKernelImpl : public VAddKernel<T> {
explicit VAddKernelImpl(int d) : VAddKernel<T>() { explicit VAddKernelImpl(int d) : VAddKernel<T>() {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) { if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 0, false, jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 0, false,
sz > 4096 ? sz : 4096)); sz > 4096 ? sz : 4096));
this->Compute = this->Compute =
...@@ -234,7 +234,7 @@ class VAddReluKernelImpl : public VAddReluKernel<T> { ...@@ -234,7 +234,7 @@ class VAddReluKernelImpl : public VAddReluKernel<T> {
explicit VAddReluKernelImpl(int d) : VAddReluKernel<T>() { explicit VAddReluKernelImpl(int d) : VAddReluKernel<T>() {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) { if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 0, true, jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 0, true,
sz > 4096 ? sz : 4096)); sz > 4096 ? sz : 4096));
this->Compute = this->Compute =
...@@ -266,7 +266,7 @@ class VScalKernelImpl : public VScalKernel<T> { ...@@ -266,7 +266,7 @@ class VScalKernelImpl : public VScalKernel<T> {
explicit VScalKernelImpl(int d) : VScalKernel<T>() { explicit VScalKernelImpl(int d) : VScalKernel<T>() {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) { if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::mul, 1, false, jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::mul, 1, false,
sz > 4096 ? sz : 4096)); sz > 4096 ? sz : 4096));
this->Compute = this->Compute =
...@@ -315,7 +315,7 @@ class VAddBiasKernelImpl : public VAddBiasKernel<T> { ...@@ -315,7 +315,7 @@ class VAddBiasKernelImpl : public VAddBiasKernel<T> {
explicit VAddBiasKernelImpl(int d) : VAddBiasKernel<T>() { explicit VAddBiasKernelImpl(int d) : VAddBiasKernel<T>() {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) { if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 1, false, jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 1, false,
sz > 4096 ? sz : 4096)); sz > 4096 ? sz : 4096));
this->Compute = this->Compute =
...@@ -349,7 +349,7 @@ class VReluKernelImpl : public VReluKernel<T> { ...@@ -349,7 +349,7 @@ class VReluKernelImpl : public VReluKernel<T> {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) { if (useJIT(d)) {
size_t sz = 96 /* init size */ + size_t sz = 96 /* init size */ +
d / AVX_FLOAT_BLOCK * 4 /* instructions */ * d / YMM_FLOAT_BLOCK * 4 /* instructions */ *
8 /* average bytes for each instruction */; 8 /* average bytes for each instruction */;
jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::relu, jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::relu,
sz > 4096 ? sz : 4096)); sz > 4096 ? sz : 4096));
......
...@@ -105,14 +105,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> { ...@@ -105,14 +105,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
int tag_num) \ int tag_num) \
: CRFDecodeKernel<float>() { \ : CRFDecodeKernel<float>() { \
this->num_ = tag_num; \ this->num_ = tag_num; \
this->end_ = this->num_ / AVX_FLOAT_BLOCK; \ this->end_ = this->num_ / YMM_FLOAT_BLOCK; \
this->rest_ = this->num_ % AVX_FLOAT_BLOCK; \ this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \
} \ } \
template <> \ template <> \
void CRFDecodeKernelImpl<float, jit::avx, block>::Compute( \ void CRFDecodeKernelImpl<float, jit::avx, block>::Compute( \
const int seq_len, const float* x, const float* w, float* alpha, \ const int seq_len, const float* x, const float* w, float* alpha, \
int* track) const { \ int* track) const { \
INIT_ALPHA(AVX_FLOAT_BLOCK) \ INIT_ALPHA(YMM_FLOAT_BLOCK) \
/* Use the column-major strategy to get the location of maximum score.*/ \ /* Use the column-major strategy to get the location of maximum score.*/ \
int seq_offset = 0; \ int seq_offset = 0; \
constexpr int state_trans_base_idx = 2; \ constexpr int state_trans_base_idx = 2; \
...@@ -150,7 +150,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> { ...@@ -150,7 +150,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
max_score = _mm256_max_ps(max_score, score_v); \ max_score = _mm256_max_ps(max_score, score_v); \
trans_offset += this->num_; \ trans_offset += this->num_; \
} \ } \
UPDATE_ALPHA(AVX_FLOAT_BLOCK) \ UPDATE_ALPHA(YMM_FLOAT_BLOCK) \
} \ } \
seq_offset += this->num_; \ seq_offset += this->num_; \
} \ } \
...@@ -161,14 +161,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> { ...@@ -161,14 +161,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
CRFDecodeKernelImpl<float, isa, block>::CRFDecodeKernelImpl(int tag_num) \ CRFDecodeKernelImpl<float, isa, block>::CRFDecodeKernelImpl(int tag_num) \
: CRFDecodeKernel<float>() { \ : CRFDecodeKernel<float>() { \
this->num_ = tag_num; \ this->num_ = tag_num; \
this->end_ = this->num_ / AVX2_FLOAT_BLOCK; \ this->end_ = this->num_ / YMM_FLOAT_BLOCK; \
this->rest_ = this->num_ % AVX2_FLOAT_BLOCK; \ this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \
} \ } \
template <> \ template <> \
void CRFDecodeKernelImpl<float, isa, block>::Compute( \ void CRFDecodeKernelImpl<float, isa, block>::Compute( \
const int seq_len, const float* x, const float* w, float* alpha, \ const int seq_len, const float* x, const float* w, float* alpha, \
int* track) const { \ int* track) const { \
INIT_ALPHA(AVX2_FLOAT_BLOCK) \ INIT_ALPHA(YMM_FLOAT_BLOCK) \
/* Use the column-major strategy to get the location of maximum score.*/ \ /* Use the column-major strategy to get the location of maximum score.*/ \
int seq_offset = 0; \ int seq_offset = 0; \
constexpr int state_trans_base_idx = 2; \ constexpr int state_trans_base_idx = 2; \
...@@ -196,7 +196,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> { ...@@ -196,7 +196,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
max_score = _mm256_max_ps(max_score, score_v); \ max_score = _mm256_max_ps(max_score, score_v); \
trans_offset += this->num_; \ trans_offset += this->num_; \
} \ } \
UPDATE_ALPHA(AVX2_FLOAT_BLOCK) \ UPDATE_ALPHA(YMM_FLOAT_BLOCK) \
} \ } \
seq_offset += this->num_; \ seq_offset += this->num_; \
} \ } \
...@@ -208,14 +208,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> { ...@@ -208,14 +208,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
int tag_num) \ int tag_num) \
: CRFDecodeKernel<float>() { \ : CRFDecodeKernel<float>() { \
this->num_ = tag_num; \ this->num_ = tag_num; \
this->end_ = this->num_ / AVX512_FLOAT_BLOCK; \ this->end_ = this->num_ / ZMM_FLOAT_BLOCK; \
this->rest_ = this->num_ % AVX512_FLOAT_BLOCK; \ this->rest_ = this->num_ % ZMM_FLOAT_BLOCK; \
} \ } \
template <> \ template <> \
void CRFDecodeKernelImpl<float, jit::avx512f, block>::Compute( \ void CRFDecodeKernelImpl<float, jit::avx512f, block>::Compute( \
const int seq_len, const float* x, const float* w, float* alpha, \ const int seq_len, const float* x, const float* w, float* alpha, \
int* track) const { \ int* track) const { \
INIT_ALPHA(AVX512_FLOAT_BLOCK) \ INIT_ALPHA(ZMM_FLOAT_BLOCK) \
/* Use the column-major strategy to get the location of maximum score.*/ \ /* Use the column-major strategy to get the location of maximum score.*/ \
int seq_offset = 0; \ int seq_offset = 0; \
constexpr int state_trans_base_idx = 2; \ constexpr int state_trans_base_idx = 2; \
...@@ -250,7 +250,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> { ...@@ -250,7 +250,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
this->num_ + j_offset), \ this->num_ + j_offset), \
max_j); \ max_j); \
/* Calculate the offset of next step*/ \ /* Calculate the offset of next step*/ \
j_offset += AVX512_FLOAT_BLOCK; \ j_offset += ZMM_FLOAT_BLOCK; \
if (j == this->end_ - 1) { \ if (j == this->end_ - 1) { \
if (this->rest_ > 0) { \ if (this->rest_ > 0) { \
j_offset += last_offset; \ j_offset += last_offset; \
......
...@@ -116,7 +116,7 @@ class VExpKernelImpl : public VExpKernel<T> { ...@@ -116,7 +116,7 @@ class VExpKernelImpl : public VExpKernel<T> {
explicit VExpKernelImpl(int d) : VExpKernel<T>() { explicit VExpKernelImpl(int d) : VExpKernel<T>() {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) { if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; // should change size_t sz = 96 + d / YMM_FLOAT_BLOCK * 70 * 8;
jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::exp, jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::exp,
sz > 4096 ? sz : 4096)); sz > 4096 ? sz : 4096));
this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>(); this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>();
...@@ -167,7 +167,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> { ...@@ -167,7 +167,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
explicit VSigmoidKernelImpl(int d) : VSigmoidKernel<T>() { explicit VSigmoidKernelImpl(int d) : VSigmoidKernel<T>() {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) { if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; // should change size_t sz = 96 + d / YMM_FLOAT_BLOCK * 82 * 8;
jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::sigmoid, jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::sigmoid,
sz > 4096 ? sz : 4096)); sz > 4096 ? sz : 4096));
this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>(); this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>();
...@@ -219,7 +219,7 @@ class VTanhKernelImpl : public VTanhKernel<T> { ...@@ -219,7 +219,7 @@ class VTanhKernelImpl : public VTanhKernel<T> {
explicit VTanhKernelImpl(int d) : VTanhKernel<T>() { explicit VTanhKernelImpl(int d) : VTanhKernel<T>() {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) { if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; // should change size_t sz = 96 + d / YMM_FLOAT_BLOCK * 84 * 8;
jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::tanh, jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::tanh,
sz > 4096 ? sz : 4096)); sz > 4096 ? sz : 4096));
this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>(); this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>();
......
...@@ -95,13 +95,13 @@ namespace jitkernel { ...@@ -95,13 +95,13 @@ namespace jitkernel {
namespace jit = platform::jit; namespace jit = platform::jit;
// TODO(TJ): below defines are deprecated, would be remove recently // TODO(TJ): below defines are deprecated, would be remove recently
#define SEARCH_BLOCK(macro_, ker, dtype, isa) \ #define SEARCH_BLOCK(macro_, ker, dtype, isa) \
if (d < AVX_FLOAT_BLOCK) { \ if (d < YMM_FLOAT_BLOCK) { \
macro_(ker, dtype, isa, kLT8); \ macro_(ker, dtype, isa, kLT8); \
} else if (d == AVX_FLOAT_BLOCK) { \ } else if (d == YMM_FLOAT_BLOCK) { \
macro_(ker, dtype, isa, kEQ8); \ macro_(ker, dtype, isa, kEQ8); \
} else if (d > AVX_FLOAT_BLOCK && d < AVX512_FLOAT_BLOCK) { \ } else if (d > YMM_FLOAT_BLOCK && d < ZMM_FLOAT_BLOCK) { \
macro_(ker, dtype, isa, kGT8LT16); \ macro_(ker, dtype, isa, kGT8LT16); \
} else if (d == AVX512_FLOAT_BLOCK) { \ } else if (d == ZMM_FLOAT_BLOCK) { \
macro_(ker, dtype, isa, kEQ16); \ macro_(ker, dtype, isa, kEQ16); \
} else { \ } else { \
macro_(ker, dtype, isa, kGT16); \ macro_(ker, dtype, isa, kGT16); \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册