From 1f00723fa379503367abd96ad8f6567fa31c4e86 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Fri, 16 Nov 2018 07:40:41 +0000 Subject: [PATCH] exp, sigmoid, tanh jitcode support more size test=develop --- paddle/fluid/operators/math/cpu_vec.h | 18 +++--- paddle/fluid/operators/math/jit_code.cc | 57 ++++++++++--------- paddle/fluid/operators/math/jit_kernel.h | 7 +-- .../fluid/operators/math/jit_kernel_blas.cc | 12 ++-- .../operators/math/jit_kernel_crf_decode.cc | 24 ++++---- paddle/fluid/operators/math/jit_kernel_exp.cc | 6 +- .../fluid/operators/math/jit_kernel_macro.h | 22 +++---- 7 files changed, 74 insertions(+), 72 deletions(-) diff --git a/paddle/fluid/operators/math/cpu_vec.h b/paddle/fluid/operators/math/cpu_vec.h index 0aed253c80f..7d81aee5969 100644 --- a/paddle/fluid/operators/math/cpu_vec.h +++ b/paddle/fluid/operators/math/cpu_vec.h @@ -33,11 +33,11 @@ namespace math { #define SIGMOID_THRESHOLD_MIN -40.0 #define SIGMOID_THRESHOLD_MAX 13.0 -#define AVX_FLOAT_BLOCK 8 +#define YMM_FLOAT_BLOCK 8 #define AVX_DOUBLE_BLOCK 4 -#define AVX2_FLOAT_BLOCK 8 +#define YMM_FLOAT_BLOCK 8 #define AVX2_DOUBLE_BLOCK 4 -#define AVX512_FLOAT_BLOCK 16 +#define ZMM_FLOAT_BLOCK 16 #define AVX512_DOUBLE_BLOCK 8 template @@ -88,7 +88,7 @@ template <> inline void vec_scal(const int n, const float a, const float* x, float* y) { #ifdef __AVX__ - constexpr int block = AVX_FLOAT_BLOCK; + constexpr int block = YMM_FLOAT_BLOCK; if (n < block) { vec_scal(n, a, x, y); return; @@ -142,7 +142,7 @@ template <> inline void vec_bias_sub(const int n, const float a, const float* x, float* y) { #ifdef __AVX__ - constexpr int block = AVX_FLOAT_BLOCK; + constexpr int block = YMM_FLOAT_BLOCK; if (n < block) { vec_bias_sub(n, a, x, y); return; @@ -200,7 +200,7 @@ inline void vec_cross(const int n, const float* x, const float* y, const float* z, float* out) { #ifdef __AVX__ - constexpr int block = AVX_FLOAT_BLOCK; + constexpr int block = YMM_FLOAT_BLOCK; if (n < block) { vec_cross(n, x, y, z, out); return; @@ -257,7 +257,7 @@ template <> inline void vec_add_bias(const int n, const float a, const float* x, float* y) { #ifdef __AVX__ - constexpr int block = AVX_FLOAT_BLOCK; + constexpr int block = YMM_FLOAT_BLOCK; if (n < block) { vec_add_bias(n, a, x, y); return; @@ -326,7 +326,7 @@ template <> inline void vec_sigmoid(const int n, const float* x, float* y) { #ifdef __AVX__ - constexpr int block = AVX_FLOAT_BLOCK; + constexpr int block = YMM_FLOAT_BLOCK; if (n < block) { vec_sigmoid(n, x, y); return; @@ -415,7 +415,7 @@ template <> inline void vec_relu(const int n, const float* x, float* y) { #ifdef __AVX__ - constexpr int block = AVX_FLOAT_BLOCK; + constexpr int block = YMM_FLOAT_BLOCK; if (n < block * 4) { vec_relu(n, x, y); return; diff --git a/paddle/fluid/operators/math/jit_code.cc b/paddle/fluid/operators/math/jit_code.cc index 15976902759..e3b600d4427 100644 --- a/paddle/fluid/operators/math/jit_code.cc +++ b/paddle/fluid/operators/math/jit_code.cc @@ -41,7 +41,7 @@ void VXXJitCode::generate() { } else if (scalar_index_ == 2) { vbroadcastss(ymm_src2, ptr[param2]); } - for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) { + for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) { if (scalar_index_ != 1) { vmovups(ymm_src1, ptr[param1 + offset]); } @@ -57,9 +57,9 @@ void VXXJitCode::generate() { vmaxps(ymm_dst, ymm_zero, ymm_dst); } vmovups(ptr[param3 + offset], ymm_dst); - offset += sizeof(float) * AVX_FLOAT_BLOCK; + offset += sizeof(float) * YMM_FLOAT_BLOCK; } - int rest = num_ % AVX_FLOAT_BLOCK; + int rest = num_ % YMM_FLOAT_BLOCK; if (rest >= 4) { if (scalar_index_ != 1) { vmovups(xmm_src1, ptr[param1 + offset]); @@ -133,23 +133,23 @@ void VXXJitCode::generate() { #define REPEAT_8TIMES(val) val, val, val, val, val, val, val, val -#define OFFSET_EXP_ONE 0 * AVX_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_TWO 1 * AVX_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_0P5 2 * AVX_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_HIG 3 * AVX_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_LOW 4 * AVX_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_LOG2EF 5 * AVX_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_C1 6 * AVX_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_C2 7 * AVX_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_P0 8 * AVX_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_P1 9 * AVX_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_P2 10 * AVX_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_P3 11 * AVX_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_P4 12 * AVX_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_P5 13 * AVX_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_MAX_INPUT 14 * AVX_FLOAT_BLOCK * sizeof(float) -#define OFFSET_SIGMOID_MAX 15 * AVX_FLOAT_BLOCK * sizeof(float) -#define OFFSET_SIGMOID_MIN 16 * AVX_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_ONE 0 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_TWO 1 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_0P5 2 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_HIG 3 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_LOW 4 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_LOG2EF 5 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_C1 6 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_C2 7 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_P0 8 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_P1 9 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_P2 10 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_P3 11 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_P4 12 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_P5 13 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_MAX_INPUT 14 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_SIGMOID_MAX 15 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_SIGMOID_MIN 16 * YMM_FLOAT_BLOCK * sizeof(float) static const float exp_float_consts[] ALIGN32 = { REPEAT_8TIMES(1.f), @@ -177,9 +177,12 @@ bool VActJitCode::init(int d, operand_type type) { bool ok = MayIUse(avx); if (type == operand_type::relu) { return ok; + } else if (type == operand_type::exp) { + // exp is slower than mkl when d >= 256 + return ok && d % 8 == 0 && d < 256; } else { // TODO(TJ): support more - return ok && d == 8; // only 8 yet + return ok && d % 8 == 0; } } @@ -224,7 +227,7 @@ void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx, vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P0]); vmulps(ymm_dst, ymm_src, ymm_tmp); for (size_t i = OFFSET_EXP_P1; i < OFFSET_EXP_P5; - i += (AVX_FLOAT_BLOCK * sizeof(float))) { + i += (YMM_FLOAT_BLOCK * sizeof(float))) { vmovaps(ymm_tmp, ptr[reg_ptr_global + i]); // P1~P4 vaddps(ymm_dst, ymm_dst, ymm_tmp); vmulps(ymm_dst, ymm_dst, ymm_src); @@ -249,7 +252,7 @@ void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx, reg64_t reg_ptr_tmp = reg_ptr_global; mov(reg_ptr_tmp, reinterpret_cast(g_tmp_mem)); vmovdqa(ptr[reg_ptr_tmp], ymm_int); - vmovdqa(ptr[reg_ptr_tmp + AVX_FLOAT_BLOCK * sizeof(float)], ymm_tmp); + vmovdqa(ptr[reg_ptr_tmp + YMM_FLOAT_BLOCK * sizeof(float)], ymm_tmp); vpaddd(xtmp1, xtmp1, xtmp2); vpslld(xtmp1, xtmp1, 23); vmovdqa(ptr[reg_ptr_tmp], xtmp1); @@ -257,7 +260,7 @@ void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx, vmovdqa(xtmp1, ptr[reg_ptr_tmp + 4 /*xmm float block*/ * sizeof(float)]); vmovdqa(xtmp2, ptr[reg_ptr_tmp + - (AVX_FLOAT_BLOCK + 4 /*xmm float block*/) * sizeof(float)]); + (YMM_FLOAT_BLOCK + 4 /*xmm float block*/) * sizeof(float)]); vpaddd(xtmp1, xtmp1, xtmp2); vpslld(xtmp1, xtmp1, 23); vmovdqa(ptr[reg_ptr_tmp + 4 /*xmm float block*/ * sizeof(float)], xtmp1); @@ -317,7 +320,7 @@ void VActJitCode::generate() { vxorps(ymm_zero, ymm_zero, ymm_zero); } int offset = 0; - for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) { + for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) { vmovups(ymm_src, ptr[param1 + offset]); switch (type_) { case operand_type::relu: @@ -338,14 +341,14 @@ void VActJitCode::generate() { break; } vmovups(ptr[param2 + offset], ymm_dst); - offset += sizeof(float) * AVX_FLOAT_BLOCK; + offset += sizeof(float) * YMM_FLOAT_BLOCK; } if (type_ != operand_type::relu) { // TODO(TJ): remove me ret(); return; } - int rest = num_ % AVX_FLOAT_BLOCK; + int rest = num_ % YMM_FLOAT_BLOCK; if (rest >= 4) { vmovups(xmm_src, ptr[param1 + offset]); vmaxps(xmm_dst, xmm_zero, xmm_src); diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index b023ef096ad..4d8d3cd79a1 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -29,10 +29,9 @@ namespace jitkernel { #define SIGMOID_THRESHOLD_MIN -40.0 #define SIGMOID_THRESHOLD_MAX 13.0 #define EXP_MAX_INPUT 40.0 -// TODO(TJ): change AVX_FLOAT_BLOCK to YMM_FLOAT_BLOCK -#define AVX_FLOAT_BLOCK 8 -#define AVX2_FLOAT_BLOCK 8 -#define AVX512_FLOAT_BLOCK 16 +#define XMM_FLOAT_BLOCK 4 +#define YMM_FLOAT_BLOCK 8 +#define ZMM_FLOAT_BLOCK 16 typedef enum { kLT8, kEQ8, kGT8LT16, kEQ16, kGT16 } jit_block; diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc index e9e7eec445c..36a50f20434 100644 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -133,7 +133,7 @@ class VMulKernelImpl : public VMulKernel { #ifdef PADDLE_WITH_XBYAK if (useJIT(d)) { // roughly estimate the size of code - size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; + size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8; jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::mul, 0, false, sz > 4096 ? sz : 4096)); this->Compute = @@ -184,7 +184,7 @@ class VAddKernelImpl : public VAddKernel { explicit VAddKernelImpl(int d) : VAddKernel() { #ifdef PADDLE_WITH_XBYAK if (useJIT(d)) { - size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; + size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8; jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 0, false, sz > 4096 ? sz : 4096)); this->Compute = @@ -234,7 +234,7 @@ class VAddReluKernelImpl : public VAddReluKernel { explicit VAddReluKernelImpl(int d) : VAddReluKernel() { #ifdef PADDLE_WITH_XBYAK if (useJIT(d)) { - size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; + size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8; jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 0, true, sz > 4096 ? sz : 4096)); this->Compute = @@ -266,7 +266,7 @@ class VScalKernelImpl : public VScalKernel { explicit VScalKernelImpl(int d) : VScalKernel() { #ifdef PADDLE_WITH_XBYAK if (useJIT(d)) { - size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; + size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8; jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::mul, 1, false, sz > 4096 ? sz : 4096)); this->Compute = @@ -315,7 +315,7 @@ class VAddBiasKernelImpl : public VAddBiasKernel { explicit VAddBiasKernelImpl(int d) : VAddBiasKernel() { #ifdef PADDLE_WITH_XBYAK if (useJIT(d)) { - size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; + size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8; jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 1, false, sz > 4096 ? sz : 4096)); this->Compute = @@ -349,7 +349,7 @@ class VReluKernelImpl : public VReluKernel { #ifdef PADDLE_WITH_XBYAK if (useJIT(d)) { size_t sz = 96 /* init size */ + - d / AVX_FLOAT_BLOCK * 4 /* instructions */ * + d / YMM_FLOAT_BLOCK * 4 /* instructions */ * 8 /* average bytes for each instruction */; jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::relu, sz > 4096 ? sz : 4096)); diff --git a/paddle/fluid/operators/math/jit_kernel_crf_decode.cc b/paddle/fluid/operators/math/jit_kernel_crf_decode.cc index a4861c347e4..4d26b819482 100644 --- a/paddle/fluid/operators/math/jit_kernel_crf_decode.cc +++ b/paddle/fluid/operators/math/jit_kernel_crf_decode.cc @@ -105,14 +105,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel { int tag_num) \ : CRFDecodeKernel() { \ this->num_ = tag_num; \ - this->end_ = this->num_ / AVX_FLOAT_BLOCK; \ - this->rest_ = this->num_ % AVX_FLOAT_BLOCK; \ + this->end_ = this->num_ / YMM_FLOAT_BLOCK; \ + this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \ } \ template <> \ void CRFDecodeKernelImpl::Compute( \ const int seq_len, const float* x, const float* w, float* alpha, \ int* track) const { \ - INIT_ALPHA(AVX_FLOAT_BLOCK) \ + INIT_ALPHA(YMM_FLOAT_BLOCK) \ /* Use the column-major strategy to get the location of maximum score.*/ \ int seq_offset = 0; \ constexpr int state_trans_base_idx = 2; \ @@ -150,7 +150,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel { max_score = _mm256_max_ps(max_score, score_v); \ trans_offset += this->num_; \ } \ - UPDATE_ALPHA(AVX_FLOAT_BLOCK) \ + UPDATE_ALPHA(YMM_FLOAT_BLOCK) \ } \ seq_offset += this->num_; \ } \ @@ -161,14 +161,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel { CRFDecodeKernelImpl::CRFDecodeKernelImpl(int tag_num) \ : CRFDecodeKernel() { \ this->num_ = tag_num; \ - this->end_ = this->num_ / AVX2_FLOAT_BLOCK; \ - this->rest_ = this->num_ % AVX2_FLOAT_BLOCK; \ + this->end_ = this->num_ / YMM_FLOAT_BLOCK; \ + this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \ } \ template <> \ void CRFDecodeKernelImpl::Compute( \ const int seq_len, const float* x, const float* w, float* alpha, \ int* track) const { \ - INIT_ALPHA(AVX2_FLOAT_BLOCK) \ + INIT_ALPHA(YMM_FLOAT_BLOCK) \ /* Use the column-major strategy to get the location of maximum score.*/ \ int seq_offset = 0; \ constexpr int state_trans_base_idx = 2; \ @@ -196,7 +196,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel { max_score = _mm256_max_ps(max_score, score_v); \ trans_offset += this->num_; \ } \ - UPDATE_ALPHA(AVX2_FLOAT_BLOCK) \ + UPDATE_ALPHA(YMM_FLOAT_BLOCK) \ } \ seq_offset += this->num_; \ } \ @@ -208,14 +208,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel { int tag_num) \ : CRFDecodeKernel() { \ this->num_ = tag_num; \ - this->end_ = this->num_ / AVX512_FLOAT_BLOCK; \ - this->rest_ = this->num_ % AVX512_FLOAT_BLOCK; \ + this->end_ = this->num_ / ZMM_FLOAT_BLOCK; \ + this->rest_ = this->num_ % ZMM_FLOAT_BLOCK; \ } \ template <> \ void CRFDecodeKernelImpl::Compute( \ const int seq_len, const float* x, const float* w, float* alpha, \ int* track) const { \ - INIT_ALPHA(AVX512_FLOAT_BLOCK) \ + INIT_ALPHA(ZMM_FLOAT_BLOCK) \ /* Use the column-major strategy to get the location of maximum score.*/ \ int seq_offset = 0; \ constexpr int state_trans_base_idx = 2; \ @@ -250,7 +250,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel { this->num_ + j_offset), \ max_j); \ /* Calculate the offset of next step*/ \ - j_offset += AVX512_FLOAT_BLOCK; \ + j_offset += ZMM_FLOAT_BLOCK; \ if (j == this->end_ - 1) { \ if (this->rest_ > 0) { \ j_offset += last_offset; \ diff --git a/paddle/fluid/operators/math/jit_kernel_exp.cc b/paddle/fluid/operators/math/jit_kernel_exp.cc index 0e2cdad4700..f2cb8fb74e5 100644 --- a/paddle/fluid/operators/math/jit_kernel_exp.cc +++ b/paddle/fluid/operators/math/jit_kernel_exp.cc @@ -116,7 +116,7 @@ class VExpKernelImpl : public VExpKernel { explicit VExpKernelImpl(int d) : VExpKernel() { #ifdef PADDLE_WITH_XBYAK if (useJIT(d)) { - size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; // should change + size_t sz = 96 + d / YMM_FLOAT_BLOCK * 70 * 8; jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::exp, sz > 4096 ? sz : 4096)); this->Compute = jitcode_->getCode(); @@ -167,7 +167,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel { explicit VSigmoidKernelImpl(int d) : VSigmoidKernel() { #ifdef PADDLE_WITH_XBYAK if (useJIT(d)) { - size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; // should change + size_t sz = 96 + d / YMM_FLOAT_BLOCK * 82 * 8; jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::sigmoid, sz > 4096 ? sz : 4096)); this->Compute = jitcode_->getCode(); @@ -219,7 +219,7 @@ class VTanhKernelImpl : public VTanhKernel { explicit VTanhKernelImpl(int d) : VTanhKernel() { #ifdef PADDLE_WITH_XBYAK if (useJIT(d)) { - size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; // should change + size_t sz = 96 + d / YMM_FLOAT_BLOCK * 84 * 8; jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::tanh, sz > 4096 ? sz : 4096)); this->Compute = jitcode_->getCode(); diff --git a/paddle/fluid/operators/math/jit_kernel_macro.h b/paddle/fluid/operators/math/jit_kernel_macro.h index e8bbc0cae57..8acf60cfbfd 100644 --- a/paddle/fluid/operators/math/jit_kernel_macro.h +++ b/paddle/fluid/operators/math/jit_kernel_macro.h @@ -94,17 +94,17 @@ namespace jitkernel { namespace jit = platform::jit; // TODO(TJ): below defines are deprecated, would be remove recently -#define SEARCH_BLOCK(macro_, ker, dtype, isa) \ - if (d < AVX_FLOAT_BLOCK) { \ - macro_(ker, dtype, isa, kLT8); \ - } else if (d == AVX_FLOAT_BLOCK) { \ - macro_(ker, dtype, isa, kEQ8); \ - } else if (d > AVX_FLOAT_BLOCK && d < AVX512_FLOAT_BLOCK) { \ - macro_(ker, dtype, isa, kGT8LT16); \ - } else if (d == AVX512_FLOAT_BLOCK) { \ - macro_(ker, dtype, isa, kEQ16); \ - } else { \ - macro_(ker, dtype, isa, kGT16); \ +#define SEARCH_BLOCK(macro_, ker, dtype, isa) \ + if (d < YMM_FLOAT_BLOCK) { \ + macro_(ker, dtype, isa, kLT8); \ + } else if (d == YMM_FLOAT_BLOCK) { \ + macro_(ker, dtype, isa, kEQ8); \ + } else if (d > YMM_FLOAT_BLOCK && d < ZMM_FLOAT_BLOCK) { \ + macro_(ker, dtype, isa, kGT8LT16); \ + } else if (d == ZMM_FLOAT_BLOCK) { \ + macro_(ker, dtype, isa, kEQ16); \ + } else { \ + macro_(ker, dtype, isa, kGT16); \ } #define SEARCH_ISA_BLOCK(macro_, ker, dtype) \ -- GitLab