提交 6a159071 编写于 作者: T tensor-tang

add vtanh jitcode of size 8

上级 046374bc
...@@ -168,24 +168,26 @@ void ReluJitCode::generate() { ...@@ -168,24 +168,26 @@ void ReluJitCode::generate() {
#define REPEAT_8TIMES(val) val, val, val, val, val, val, val, val #define REPEAT_8TIMES(val) val, val, val, val, val, val, val, val
#define OFFSET_EXP_ONE 0 * AVX_FLOAT_BLOCK * sizeof(float) #define OFFSET_EXP_ONE 0 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_0P5 1 * AVX_FLOAT_BLOCK * sizeof(float) #define OFFSET_EXP_TWO 1 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_HIG 2 * AVX_FLOAT_BLOCK * sizeof(float) #define OFFSET_EXP_0P5 2 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOW 3 * AVX_FLOAT_BLOCK * sizeof(float) #define OFFSET_EXP_HIG 3 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOG2EF 4 * AVX_FLOAT_BLOCK * sizeof(float) #define OFFSET_EXP_LOW 4 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_C1 5 * AVX_FLOAT_BLOCK * sizeof(float) #define OFFSET_EXP_LOG2EF 5 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_C2 6 * AVX_FLOAT_BLOCK * sizeof(float) #define OFFSET_EXP_C1 6 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P0 7 * AVX_FLOAT_BLOCK * sizeof(float) #define OFFSET_EXP_C2 7 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P1 8 * AVX_FLOAT_BLOCK * sizeof(float) #define OFFSET_EXP_P0 8 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P2 9 * AVX_FLOAT_BLOCK * sizeof(float) #define OFFSET_EXP_P1 9 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P3 10 * AVX_FLOAT_BLOCK * sizeof(float) #define OFFSET_EXP_P2 10 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P4 11 * AVX_FLOAT_BLOCK * sizeof(float) #define OFFSET_EXP_P3 11 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P5 12 * AVX_FLOAT_BLOCK * sizeof(float) #define OFFSET_EXP_P4 12 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_MAX_INPUT 13 * AVX_FLOAT_BLOCK * sizeof(float) #define OFFSET_EXP_P5 13 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MAX 14 * AVX_FLOAT_BLOCK * sizeof(float) #define OFFSET_EXP_MAX_INPUT 14 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MIN 15 * AVX_FLOAT_BLOCK * sizeof(float) #define OFFSET_SIGMOID_MAX 15 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MIN 16 * AVX_FLOAT_BLOCK * sizeof(float)
static const float exp_float_consts[] ALIGN32 = { static const float exp_float_consts[] ALIGN32 = {
REPEAT_8TIMES(1.f), REPEAT_8TIMES(1.f),
REPEAT_8TIMES(2.f),
REPEAT_8TIMES(0.5f), REPEAT_8TIMES(0.5f),
REPEAT_8TIMES(EXP_HIG), REPEAT_8TIMES(EXP_HIG),
REPEAT_8TIMES(EXP_LOW), REPEAT_8TIMES(EXP_LOW),
...@@ -216,6 +218,7 @@ void VExpJitCode::exp_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) { ...@@ -216,6 +218,7 @@ void VExpJitCode::exp_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
ymm_t ymm_fy = ymm_t(3); ymm_t ymm_fy = ymm_t(3);
ymm_t ymm_mask = ymm_t(4); ymm_t ymm_mask = ymm_t(4);
ymm_t ymm_tmp = ymm_t(5); ymm_t ymm_tmp = ymm_t(5);
assert(ymm_src.getIdx() != ymm_dst.getIdx()); // TODO(TJ): use enfore
push(reg_ptr_global); push(reg_ptr_global);
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts)); mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]); vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]);
...@@ -327,6 +330,40 @@ void VSigmoidJitCode::generate() { ...@@ -327,6 +330,40 @@ void VSigmoidJitCode::generate() {
ret(); ret();
} }
bool VTanhJitCode::init(int d) {
return MayIUse(avx) && d == 8; // only 8 yet
}
void VTanhJitCode::vtanh_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
// y = 2 / (1 + e^(-2x)) - 1
// use ymm2, ymm3
reg64_t reg_ptr_global = rax;
ymm_t ymm_tmp = ymm_t(2);
ymm_t ymm_zero = ymm_t(3);
push(reg_ptr_global);
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
vxorps(ymm_zero, ymm_zero, ymm_zero);
vsubps(ymm_tmp, ymm_zero, ymm_tmp);
vmulps(ymm_src, ymm_src, ymm_tmp);
exp_ymm(ymm_src, ymm_dst);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
vaddps(ymm_dst, ymm_dst, ymm_tmp);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
vdivps(ymm_dst, ymm_tmp, ymm_dst);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
vsubps(ymm_dst, ymm_dst, ymm_tmp);
pop(reg_ptr_global);
}
void VTanhJitCode::generate() {
int offset = 0;
vmovups(ymm_src, ptr[param1 + offset]);
vtanh_ymm(ymm_src, ymm_dst);
vmovups(ptr[param2 + offset], ymm_dst);
ret();
}
} // namespace gen } // namespace gen
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
......
...@@ -149,6 +149,26 @@ class VSigmoidJitCode : public VExpJitCode { ...@@ -149,6 +149,26 @@ class VSigmoidJitCode : public VExpJitCode {
ymm_t ymm_dst = ymm_t(1); ymm_t ymm_dst = ymm_t(1);
}; };
class VTanhJitCode : public VExpJitCode {
public:
DECLARE_JIT_CODE(VTanhJitCode);
explicit VTanhJitCode(int d, size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: VExpJitCode(d, code_size, code_ptr), num_(d) {}
static bool init(int d);
void generate() override;
// compute sigmoid with ymm
void vtanh_ymm(const Xbyak::Ymm& src, const Xbyak::Ymm& dst);
private:
int num_;
reg64_t param1{abi_param1};
reg64_t param2{abi_param2};
ymm_t ymm_src = ymm_t(0);
ymm_t ymm_dst = ymm_t(1);
};
} // namespace gen } // namespace gen
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
......
...@@ -132,6 +132,7 @@ template <typename T> ...@@ -132,6 +132,7 @@ template <typename T>
class VTanhKernel : public VActKernel<T> { class VTanhKernel : public VActKernel<T> {
public: public:
virtual void ComputeDeprecated(const T *x, T *y) const = 0; virtual void ComputeDeprecated(const T *x, T *y) const = 0;
void (*Compute)(const T *, T *, int);
}; };
template <typename T> template <typename T>
......
...@@ -45,6 +45,7 @@ void VExpRefer(const T* x, T* y, int n) { ...@@ -45,6 +45,7 @@ void VExpRefer(const T* x, T* y, int n) {
template <typename T> template <typename T>
void VSigmoidRefer(const T* x, T* y, int n) { void VSigmoidRefer(const T* x, T* y, int n) {
// y = 1 / (1 + e^-x)
const T min = SIGMOID_THRESHOLD_MIN; const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX; const T max = SIGMOID_THRESHOLD_MAX;
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
...@@ -53,6 +54,18 @@ void VSigmoidRefer(const T* x, T* y, int n) { ...@@ -53,6 +54,18 @@ void VSigmoidRefer(const T* x, T* y, int n) {
} }
} }
template <typename T>
void VTanhRefer(const T* x, T* y, int n) {
// y = 2 * sigmoid(2x) - 1
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(2) * x[i];
}
VSigmoidRefer(y, y, n);
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(2) * y[i] - static_cast<T>(1);
}
}
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
template <typename T> template <typename T>
void VExpMKL(const T* x, T* y, int n); void VExpMKL(const T* x, T* y, int n);
...@@ -80,6 +93,17 @@ void VSigmoidMKL(const T* x, T* y, int n) { ...@@ -80,6 +93,17 @@ void VSigmoidMKL(const T* x, T* y, int n) {
y[i] = static_cast<T>(1) / (static_cast<T>(1) + y[i]); y[i] = static_cast<T>(1) / (static_cast<T>(1) + y[i]);
} }
} }
template <typename T>
void VTanhMKL(const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(2) * x[i];
}
VSigmoidMKL(y, y, n);
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(2) * y[i] - static_cast<T>(1);
}
}
#endif #endif
/* VExp JitKernel */ /* VExp JitKernel */
...@@ -189,8 +213,63 @@ bool VSigmoidKernelImpl<double>::useMKL(int d) { ...@@ -189,8 +213,63 @@ bool VSigmoidKernelImpl<double>::useMKL(int d) {
} }
#endif #endif
/* VTanh JitKernel */
template <typename T>
class VTanhKernelImpl : public VTanhKernel<T> {
public:
JITKERNEL_DECLARE_STATIC_FUNC;
explicit VTanhKernelImpl(int d) : VTanhKernel<T>() {
this->num_ = d; // TODO(TJ): remove me when ComputeDeprecated done
#ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; // should change
jitcode_.reset(new gen::VTanhJitCode(d, sz > 4096 ? sz : 4096));
this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>();
return;
}
#endif
#ifdef PADDLE_WITH_MKLML
// strictly it's a better impl with MKL, then is refer
if (useMKL(d)) {
this->Compute = VTanhMKL<T>;
return;
}
#endif
this->Compute = VTanhRefer<T>;
}
void ComputeDeprecated(const T* x, T* y) const override {
VTanhRefer(x, y, this->num_);
}
#ifdef PADDLE_WITH_XBYAK
private:
std::unique_ptr<gen::VTanhJitCode> jitcode_{nullptr};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template <>
bool VTanhKernelImpl<float>::useJIT(int d) {
return gen::VTanhJitCode::init(d);
}
#endif
#ifdef PADDLE_WITH_MKLML
template <>
bool VTanhKernelImpl<float>::useMKL(int d) {
return d > 512;
}
template <>
bool VTanhKernelImpl<double>::useMKL(int d) {
return true;
}
#endif
REGISTER_JITKERNEL(vexp, VExpKernel); REGISTER_JITKERNEL(vexp, VExpKernel);
REGISTER_JITKERNEL(vsigmoid, VSigmoidKernel); REGISTER_JITKERNEL(vsigmoid, VSigmoidKernel);
REGISTER_JITKERNEL(vtanh, VTanhKernel);
namespace detail { namespace detail {
...@@ -337,156 +416,6 @@ __m256 ExpAVX2(__m256 x) { ...@@ -337,156 +416,6 @@ __m256 ExpAVX2(__m256 x) {
#endif #endif
} // namespace detail } // namespace detail
#define INTRI_SIGMOID(tmp, min, max, expisa) \
tmp = _mm256_max_ps(tmp, min); \
tmp = _mm256_min_ps(tmp, max); \
tmp = _mm256_sub_ps(_mm256_set1_ps(0.0f), tmp); \
tmp = expisa(tmp); \
tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \
tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp)
#undef INTRI_VSIGMOID
/* VTanh JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block>
class VTanhKernelImpl : public VTanhKernel<T> {
public:
explicit VTanhKernelImpl(int d) : VTanhKernel<T>() {
this->num_ = d;
vscal_ = KernelPool::Instance().template Get<VScalKernel<T>>(d);
vsigmoid_ = KernelPool::Instance().template Get<VSigmoidKernel<T>>(d);
vaddbias_ = KernelPool::Instance().template Get<VAddBiasKernel<T>>(d);
}
void ComputeDeprecated(const T* x, T* y) const override {
const T a = static_cast<T>(2), b = static_cast<T>(-1);
vscal_->Compute(&a, x, y, this->num_);
vsigmoid_->ComputeDeprecated(y, y);
vscal_->Compute(&a, y, y, this->num_);
vaddbias_->Compute(&b, y, y, this->num_);
}
private:
std::shared_ptr<const VScalKernel<T>> vscal_;
std::shared_ptr<const VSigmoidKernel<T>> vsigmoid_;
std::shared_ptr<const VAddBiasKernel<T>> vaddbias_;
};
#define INTRI_VTANH(tmp, expisa) \
tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), tmp); \
tmp = _mm256_min_ps(tmp, _mm256_set1_ps(EXP_MAX_INPUT)); \
tmp = expisa(tmp); \
tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \
tmp = _mm256_div_ps(_mm256_set1_ps(2.0f), tmp); \
tmp = _mm256_sub_ps(tmp, _mm256_set1_ps(1.0f))
#define INTRI8_FLOAT(isa, expisa) \
template <> \
void VTanhKernelImpl<float, isa, kEQ8>::ComputeDeprecated(const float* x, \
float* y) const { \
__m256 tmp = _mm256_loadu_ps(x); \
INTRI_VTANH(tmp, expisa); \
_mm256_storeu_ps(y, tmp); \
}
#define INTRI16_FLOAT(isa, expisa) \
template <> \
void VTanhKernelImpl<float, isa, kEQ16>::ComputeDeprecated(const float* x, \
float* y) const { \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
INTRI_VTANH(tmp0, expisa); \
INTRI_VTANH(tmp1, expisa); \
_mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y + 8, tmp1); \
}
#define INTRI_GT8LT16_FLOAT(isa, expisa) \
template <> \
VTanhKernelImpl<float, isa, kGT8LT16>::VTanhKernelImpl(int d) \
: VTanhKernel<float>() { \
this->num_ = d; \
this->end_ = AVX_FLOAT_BLOCK; \
this->rest_ = d - this->end_; \
vscal_ = \
KernelPool::Instance().template Get<VScalKernel<float>>(this->rest_); \
vsigmoid_ = KernelPool::Instance().template Get<VSigmoidKernel<float>>( \
this->rest_); \
vaddbias_ = KernelPool::Instance().template Get<VAddBiasKernel<float>>( \
this->rest_); \
} \
template <> \
void VTanhKernelImpl<float, isa, kGT8LT16>::ComputeDeprecated( \
const float* x, float* y) const { \
__m256 tmp = _mm256_loadu_ps(x); \
INTRI_VTANH(tmp, expisa); \
_mm256_storeu_ps(y, tmp); \
x += AVX_FLOAT_BLOCK; \
y += AVX_FLOAT_BLOCK; \
const float a = 2.f, b = -1.f; \
vscal_->Compute(&a, x, y, this->num_); \
vsigmoid_->ComputeDeprecated(y, y); \
vscal_->Compute(&a, y, y, this->num_); \
vaddbias_->Compute(&b, y, y, this->num_); \
}
#define INTRI_GT16_FLOAT(isa, expisa) \
template <> \
VTanhKernelImpl<float, isa, kGT16>::VTanhKernelImpl(int d) \
: VTanhKernel<float>() { \
this->num_ = d; \
this->rest_ = d % AVX_FLOAT_BLOCK; \
this->end_ = d - this->rest_; \
vscal_ = \
KernelPool::Instance().template Get<VScalKernel<float>>(this->rest_); \
vsigmoid_ = KernelPool::Instance().template Get<VSigmoidKernel<float>>( \
this->rest_); \
vaddbias_ = KernelPool::Instance().template Get<VAddBiasKernel<float>>( \
this->rest_); \
} \
template <> \
void VTanhKernelImpl<float, isa, kGT16>::ComputeDeprecated(const float* x, \
float* y) const { \
for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \
__m256 tmp = _mm256_loadu_ps(x + i); \
INTRI_VTANH(tmp, expisa); \
_mm256_storeu_ps(y + i, tmp); \
} \
x += this->end_; \
y += this->end_; \
const float a = 2.f, b = -1.f; \
vscal_->Compute(&a, x, y, this->num_); \
vsigmoid_->ComputeDeprecated(y, y); \
vscal_->Compute(&a, y, y, this->num_); \
vaddbias_->Compute(&b, y, y, this->num_); \
}
#ifdef __AVX__
INTRI8_FLOAT(jit::avx, detail::ExpAVX);
INTRI16_FLOAT(jit::avx, detail::ExpAVX);
INTRI_GT8LT16_FLOAT(jit::avx, detail::ExpAVX);
INTRI_GT16_FLOAT(jit::avx, detail::ExpAVX);
#endif
#ifdef __AVX2__
INTRI8_FLOAT(jit::avx2, detail::ExpAVX2);
INTRI16_FLOAT(jit::avx2, detail::ExpAVX2);
// maybe use avx at gt8lt16 and gt16
#endif
#ifdef __AVX512F__
INTRI8_FLOAT(jit::avx512f, detail::ExpAVX2);
INTRI16_FLOAT(jit::avx512f, detail::ExpAVX2);
// maybe use avx at gt8lt16 and gt16
#endif
#undef INTRI8_FLOAT
#undef INTRI16_FLOAT
#undef INTRI_GT8LT16_FLOAT
#undef INTRI_GT16_FLOAT
#undef INTRI_VTANH
REGISTER_JITKERNEL_DEPRECATED(vtanh, VTanhKernel);
#undef JITKERNEL_NEW_ACT_IMPL
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -322,7 +322,7 @@ TEST(JitKernel, vtanh) { ...@@ -322,7 +322,7 @@ TEST(JitKernel, vtanh) {
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
auto ttgts = GetCurrentUS(); auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
ker->ComputeDeprecated(x_data, ztgt_data); ker->Compute(x_data, ztgt_data, d);
} }
auto ttgte = GetCurrentUS(); auto ttgte = GetCurrentUS();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册