提交 046374bc 编写于 作者: T tensor-tang

add vsigmoid jitcode of size 8

上级 ee2a7f1b
......@@ -152,10 +152,6 @@ void ReluJitCode::generate() {
ret();
}
bool VExpJitCode::init(int d) {
return MayIUse(avx) && d == 8; // only 8 yet
}
#define ALIGN32 __attribute__((aligned(32)))
#define EXP_HIG 88.3762626647949f
#define EXP_LOW -88.3762626647949f
......@@ -171,6 +167,7 @@ bool VExpJitCode::init(int d) {
#define REPEAT_8TIMES(val) val, val, val, val, val, val, val, val
#define OFFSET_EXP_ONE 0 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_0P5 1 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_HIG 2 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOW 3 * AVX_FLOAT_BLOCK * sizeof(float)
......@@ -183,24 +180,43 @@ bool VExpJitCode::init(int d) {
#define OFFSET_EXP_P3 10 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P4 11 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P5 12 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_MAX_INPUT 13 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MAX 14 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MIN 15 * AVX_FLOAT_BLOCK * sizeof(float)
static const float exp_float_consts[] ALIGN32 = {
REPEAT_8TIMES(1.f), REPEAT_8TIMES(0.5f),
REPEAT_8TIMES(EXP_HIG), REPEAT_8TIMES(EXP_LOW),
REPEAT_8TIMES(CEPHES_LOG2EF), REPEAT_8TIMES(CEPHES_EXP_C1),
REPEAT_8TIMES(CEPHES_EXP_C2), REPEAT_8TIMES(CEPHES_EXP_P0),
REPEAT_8TIMES(CEPHES_EXP_P1), REPEAT_8TIMES(CEPHES_EXP_P2),
REPEAT_8TIMES(CEPHES_EXP_P3), REPEAT_8TIMES(CEPHES_EXP_P4),
REPEAT_8TIMES(CEPHES_EXP_P5)};
REPEAT_8TIMES(1.f),
REPEAT_8TIMES(0.5f),
REPEAT_8TIMES(EXP_HIG),
REPEAT_8TIMES(EXP_LOW),
REPEAT_8TIMES(CEPHES_LOG2EF),
REPEAT_8TIMES(CEPHES_EXP_C1),
REPEAT_8TIMES(CEPHES_EXP_C2),
REPEAT_8TIMES(CEPHES_EXP_P0),
REPEAT_8TIMES(CEPHES_EXP_P1),
REPEAT_8TIMES(CEPHES_EXP_P2),
REPEAT_8TIMES(CEPHES_EXP_P3),
REPEAT_8TIMES(CEPHES_EXP_P4),
REPEAT_8TIMES(CEPHES_EXP_P5),
REPEAT_8TIMES(EXP_MAX_INPUT),
REPEAT_8TIMES(SIGMOID_THRESHOLD_MAX),
REPEAT_8TIMES(SIGMOID_THRESHOLD_MIN)};
static const int exp_int_0x7f[] ALIGN32 = {REPEAT_8TIMES(0x7f)};
static int g_tmp_mem[16] ALIGN32 = {0};
void VExpJitCode::generate() {
// in: ymm0, out: ymm1
// use ymm 0~5, rax
int offset = 0;
vmovups(ymm_src, ptr[param1 + offset]);
bool VExpJitCode::init(int d) {
return MayIUse(avx) && d == 8; // only 8 yet
}
void VExpJitCode::exp_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
// use reg rax and ymm 2~5
reg64_t reg_ptr_global = rax;
ymm_t ymm_fx = ymm_t(2);
ymm_t ymm_fy = ymm_t(3);
ymm_t ymm_mask = ymm_t(4);
ymm_t ymm_tmp = ymm_t(5);
push(reg_ptr_global);
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]);
vminps(ymm_src, ymm_src, ymm_tmp);
......@@ -269,8 +285,45 @@ void VExpJitCode::generate() {
vmovdqa(ymm_int, ptr[reg_ptr_tmp]);
}
vmulps(ymm_dst, ymm_dst, ymm_int);
pop(reg_ptr_global);
}
void VExpJitCode::generate() {
int offset = 0;
vmovups(ymm_src, ptr[param1 + offset]);
exp_ymm(ymm_src, ymm_dst);
vmovups(ptr[param2 + offset], ymm_dst);
ret();
}
bool VSigmoidJitCode::init(int d) {
return MayIUse(avx) && d == 8; // only 8 yet
}
void VSigmoidJitCode::sigmoid_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
// use ymm2
reg64_t reg_ptr_global = rax;
ymm_t ymm_tmp = ymm_t(2);
push(reg_ptr_global);
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MAX]);
vminps(ymm_src, ymm_src, ymm_tmp);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MIN]);
vmaxps(ymm_src, ymm_src, ymm_tmp);
vxorps(ymm_tmp, ymm_tmp, ymm_tmp);
vsubps(ymm_src, ymm_tmp, ymm_src);
exp_ymm(ymm_src, ymm_dst);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
vaddps(ymm_dst, ymm_dst, ymm_tmp);
vdivps(ymm_dst, ymm_tmp, ymm_dst);
pop(reg_ptr_global);
}
void VSigmoidJitCode::generate() {
int offset = 0;
vmovups(ymm_src, ptr[param1 + offset]);
sigmoid_ymm(ymm_src, ymm_dst);
vmovups(ptr[param2 + offset], ymm_dst);
ret();
}
......
......@@ -117,18 +117,36 @@ class VExpJitCode : public JitCode {
static bool init(int d);
void generate() override;
protected:
// compute exp with ymm
void exp_ymm(const Xbyak::Ymm& src, const Xbyak::Ymm& dst);
private:
int num_;
reg64_t param1{abi_param1};
reg64_t param2{abi_param2};
ymm_t ymm_src = ymm_t(0);
ymm_t ymm_dst = ymm_t(1);
};
class VSigmoidJitCode : public VExpJitCode {
public:
DECLARE_JIT_CODE(VSigmoidJitCode);
explicit VSigmoidJitCode(int d, size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: VExpJitCode(d, code_size, code_ptr), num_(d) {}
static bool init(int d);
void generate() override;
reg64_t reg_ptr_global = rax;
// compute sigmoid with ymm
void sigmoid_ymm(const Xbyak::Ymm& src, const Xbyak::Ymm& dst);
private:
int num_;
reg64_t param1{abi_param1};
reg64_t param2{abi_param2};
ymm_t ymm_src = ymm_t(0);
ymm_t ymm_dst = ymm_t(1);
ymm_t ymm_fx = ymm_t(2);
ymm_t ymm_fy = ymm_t(3);
ymm_t ymm_mask = ymm_t(4);
ymm_t ymm_tmp = ymm_t(5);
};
} // namespace gen
......
......@@ -29,6 +29,7 @@ namespace jitkernel {
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
// TODO(TJ): change AVX_FLOAT_BLOCK to YMM_FLOAT_BLOCK
#define AVX_FLOAT_BLOCK 8
#define AVX2_FLOAT_BLOCK 8
#define AVX512_FLOAT_BLOCK 16
......@@ -124,6 +125,7 @@ template <typename T>
class VSigmoidKernel : public VActKernel<T> {
public:
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
void (*Compute)(const T *, T *, int);
};
template <typename T>
......
......@@ -43,6 +43,16 @@ void VExpRefer(const T* x, T* y, int n) {
}
}
template <typename T>
void VSigmoidRefer(const T* x, T* y, int n) {
const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX;
for (int i = 0; i < n; ++i) {
T tmp = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = static_cast<T>(1) / (static_cast<T>(1) + std::exp(-tmp));
}
}
#ifdef PADDLE_WITH_MKLML
template <typename T>
void VExpMKL(const T* x, T* y, int n);
......@@ -56,6 +66,20 @@ template <>
void VExpMKL<double>(const double* x, double* y, int n) {
platform::dynload::vdExp(n, x, y);
}
template <typename T>
void VSigmoidMKL(const T* x, T* y, int n) {
const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX;
for (int i = 0; i < n; ++i) {
y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = static_cast<T>(0) - y[i];
}
VExpMKL(y, y, n);
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(1) / (static_cast<T>(1) + y[i]);
}
}
#endif
/* VExp JitKernel */
......@@ -108,9 +132,65 @@ template <>
bool VExpKernelImpl<double>::useMKL(int d) {
return true;
}
#endif
/* VSigmoid JitKernel */
template <typename T>
class VSigmoidKernelImpl : public VSigmoidKernel<T> {
public:
JITKERNEL_DECLARE_STATIC_FUNC;
explicit VSigmoidKernelImpl(int d) : VSigmoidKernel<T>() {
this->num_ = d; // TODO(TJ): remove me when ComputeDeprecated done
#ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; // should change
jitcode_.reset(new gen::VSigmoidJitCode(d, sz > 4096 ? sz : 4096));
this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>();
return;
}
#endif
#ifdef PADDLE_WITH_MKLML
// strictly it's a better impl with MKL, then is refer
if (useMKL(d)) {
this->Compute = VSigmoidMKL<T>;
return;
}
#endif
this->Compute = VSigmoidRefer<T>;
}
void ComputeDeprecated(const T* x, T* y) const override {
VSigmoidRefer(x, y, this->num_);
}
#ifdef PADDLE_WITH_XBYAK
private:
std::unique_ptr<gen::VSigmoidJitCode> jitcode_{nullptr};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template <>
bool VSigmoidKernelImpl<float>::useJIT(int d) {
return gen::VSigmoidJitCode::init(d);
}
#endif
#ifdef PADDLE_WITH_MKLML
template <>
bool VSigmoidKernelImpl<float>::useMKL(int d) {
return d > 512;
}
template <>
bool VSigmoidKernelImpl<double>::useMKL(int d) {
return true;
}
#endif
REGISTER_JITKERNEL(vexp, VExpKernel);
REGISTER_JITKERNEL(vsigmoid, VSigmoidKernel);
namespace detail {
......@@ -258,31 +338,6 @@ __m256 ExpAVX2(__m256 x) {
} // namespace detail
/* VSigmoid JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block>
class VSigmoidKernelImpl : public VSigmoidKernel<T> {
public:
explicit VSigmoidKernelImpl(int d) : VSigmoidKernel<T>() {
this->num_ = d;
vexp_ = KernelPool::Instance().template Get<VExpKernel<T>>(d);
}
void ComputeDeprecated(const T* x, T* y) const override {
const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX;
for (int i = 0; i < this->num_; ++i) {
y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = static_cast<T>(0) - y[i];
}
vexp_->ComputeDeprecated(y, y);
for (int i = 0; i < this->num_; ++i) {
y[i] = static_cast<T>(1) / (static_cast<T>(1) + y[i]);
}
}
private:
std::shared_ptr<const VExpKernel<T>> vexp_;
};
#define INTRI_SIGMOID(tmp, min, max, expisa) \
tmp = _mm256_max_ps(tmp, min); \
tmp = _mm256_min_ps(tmp, max); \
......@@ -290,120 +345,8 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
tmp = expisa(tmp); \
tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \
tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp)
#define INTRI8_FLOAT(isa, expisa) \
template <> \
void VSigmoidKernelImpl<float, isa, kEQ8>::ComputeDeprecated( \
const float* x, float* y) const { \
/* TODO(TJ): try to use static const*/ \
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
__m256 tmp = _mm256_loadu_ps(x); \
INTRI_SIGMOID(tmp, min, max, expisa); \
_mm256_storeu_ps(y, tmp); \
}
#define INTRI16_FLOAT(isa, expisa) \
template <> \
void VSigmoidKernelImpl<float, isa, kEQ16>::ComputeDeprecated( \
const float* x, float* y) const { \
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
INTRI_SIGMOID(tmp0, min, max, expisa); \
INTRI_SIGMOID(tmp1, min, max, expisa); \
_mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y + 8, tmp1); \
}
#define INTRI_GT8LT16_FLOAT(isa, expisa) \
template <> \
VSigmoidKernelImpl<float, isa, kGT8LT16>::VSigmoidKernelImpl(int d) \
: VSigmoidKernel<float>() { \
this->num_ = d; \
this->end_ = AVX_FLOAT_BLOCK; \
this->rest_ = d - this->end_; \
vexp_ = \
KernelPool::Instance().template Get<VExpKernel<float>>(this->rest_); \
} \
template <> \
void VSigmoidKernelImpl<float, isa, kGT8LT16>::ComputeDeprecated( \
const float* x, float* y) const { \
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
__m256 tmp = _mm256_loadu_ps(x); \
INTRI_SIGMOID(tmp, min, max, expisa); \
_mm256_storeu_ps(y, tmp); \
const float min_ = SIGMOID_THRESHOLD_MIN; \
const float max_ = SIGMOID_THRESHOLD_MAX; \
for (int i = this->end_; i < this->num_; ++i) { \
y[i] = (x[i] < min_) ? min_ : ((x[i] > max_) ? max_ : x[i]); \
y[i] = 0.f - y[i]; \
} \
vexp_->ComputeDeprecated(y + this->end_, y + this->end_); \
for (int i = this->end_; i < this->num_; ++i) { \
y[i] = 1.f / (1.f + y[i]); \
} \
}
#define INTRI_GT16_FLOAT(isa, expisa) \
template <> \
VSigmoidKernelImpl<float, isa, kGT16>::VSigmoidKernelImpl(int d) \
: VSigmoidKernel<float>() { \
this->num_ = d; \
this->rest_ = d % AVX_FLOAT_BLOCK; \
this->end_ = d - this->rest_; \
vexp_ = \
KernelPool::Instance().template Get<VExpKernel<float>>(this->rest_); \
} \
template <> \
void VSigmoidKernelImpl<float, isa, kGT16>::ComputeDeprecated( \
const float* x, float* y) const { \
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \
__m256 tmp = _mm256_loadu_ps(x + i); \
INTRI_SIGMOID(tmp, min, max, expisa); \
_mm256_storeu_ps(y + i, tmp); \
} \
const float min_ = SIGMOID_THRESHOLD_MIN; \
const float max_ = SIGMOID_THRESHOLD_MAX; \
for (int i = this->end_; i < this->num_; ++i) { \
y[i] = (x[i] < min_) ? min_ : ((x[i] > max_) ? max_ : x[i]); \
y[i] = 0.f - y[i]; \
} \
vexp_->ComputeDeprecated(y + this->end_, y + this->end_); \
for (int i = this->end_; i < this->num_; ++i) { \
y[i] = 1.f / (1.f + y[i]); \
} \
}
#ifdef __AVX__
INTRI8_FLOAT(jit::avx, detail::ExpAVX);
INTRI16_FLOAT(jit::avx, detail::ExpAVX);
INTRI_GT8LT16_FLOAT(jit::avx, detail::ExpAVX);
INTRI_GT16_FLOAT(jit::avx, detail::ExpAVX);
#endif
#ifdef __AVX2__
INTRI8_FLOAT(jit::avx2, detail::ExpAVX2);
INTRI16_FLOAT(jit::avx2, detail::ExpAVX2);
// maybe use avx at gt8lt16 and gt16
#endif
#ifdef __AVX512F__
INTRI8_FLOAT(jit::avx512f, detail::ExpAVX2);
INTRI16_FLOAT(jit::avx512f, detail::ExpAVX2);
// maybe use avx2 at gt8lt16 and gt16
#endif
#undef INTRI8_FLOAT
#undef INTRI16_FLOAT
#undef INTRI_GT8LT16_FLOAT
#undef INTRI_GT16_FLOAT
#undef INTRI_VSIGMOID
REGISTER_JITKERNEL_DEPRECATED(vsigmoid, VSigmoidKernel);
/* VTanh JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block>
class VTanhKernelImpl : public VTanhKernel<T> {
......
......@@ -223,7 +223,7 @@ void vsigmoid_better(
y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = 0.f - y[i];
}
vexp->ComputeDeprecated(y, y);
vexp->Compute(y, y, n);
for (int i = 0; i < n; ++i) {
y[i] = 1.f / (1.f + y[i]);
}
......@@ -254,7 +254,7 @@ TEST(JitKernel, vsigmoid) {
auto trefe = GetCurrentUS();
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->ComputeDeprecated(x_data, ztgt_data);
ker->Compute(x_data, ztgt_data, d);
}
auto ttgte = GetCurrentUS();
......@@ -288,7 +288,7 @@ void vtanh_better(
const int n, const float* x, float* y) {
const float a = 2.f, b = -1.f;
vscal->Compute(&a, x, y, n);
vsigmoid->ComputeDeprecated(y, y);
vsigmoid->Compute(y, y, n);
vscal->Compute(&a, y, y, n);
vaddbias->Compute(&b, y, y, n);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册