提交 1e06a32a 编写于 作者: T tensor-tang

add vexp jitcode of size 8

test=develop
上级 23544096
......@@ -151,6 +151,132 @@ void ReluJitCode::generate() {
}
ret();
}
bool VExpJitCode::init(int d) {
return MayIUse(avx) && d == 8; // only 8 yet
}
#define ALIGN32 __attribute__((aligned(32)))
#define EXP_HIG 88.3762626647949f
#define EXP_LOW -88.3762626647949f
#define CEPHES_LOG2EF 1.44269504088896341
#define CEPHES_EXP_C1 0.693359375
#define CEPHES_EXP_C2 -2.12194440e-4
#define CEPHES_EXP_P0 1.9875691500E-4
#define CEPHES_EXP_P1 1.3981999507E-3
#define CEPHES_EXP_P2 8.3334519073E-3
#define CEPHES_EXP_P3 4.1665795894E-2
#define CEPHES_EXP_P4 1.6666665459E-1
#define CEPHES_EXP_P5 5.0000001201E-1
#define REPEAT_8TIMES(val) val, val, val, val, val, val, val, val
#define OFFSET_EXP_0P5 1 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_HIG 2 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOW 3 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOG2EF 4 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_C1 5 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_C2 6 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P0 7 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P1 8 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P2 9 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P3 10 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P4 11 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P5 12 * AVX_FLOAT_BLOCK * sizeof(float)
static const float exp_float_consts[] ALIGN32 = {
REPEAT_8TIMES(1.f), REPEAT_8TIMES(0.5f),
REPEAT_8TIMES(EXP_HIG), REPEAT_8TIMES(EXP_LOW),
REPEAT_8TIMES(CEPHES_LOG2EF), REPEAT_8TIMES(CEPHES_EXP_C1),
REPEAT_8TIMES(CEPHES_EXP_C2), REPEAT_8TIMES(CEPHES_EXP_P0),
REPEAT_8TIMES(CEPHES_EXP_P1), REPEAT_8TIMES(CEPHES_EXP_P2),
REPEAT_8TIMES(CEPHES_EXP_P3), REPEAT_8TIMES(CEPHES_EXP_P4),
REPEAT_8TIMES(CEPHES_EXP_P5)};
static const int exp_int_0x7f[] ALIGN32 = {REPEAT_8TIMES(0x7f)};
static int g_tmp_mem[16] ALIGN32 = {0};
void VExpJitCode::generate() {
preCode();
// push some?
// in: ymm0, out: ymm1
// use ymm 0~5 (and ymm 14~15 if avx only)
int offset = 0;
vmovups(ymm_src, ptr[param1 + offset]);
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]);
vminps(ymm_src, ymm_src, ymm_tmp);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOW]);
vmaxps(ymm_src, ymm_src, ymm_tmp);
// express exp(x) as exp(g + n*log(2))
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOG2EF]);
vmulps(ymm_fx, ymm_src, ymm_tmp);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_0P5]);
vaddps(ymm_fx, ymm_fx, ymm_tmp);
vroundps(ymm_fy, ymm_fx, 0x01);
// if greater, substract 1
vcmpgtps(ymm_mask, ymm_fy, ymm_fx);
vmovaps(ymm_tmp, ptr[reg_ptr_global]);
vandps(ymm_mask, ymm_mask, ymm_tmp);
vsubps(ymm_fx, ymm_fy, ymm_mask);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C1]);
vmulps(ymm_fy, ymm_fx, ymm_tmp);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C2]);
vmulps(ymm_z, ymm_fx, ymm_tmp); // ymm_z use same with mask
vsubps(ymm_src, ymm_src, ymm_fy);
vsubps(ymm_src, ymm_src, ymm_z);
vmulps(ymm_z, ymm_src, ymm_src);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P0]);
vmulps(ymm_dst, ymm_src, ymm_tmp);
for (size_t i = OFFSET_EXP_P1; i < OFFSET_EXP_P5;
i += (AVX_FLOAT_BLOCK * sizeof(float))) {
vmovaps(ymm_tmp, ptr[reg_ptr_global + i]); // P1~P4
vaddps(ymm_dst, ymm_dst, ymm_tmp);
vmulps(ymm_dst, ymm_dst, ymm_src);
}
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P5]);
vaddps(ymm_dst, ymm_dst, ymm_tmp);
vmulps(ymm_dst, ymm_dst, ymm_z);
vaddps(ymm_dst, ymm_dst, ymm_src);
vmovaps(ymm_tmp, ptr[reg_ptr_global]);
vaddps(ymm_dst, ymm_dst, ymm_tmp);
// build 2^n
ymm_t ymm_int = ymm_fx;
vcvttps2dq(ymm_int, ymm_fx);
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_int_0x7f));
vmovdqa(ymm_tmp, ptr[reg_ptr_global]);
if (MayIUse(avx2)) {
vpaddd(ymm_int, ymm_int, ymm_tmp);
vpslld(ymm_int, ymm_int, 23);
} else if (MayIUse(avx)) {
// use ymm_int, ymm_tmp and reg_ptr_global
xmm_t xtmp1 = xmm_t(ymm_int); // or magic number should equal the ymm_int
xmm_t xtmp2 = xmm_t(ymm_tmp); // or magic number should equal the ymm_tmp
mov(reg_ptr_global, reinterpret_cast<size_t>(g_tmp_mem));
vmovdqa(ptr[reg_ptr_global], ymm_int);
vmovdqa(ptr[reg_ptr_global + AVX_FLOAT_BLOCK * sizeof(float)], ymm_tmp);
vpaddd(xtmp1, xtmp1, xtmp2);
vpslld(xtmp1, xtmp1, 23);
vmovdqa(ptr[reg_ptr_global], xtmp1);
// next 128bits
vmovdqa(xtmp1, ptr[reg_ptr_global + 4 /*xmm float block*/ * sizeof(float)]);
vmovdqa(xtmp2,
ptr[reg_ptr_global +
(AVX_FLOAT_BLOCK + 4 /*xmm float block*/) * sizeof(float)]);
vpaddd(xtmp1, xtmp1, xtmp2);
vpslld(xtmp1, xtmp1, 23);
vmovdqa(ptr[reg_ptr_global + 4 /*xmm float block*/ * sizeof(float)], xtmp1);
// load out
vmovdqa(ymm_int, ptr[reg_ptr_global]);
}
vmulps(ymm_dst, ymm_dst, ymm_int);
vmovups(ptr[param2 + offset], ymm_dst);
// ret();
postCode();
}
} // namespace gen
} // namespace jitkernel
} // namespace math
......
......@@ -108,6 +108,30 @@ class ReluJitCode : public JitCode {
ymm_t ymm_dst = ymm_t(1);
};
class VExpJitCode : public JitCode {
public:
DECLARE_JIT_CODE(VExpJitCode);
explicit VExpJitCode(int d, size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr), num_(d) {}
static bool init(int d);
void generate() override;
private:
int num_;
reg64_t param1{abi_param1};
reg64_t param2{abi_param2};
reg64_t reg_ptr_global = rax;
ymm_t ymm_src = ymm_t(0);
ymm_t ymm_dst = ymm_t(1);
ymm_t ymm_fx = ymm_t(2);
ymm_t ymm_fy = ymm_t(3);
ymm_t ymm_mask = ymm_t(4);
ymm_t ymm_z = ymm_t(4);
ymm_t ymm_tmp = ymm_t(5);
};
} // namespace gen
} // namespace jitkernel
} // namespace math
......
......@@ -117,6 +117,7 @@ template <typename T>
class VExpKernel : public VActKernel<T> {
public:
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
void (*Compute)(const T *, T *, int);
};
template <typename T>
......
......@@ -25,10 +25,6 @@ limitations under the License. */
#include "paddle/fluid/platform/dynload/mklml.h"
#endif
#ifdef __AVX__
#include <immintrin.h>
#endif
namespace paddle {
namespace operators {
namespace math {
......@@ -128,18 +124,11 @@ void VScalMKL<double>(const double* a, const double* x, double* y, int n) {
#endif
#define DECLARE_STATIC_FUNC \
static inline std::string name(int d) { \
PADDLE_THROW("DType should be either float or double"); \
} \
static inline bool useJIT(int d) { return false; } \
static inline bool useMKL(int d) { return false; }
/* VMUL JitKernel */
template <typename T>
class VMulKernelImpl : public VMulKernel<T> {
public:
DECLARE_STATIC_FUNC;
JITKERNEL_DECLARE_STATIC_FUNC;
explicit VMulKernelImpl(int d) : VMulKernel<T>() {
#ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) {
......@@ -191,7 +180,7 @@ bool VMulKernelImpl<double>::useMKL(int d) {
template <typename T>
class VAddKernelImpl : public VAddKernel<T> {
public:
DECLARE_STATIC_FUNC;
JITKERNEL_DECLARE_STATIC_FUNC;
explicit VAddKernelImpl(int d) : VAddKernel<T>() {
#ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) {
......@@ -241,7 +230,7 @@ bool VAddKernelImpl<double>::useMKL(int d) {
template <typename T>
class VAddReluKernelImpl : public VAddReluKernel<T> {
public:
DECLARE_STATIC_FUNC;
JITKERNEL_DECLARE_STATIC_FUNC;
explicit VAddReluKernelImpl(int d) : VAddReluKernel<T>() {
#ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) {
......@@ -273,7 +262,7 @@ bool VAddReluKernelImpl<float>::useJIT(int d) {
template <typename T>
class VScalKernelImpl : public VScalKernel<T> {
public:
DECLARE_STATIC_FUNC;
JITKERNEL_DECLARE_STATIC_FUNC;
explicit VScalKernelImpl(int d) : VScalKernel<T>() {
#ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) {
......@@ -322,7 +311,7 @@ bool VScalKernelImpl<double>::useMKL(int d) {
template <typename T>
class VAddBiasKernelImpl : public VAddBiasKernel<T> {
public:
DECLARE_STATIC_FUNC;
JITKERNEL_DECLARE_STATIC_FUNC;
explicit VAddBiasKernelImpl(int d) : VAddBiasKernel<T>() {
#ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) {
......@@ -355,14 +344,14 @@ bool VAddBiasKernelImpl<float>::useJIT(int d) {
template <typename T>
class VReluKernelImpl : public VReluKernel<T> {
public:
DECLARE_STATIC_FUNC;
JITKERNEL_DECLARE_STATIC_FUNC;
explicit VReluKernelImpl(int d) : VReluKernel<T>() {
this->num_ = d; // TODO(TJ): remove me when ComputeDeprecated done
#ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) {
size_t sz = 96 /*init*/ +
d / AVX_FLOAT_BLOCK * 4 /* instructions*/ *
8 /*everage byte for each instruction*/;
size_t sz = 96 /* init size */ +
d / AVX_FLOAT_BLOCK * 4 /* instructions */ *
8 /* average bytes for each instruction */;
jitcode_.reset(new gen::ReluJitCode(d, sz > 4096 ? sz : 4096));
this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>();
return;
......@@ -388,8 +377,6 @@ bool VReluKernelImpl<float>::useJIT(int d) {
}
#endif
#undef DECLARE_STATIC_FUNC
REGISTER_JITKERNEL(vmul, VMulKernel);
REGISTER_JITKERNEL(vadd, VAddKernel);
REGISTER_JITKERNEL(vaddrelu, VAddReluKernel);
......
......@@ -16,6 +16,11 @@ limitations under the License. */
#include <cmath> // for exp
#include <string>
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
#ifdef PADDLE_WITH_XBYAK
#include "paddle/fluid/operators/math/jit_code.h"
#endif
#ifdef PADDLE_WITH_MKLML
#include "paddle/fluid/platform/dynload/mklml.h"
#endif
......@@ -30,41 +35,84 @@ namespace math {
namespace jitkernel {
namespace jit = platform::jit;
// TODO(TJ): move refer codes to one file
template <typename T>
void VExpRefer(const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = std::exp(x[i]);
}
}
#ifdef PADDLE_WITH_MKLML
template <typename T>
void VExpMKL(const T* x, T* y, int n);
template <>
void VExpMKL<float>(const float* x, float* y, int n) {
platform::dynload::vsExp(n, x, y);
}
template <>
void VExpMKL<double>(const double* x, double* y, int n) {
platform::dynload::vdExp(n, x, y);
}
#endif
/* VExp JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block>
template <typename T>
class VExpKernelImpl : public VExpKernel<T> {
public:
explicit VExpKernelImpl(int d) : VExpKernel<T>() { this->num_ = d; }
void ComputeDeprecated(const T* x, T* y) const override {
for (int i = 0; i < this->num_; ++i) {
y[i] = std::exp(x[i]);
JITKERNEL_DECLARE_STATIC_FUNC;
explicit VExpKernelImpl(int d) : VExpKernel<T>() {
this->num_ = d; // TODO(TJ): remove me when ComputeDeprecated done
#ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; // should change
jitcode_.reset(new gen::VExpJitCode(d, sz > 4096 ? sz : 4096));
this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>();
return;
}
#endif
#ifdef PADDLE_WITH_MKLML
if (useMKL(d)) {
this->Compute = VExpMKL<T>;
return;
}
#endif
this->Compute = VExpRefer<T>;
}
void ComputeDeprecated(const T* x, T* y) const override {
VExpRefer(x, y, this->num_);
}
#ifdef PADDLE_WITH_XBYAK
private:
std::unique_ptr<gen::VExpJitCode> jitcode_{nullptr};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template <>
bool VExpKernelImpl<float>::useJIT(int d) {
return gen::VExpJitCode::init(d);
}
#endif
#ifdef PADDLE_WITH_MKLML
#define MKL_FLOAT(isa, block) \
template <> \
void VExpKernelImpl<float, isa, block>::ComputeDeprecated(const float* x, \
float* y) const { \
platform::dynload::vsExp(this->num_, x, y); \
}
template <>
bool VExpKernelImpl<float>::useMKL(int d) {
return d > 512;
}
#define MKL_DOUBLE(isa, block) \
template <> \
void VExpKernelImpl<double, isa, block>::ComputeDeprecated( \
const double* x, double* y) const { \
platform::dynload::vdExp(this->num_, x, y); \
}
FOR_EACH_ISA(MKL_FLOAT, kLT8);
FOR_EACH_ISA(MKL_FLOAT, kGT8LT16);
FOR_EACH_ISA(MKL_FLOAT, kGT16);
FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
template <>
bool VExpKernelImpl<double>::useMKL(int d) {
return true;
}
#endif
namespace detail {
REGISTER_JITKERNEL(vexp, VExpKernel);
#ifdef __AVX__
namespace detail {
#define ALIGN32 __attribute__((aligned(32)))
......@@ -195,7 +243,6 @@ __m256 ExpAVX(__m256 x) {
y = _mm256_mul_ps(y, pow2n);
return y;
}
#endif
#ifdef __AVX2__
__m256 ExpAVX2(__m256 x) {
......@@ -211,47 +258,6 @@ __m256 ExpAVX2(__m256 x) {
} // namespace detail
#define INTRI8_FLOAT(isa, expisa) \
template <> \
void VExpKernelImpl<float, isa, kEQ8>::ComputeDeprecated(const float* x, \
float* y) const { \
__m256 tmp = _mm256_loadu_ps(x); \
_mm256_storeu_ps(y, expisa(tmp)); \
}
#define INTRI16_FLOAT(isa, expisa) \
template <> \
void VExpKernelImpl<float, isa, kEQ16>::ComputeDeprecated(const float* x, \
float* y) const { \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
tmp0 = expisa(tmp0); \
tmp1 = expisa(tmp1); \
_mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y + 8, tmp1); \
}
#ifdef __AVX__
INTRI8_FLOAT(jit::avx, detail::ExpAVX);
INTRI16_FLOAT(jit::avx, detail::ExpAVX);
#endif
#ifdef __AVX2__
INTRI8_FLOAT(jit::avx2, detail::ExpAVX2);
INTRI16_FLOAT(jit::avx2, detail::ExpAVX2);
#endif
#ifdef __AVX512F__
INTRI8_FLOAT(jit::avx512f, detail::ExpAVX2);
INTRI16_FLOAT(jit::avx512f, detail::ExpAVX2);
#endif
// TODO(TJ): eq16 test and complete avx512
#undef INTRI8_FLOAT
#undef INTRI16_FLOAT
#undef MKL_FLOAT
#undef MKL_DOUBLE
REGISTER_JITKERNEL_DEPRECATED(vexp, VExpKernel);
/* VSigmoid JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block>
class VSigmoidKernelImpl : public VSigmoidKernel<T> {
......
......@@ -15,12 +15,20 @@ limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace math {
namespace jitkernel {
#define JITKERNEL_DECLARE_STATIC_FUNC \
static inline std::string name(int d) { \
PADDLE_THROW("DType should be either float or double"); \
} \
static inline bool useJIT(int d) { return false; } \
static inline bool useMKL(int d) { return false; }
#define JITKERNEL_DEFINE_NAME(ker_key, ker_class) \
template <> \
std::string ker_class##Impl<float>::name(int d) { \
......
......@@ -181,7 +181,8 @@ TEST(JitKernel, vexp) {
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->ComputeDeprecated(x_data, ztgt_data);
// ker->ComputeDeprecated(x_data, ztgt_data);
ker->Compute(x_data, ztgt_data, d);
}
auto ttgte = GetCurrentUS();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册