未验证 提交 7f17e561 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #14423 from tensor-tang/fea/jit/act

jitcode act relu, exp, sigmoid, tanh
...@@ -33,11 +33,11 @@ namespace math { ...@@ -33,11 +33,11 @@ namespace math {
#define SIGMOID_THRESHOLD_MIN -40.0 #define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0 #define SIGMOID_THRESHOLD_MAX 13.0
#define AVX_FLOAT_BLOCK 8 #define YMM_FLOAT_BLOCK 8
#define AVX_DOUBLE_BLOCK 4 #define AVX_DOUBLE_BLOCK 4
#define AVX2_FLOAT_BLOCK 8 #define YMM_FLOAT_BLOCK 8
#define AVX2_DOUBLE_BLOCK 4 #define AVX2_DOUBLE_BLOCK 4
#define AVX512_FLOAT_BLOCK 16 #define ZMM_FLOAT_BLOCK 16
#define AVX512_DOUBLE_BLOCK 8 #define AVX512_DOUBLE_BLOCK 8
template <typename T> template <typename T>
...@@ -88,7 +88,7 @@ template <> ...@@ -88,7 +88,7 @@ template <>
inline void vec_scal<float, platform::jit::avx>(const int n, const float a, inline void vec_scal<float, platform::jit::avx>(const int n, const float a,
const float* x, float* y) { const float* x, float* y) {
#ifdef __AVX__ #ifdef __AVX__
constexpr int block = AVX_FLOAT_BLOCK; constexpr int block = YMM_FLOAT_BLOCK;
if (n < block) { if (n < block) {
vec_scal<float, platform::jit::isa_any>(n, a, x, y); vec_scal<float, platform::jit::isa_any>(n, a, x, y);
return; return;
...@@ -142,7 +142,7 @@ template <> ...@@ -142,7 +142,7 @@ template <>
inline void vec_bias_sub<float, platform::jit::avx>(const int n, const float a, inline void vec_bias_sub<float, platform::jit::avx>(const int n, const float a,
const float* x, float* y) { const float* x, float* y) {
#ifdef __AVX__ #ifdef __AVX__
constexpr int block = AVX_FLOAT_BLOCK; constexpr int block = YMM_FLOAT_BLOCK;
if (n < block) { if (n < block) {
vec_bias_sub<float, platform::jit::isa_any>(n, a, x, y); vec_bias_sub<float, platform::jit::isa_any>(n, a, x, y);
return; return;
...@@ -200,7 +200,7 @@ inline void vec_cross<float, platform::jit::avx>(const int n, const float* x, ...@@ -200,7 +200,7 @@ inline void vec_cross<float, platform::jit::avx>(const int n, const float* x,
const float* y, const float* z, const float* y, const float* z,
float* out) { float* out) {
#ifdef __AVX__ #ifdef __AVX__
constexpr int block = AVX_FLOAT_BLOCK; constexpr int block = YMM_FLOAT_BLOCK;
if (n < block) { if (n < block) {
vec_cross<float, platform::jit::isa_any>(n, x, y, z, out); vec_cross<float, platform::jit::isa_any>(n, x, y, z, out);
return; return;
...@@ -257,7 +257,7 @@ template <> ...@@ -257,7 +257,7 @@ template <>
inline void vec_add_bias<float, platform::jit::avx>(const int n, const float a, inline void vec_add_bias<float, platform::jit::avx>(const int n, const float a,
const float* x, float* y) { const float* x, float* y) {
#ifdef __AVX__ #ifdef __AVX__
constexpr int block = AVX_FLOAT_BLOCK; constexpr int block = YMM_FLOAT_BLOCK;
if (n < block) { if (n < block) {
vec_add_bias<float, platform::jit::isa_any>(n, a, x, y); vec_add_bias<float, platform::jit::isa_any>(n, a, x, y);
return; return;
...@@ -326,7 +326,7 @@ template <> ...@@ -326,7 +326,7 @@ template <>
inline void vec_sigmoid<float, platform::jit::avx>(const int n, const float* x, inline void vec_sigmoid<float, platform::jit::avx>(const int n, const float* x,
float* y) { float* y) {
#ifdef __AVX__ #ifdef __AVX__
constexpr int block = AVX_FLOAT_BLOCK; constexpr int block = YMM_FLOAT_BLOCK;
if (n < block) { if (n < block) {
vec_sigmoid<float, platform::jit::isa_any>(n, x, y); vec_sigmoid<float, platform::jit::isa_any>(n, x, y);
return; return;
...@@ -415,7 +415,7 @@ template <> ...@@ -415,7 +415,7 @@ template <>
inline void vec_relu<float, platform::jit::avx>(const int n, const float* x, inline void vec_relu<float, platform::jit::avx>(const int n, const float* x,
float* y) { float* y) {
#ifdef __AVX__ #ifdef __AVX__
constexpr int block = AVX_FLOAT_BLOCK; constexpr int block = YMM_FLOAT_BLOCK;
if (n < block * 4) { if (n < block * 4) {
vec_relu<float, platform::jit::isa_any>(n, x, y); vec_relu<float, platform::jit::isa_any>(n, x, y);
return; return;
......
...@@ -41,7 +41,7 @@ void VXXJitCode::generate() { ...@@ -41,7 +41,7 @@ void VXXJitCode::generate() {
} else if (scalar_index_ == 2) { } else if (scalar_index_ == 2) {
vbroadcastss(ymm_src2, ptr[param2]); vbroadcastss(ymm_src2, ptr[param2]);
} }
for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) { for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
if (scalar_index_ != 1) { if (scalar_index_ != 1) {
vmovups(ymm_src1, ptr[param1 + offset]); vmovups(ymm_src1, ptr[param1 + offset]);
} }
...@@ -57,9 +57,9 @@ void VXXJitCode::generate() { ...@@ -57,9 +57,9 @@ void VXXJitCode::generate() {
vmaxps(ymm_dst, ymm_zero, ymm_dst); vmaxps(ymm_dst, ymm_zero, ymm_dst);
} }
vmovups(ptr[param3 + offset], ymm_dst); vmovups(ptr[param3 + offset], ymm_dst);
offset += sizeof(float) * AVX_FLOAT_BLOCK; offset += sizeof(float) * YMM_FLOAT_BLOCK;
} }
int rest = num_ % AVX_FLOAT_BLOCK; int rest = num_ % YMM_FLOAT_BLOCK;
if (rest >= 4) { if (rest >= 4) {
if (scalar_index_ != 1) { if (scalar_index_ != 1) {
vmovups(xmm_src1, ptr[param1 + offset]); vmovups(xmm_src1, ptr[param1 + offset]);
...@@ -118,18 +118,237 @@ void VXXJitCode::generate() { ...@@ -118,18 +118,237 @@ void VXXJitCode::generate() {
ret(); ret();
} }
bool ReluJitCode::init(int d) { return MayIUse(avx); } #define ALIGN32 __attribute__((aligned(32)))
#define EXP_HIG 88.3762626647949f
#define EXP_LOW -88.3762626647949f
#define CEPHES_LOG2EF 1.44269504088896341
#define CEPHES_EXP_C1 0.693359375
#define CEPHES_EXP_C2 -2.12194440e-4
#define CEPHES_EXP_P0 1.9875691500E-4
#define CEPHES_EXP_P1 1.3981999507E-3
#define CEPHES_EXP_P2 8.3334519073E-3
#define CEPHES_EXP_P3 4.1665795894E-2
#define CEPHES_EXP_P4 1.6666665459E-1
#define CEPHES_EXP_P5 5.0000001201E-1
void ReluJitCode::generate() { #define REPEAT_8TIMES(val) val, val, val, val, val, val, val, val
int offset = 0;
#define OFFSET_EXP_ONE 0 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_TWO 1 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_0P5 2 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_HIG 3 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOW 4 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOG2EF 5 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_C1 6 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_C2 7 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P0 8 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P1 9 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P2 10 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P3 11 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P4 12 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P5 13 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_MAX_INPUT 14 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MAX 15 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MIN 16 * YMM_FLOAT_BLOCK * sizeof(float)
static const float exp_float_consts[] ALIGN32 = {
REPEAT_8TIMES(1.f),
REPEAT_8TIMES(2.f),
REPEAT_8TIMES(0.5f),
REPEAT_8TIMES(EXP_HIG),
REPEAT_8TIMES(EXP_LOW),
REPEAT_8TIMES(CEPHES_LOG2EF),
REPEAT_8TIMES(CEPHES_EXP_C1),
REPEAT_8TIMES(CEPHES_EXP_C2),
REPEAT_8TIMES(CEPHES_EXP_P0),
REPEAT_8TIMES(CEPHES_EXP_P1),
REPEAT_8TIMES(CEPHES_EXP_P2),
REPEAT_8TIMES(CEPHES_EXP_P3),
REPEAT_8TIMES(CEPHES_EXP_P4),
REPEAT_8TIMES(CEPHES_EXP_P5),
REPEAT_8TIMES(EXP_MAX_INPUT),
REPEAT_8TIMES(SIGMOID_THRESHOLD_MAX),
REPEAT_8TIMES(SIGMOID_THRESHOLD_MIN)};
static const int exp_int_0x7f[] ALIGN32 = {REPEAT_8TIMES(0x7f)};
static int g_tmp_mem[16] ALIGN32 = {0};
bool VActJitCode::init(int d, operand_type type) {
bool ok = MayIUse(avx);
if (type == operand_type::relu) {
return ok;
} else if (type == operand_type::exp) {
// exp is slower than mkl when d >= 256
return ok && d % 8 == 0 && d < 256;
} else {
// TODO(TJ): support more
return ok && d % 8 == 0;
}
}
void VActJitCode::relu_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, ymm_t& ymm_zero) {
vmaxps(ymm_dst, ymm_zero, ymm_src);
}
void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
int fy_idx, int mask_idx, int tmp_idx) {
assert(ymm_src.getIdx() != ymm_dst.getIdx()); // TODO(TJ): use enfore
// check all idx can not equal
ymm_t ymm_fx = ymm_t(fx_idx);
ymm_t ymm_fy = ymm_t(fy_idx);
ymm_t ymm_mask = ymm_t(mask_idx);
ymm_t ymm_tmp = ymm_t(tmp_idx);
reg64_t reg_ptr_global = rax;
push(reg_ptr_global);
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]);
vminps(ymm_src, ymm_src, ymm_tmp);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOW]);
vmaxps(ymm_src, ymm_src, ymm_tmp);
// express exp(x) as exp(g + n*log(2))
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOG2EF]);
vmulps(ymm_fx, ymm_src, ymm_tmp);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_0P5]);
vaddps(ymm_fx, ymm_fx, ymm_tmp);
vroundps(ymm_fy, ymm_fx, 0x01);
// if greater, substract 1
vcmpgtps(ymm_mask, ymm_fy, ymm_fx);
vmovaps(ymm_tmp, ptr[reg_ptr_global]);
vandps(ymm_mask, ymm_mask, ymm_tmp);
vsubps(ymm_fx, ymm_fy, ymm_mask);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C1]);
vmulps(ymm_fy, ymm_fx, ymm_tmp);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C2]);
ymm_t ymm_z = ymm_t(ymm_mask.getIdx());
vmulps(ymm_z, ymm_fx, ymm_tmp);
vsubps(ymm_src, ymm_src, ymm_fy);
vsubps(ymm_src, ymm_src, ymm_z);
vmulps(ymm_z, ymm_src, ymm_src);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P0]);
vmulps(ymm_dst, ymm_src, ymm_tmp);
for (size_t i = OFFSET_EXP_P1; i < OFFSET_EXP_P5;
i += (YMM_FLOAT_BLOCK * sizeof(float))) {
vmovaps(ymm_tmp, ptr[reg_ptr_global + i]); // P1~P4
vaddps(ymm_dst, ymm_dst, ymm_tmp);
vmulps(ymm_dst, ymm_dst, ymm_src);
}
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P5]);
vaddps(ymm_dst, ymm_dst, ymm_tmp);
vmulps(ymm_dst, ymm_dst, ymm_z);
vaddps(ymm_dst, ymm_dst, ymm_src);
vmovaps(ymm_tmp, ptr[reg_ptr_global]);
vaddps(ymm_dst, ymm_dst, ymm_tmp);
// build 2^n
ymm_t ymm_int = ymm_fx;
vcvttps2dq(ymm_int, ymm_fx);
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_int_0x7f));
vmovdqa(ymm_tmp, ptr[reg_ptr_global]);
if (MayIUse(avx2)) {
vpaddd(ymm_int, ymm_int, ymm_tmp);
vpslld(ymm_int, ymm_int, 23);
} else if (MayIUse(avx)) {
xmm_t xtmp1 = xmm_t(ymm_int.getIdx());
xmm_t xtmp2 = xmm_t(ymm_tmp.getIdx());
reg64_t reg_ptr_tmp = reg_ptr_global;
mov(reg_ptr_tmp, reinterpret_cast<size_t>(g_tmp_mem));
vmovdqa(ptr[reg_ptr_tmp], ymm_int);
vmovdqa(ptr[reg_ptr_tmp + YMM_FLOAT_BLOCK * sizeof(float)], ymm_tmp);
vpaddd(xtmp1, xtmp1, xtmp2);
vpslld(xtmp1, xtmp1, 23);
vmovdqa(ptr[reg_ptr_tmp], xtmp1);
// next 128bits
vmovdqa(xtmp1, ptr[reg_ptr_tmp + 4 /*xmm float block*/ * sizeof(float)]);
vmovdqa(xtmp2,
ptr[reg_ptr_tmp +
(YMM_FLOAT_BLOCK + 4 /*xmm float block*/) * sizeof(float)]);
vpaddd(xtmp1, xtmp1, xtmp2);
vpslld(xtmp1, xtmp1, 23);
vmovdqa(ptr[reg_ptr_tmp + 4 /*xmm float block*/ * sizeof(float)], xtmp1);
// load out
vmovdqa(ymm_int, ptr[reg_ptr_tmp]);
}
vmulps(ymm_dst, ymm_dst, ymm_int);
pop(reg_ptr_global);
}
void VActJitCode::sigmoid_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
int fy_idx, int mask_idx, int tmp_idx) {
// y = 1 / (1 + e^-x)
ymm_t ymm_tmp = ymm_t(tmp_idx);
reg64_t reg_ptr_global = rax;
push(reg_ptr_global);
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MAX]);
vminps(ymm_src, ymm_src, ymm_tmp);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MIN]);
vmaxps(ymm_src, ymm_src, ymm_tmp);
vxorps(ymm_tmp, ymm_tmp, ymm_tmp);
vsubps(ymm_src, ymm_tmp, ymm_src);
exp_ymm(ymm_dst, ymm_src, fx_idx, fy_idx, mask_idx, tmp_idx);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
vaddps(ymm_dst, ymm_dst, ymm_tmp);
vdivps(ymm_dst, ymm_tmp, ymm_dst);
pop(reg_ptr_global);
}
void VActJitCode::tanh_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
int fy_idx, int mask_idx, int tmp_idx) {
// y = 2 / (1 + e^(-2x)) - 1
ymm_t ymm_tmp = ymm_t(tmp_idx);
ymm_t ymm_zero = ymm_t(mask_idx);
reg64_t reg_ptr_global = rax;
push(reg_ptr_global);
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
vxorps(ymm_zero, ymm_zero, ymm_zero); vxorps(ymm_zero, ymm_zero, ymm_zero);
for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) { vsubps(ymm_tmp, ymm_zero, ymm_tmp);
vmulps(ymm_src, ymm_src, ymm_tmp);
exp_ymm(ymm_dst, ymm_src, fx_idx, fy_idx, mask_idx, tmp_idx);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
vaddps(ymm_dst, ymm_dst, ymm_tmp);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
vdivps(ymm_dst, ymm_tmp, ymm_dst);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
vsubps(ymm_dst, ymm_dst, ymm_tmp);
pop(reg_ptr_global);
}
void VActJitCode::generate() {
xmm_t xmm_zero = xmm_t(2);
ymm_t ymm_zero = ymm_t(2);
if (type_ == operand_type::relu) {
vxorps(ymm_zero, ymm_zero, ymm_zero);
}
int offset = 0;
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
vmovups(ymm_src, ptr[param1 + offset]); vmovups(ymm_src, ptr[param1 + offset]);
vmaxps(ymm_dst, ymm_zero, ymm_src); switch (type_) {
case operand_type::relu:
relu_ymm(ymm_dst, ymm_src, ymm_zero);
break;
case operand_type::exp:
exp_ymm(ymm_dst, ymm_src, 2, 3, 4, 5);
break;
case operand_type::sigmoid:
sigmoid_ymm(ymm_dst, ymm_src, 2, 3, 4, 5);
break;
case operand_type::tanh:
tanh_ymm(ymm_dst, ymm_src, 2, 3, 4, 5);
break;
case operand_type::identity:
break;
default:
break;
}
vmovups(ptr[param2 + offset], ymm_dst); vmovups(ptr[param2 + offset], ymm_dst);
offset += sizeof(float) * AVX_FLOAT_BLOCK; offset += sizeof(float) * YMM_FLOAT_BLOCK;
} }
int rest = num_ % AVX_FLOAT_BLOCK; if (type_ != operand_type::relu) {
// TODO(TJ): remove me
ret();
return;
}
int rest = num_ % YMM_FLOAT_BLOCK;
if (rest >= 4) { if (rest >= 4) {
vmovups(xmm_src, ptr[param1 + offset]); vmovups(xmm_src, ptr[param1 + offset]);
vmaxps(xmm_dst, xmm_zero, xmm_src); vmaxps(xmm_dst, xmm_zero, xmm_src);
...@@ -151,6 +370,7 @@ void ReluJitCode::generate() { ...@@ -151,6 +370,7 @@ void ReluJitCode::generate() {
} }
ret(); ret();
} }
} // namespace gen } // namespace gen
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
......
...@@ -29,7 +29,16 @@ using ymm_t = const Xbyak::Ymm; ...@@ -29,7 +29,16 @@ using ymm_t = const Xbyak::Ymm;
using zmm_t = const Xbyak::Zmm; using zmm_t = const Xbyak::Zmm;
using Label = Xbyak::Label; using Label = Xbyak::Label;
typedef enum { mul = 0, add } operand_type; typedef enum {
mul = 0,
add,
sub,
relu,
exp,
sigmoid,
tanh,
identity
} operand_type;
// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu) // function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu)
class VXXJitCode : public JitCode { class VXXJitCode : public JitCode {
...@@ -85,26 +94,65 @@ class VXXJitCode : public JitCode { ...@@ -85,26 +94,65 @@ class VXXJitCode : public JitCode {
ymm_t ymm_zero = ymm_t(3); ymm_t ymm_zero = ymm_t(3);
}; };
class ReluJitCode : public JitCode { class VActJitCode : public JitCode {
public: public:
DECLARE_JIT_CODE(ReluJitCode); const char* name() const override {
explicit ReluJitCode(int d, size_t code_size = 256 * 1024, std::string base = "VActJitCode";
switch (type_) {
case operand_type::relu:
base += "_Relu";
break;
case operand_type::exp:
base += "_Exp";
break;
case operand_type::sigmoid:
base += "_Sigmoid";
break;
case operand_type::tanh:
base += "_Tanh";
break;
case operand_type::identity:
base += "_Identity";
break;
default:
break;
}
return base.c_str();
}
explicit VActJitCode(int d, operand_type type, size_t code_size = 256 * 1024,
void* code_ptr = nullptr) void* code_ptr = nullptr)
: JitCode(code_size, code_ptr), num_(d) {} : JitCode(code_size, code_ptr), num_(d), type_(type) {}
static bool init(int d); static bool init(int d, operand_type type);
void generate() override; void generate() override;
private: protected:
// compute relu with ymm
void relu_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src,
const Xbyak::Ymm& zero);
// compute exp with ymm
void exp_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2,
int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5);
// compute sigmoid with ymm
void sigmoid_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2,
int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5);
// compute tanh with ymm
void tanh_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2,
int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5);
protected:
int num_; int num_;
operand_type type_;
reg64_t param1{abi_param1}; reg64_t param1{abi_param1};
reg64_t param2{abi_param2}; reg64_t param2{abi_param2};
xmm_t xmm_zero = xmm_t(0); xmm_t xmm_src = xmm_t(0);
xmm_t xmm_src = xmm_t(1); ymm_t ymm_src = ymm_t(0);
xmm_t xmm_dst = xmm_t(1);
ymm_t ymm_zero = ymm_t(0); xmm_t xmm_dst = xmm_t(1);
ymm_t ymm_src = ymm_t(1);
ymm_t ymm_dst = ymm_t(1); ymm_t ymm_dst = ymm_t(1);
}; };
......
...@@ -29,9 +29,9 @@ namespace jitkernel { ...@@ -29,9 +29,9 @@ namespace jitkernel {
#define SIGMOID_THRESHOLD_MIN -40.0 #define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0 #define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0 #define EXP_MAX_INPUT 40.0
#define AVX_FLOAT_BLOCK 8 #define XMM_FLOAT_BLOCK 4
#define AVX2_FLOAT_BLOCK 8 #define YMM_FLOAT_BLOCK 8
#define AVX512_FLOAT_BLOCK 16 #define ZMM_FLOAT_BLOCK 16
typedef enum { kLT8, kEQ8, kGT8LT16, kEQ16, kGT16 } jit_block; typedef enum { kLT8, kEQ8, kGT8LT16, kEQ16, kGT16 } jit_block;
...@@ -97,39 +97,23 @@ class VAddBiasKernel : public Kernel { ...@@ -97,39 +97,23 @@ class VAddBiasKernel : public Kernel {
template <typename T> template <typename T>
class VActKernel : public Kernel { class VActKernel : public Kernel {
public: public:
virtual void ComputeDeprecated(const T *x, T *y) const = 0; void (*Compute)(const T *, T *, int);
}; };
template <typename T> template <typename T>
class VReluKernel : public VActKernel<T> { class VReluKernel : public VActKernel<T> {};
public:
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
void (*Compute)(const T *, T *, int);
};
template <typename T> template <typename T>
class VIdentityKernel : public VActKernel<T> { class VIdentityKernel : public VActKernel<T> {};
public:
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
};
template <typename T> template <typename T>
class VExpKernel : public VActKernel<T> { class VExpKernel : public VActKernel<T> {};
public:
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
};
template <typename T> template <typename T>
class VSigmoidKernel : public VActKernel<T> { class VSigmoidKernel : public VActKernel<T> {};
public:
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
};
template <typename T> template <typename T>
class VTanhKernel : public VActKernel<T> { class VTanhKernel : public VActKernel<T> {};
public:
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
};
template <typename T> template <typename T>
class LSTMKernel : public Kernel { class LSTMKernel : public Kernel {
......
...@@ -25,10 +25,6 @@ limitations under the License. */ ...@@ -25,10 +25,6 @@ limitations under the License. */
#include "paddle/fluid/platform/dynload/mklml.h" #include "paddle/fluid/platform/dynload/mklml.h"
#endif #endif
#ifdef __AVX__
#include <immintrin.h>
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
...@@ -128,23 +124,16 @@ void VScalMKL<double>(const double* a, const double* x, double* y, int n) { ...@@ -128,23 +124,16 @@ void VScalMKL<double>(const double* a, const double* x, double* y, int n) {
#endif #endif
#define DECLARE_STATIC_FUNC \
static inline std::string name(int d) { \
PADDLE_THROW("DType should be either float or double"); \
} \
static inline bool useJIT(int d) { return false; } \
static inline bool useMKL(int d) { return false; }
/* VMUL JitKernel */ /* VMUL JitKernel */
template <typename T> template <typename T>
class VMulKernelImpl : public VMulKernel<T> { class VMulKernelImpl : public VMulKernel<T> {
public: public:
DECLARE_STATIC_FUNC; JITKERNEL_DECLARE_STATIC_FUNC;
explicit VMulKernelImpl(int d) : VMulKernel<T>() { explicit VMulKernelImpl(int d) : VMulKernel<T>() {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) { if (useJIT(d)) {
// roughly estimate the size of code // roughly estimate the size of code
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::mul, 0, false, jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::mul, 0, false,
sz > 4096 ? sz : 4096)); sz > 4096 ? sz : 4096));
this->Compute = this->Compute =
...@@ -191,11 +180,11 @@ bool VMulKernelImpl<double>::useMKL(int d) { ...@@ -191,11 +180,11 @@ bool VMulKernelImpl<double>::useMKL(int d) {
template <typename T> template <typename T>
class VAddKernelImpl : public VAddKernel<T> { class VAddKernelImpl : public VAddKernel<T> {
public: public:
DECLARE_STATIC_FUNC; JITKERNEL_DECLARE_STATIC_FUNC;
explicit VAddKernelImpl(int d) : VAddKernel<T>() { explicit VAddKernelImpl(int d) : VAddKernel<T>() {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) { if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 0, false, jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 0, false,
sz > 4096 ? sz : 4096)); sz > 4096 ? sz : 4096));
this->Compute = this->Compute =
...@@ -241,11 +230,11 @@ bool VAddKernelImpl<double>::useMKL(int d) { ...@@ -241,11 +230,11 @@ bool VAddKernelImpl<double>::useMKL(int d) {
template <typename T> template <typename T>
class VAddReluKernelImpl : public VAddReluKernel<T> { class VAddReluKernelImpl : public VAddReluKernel<T> {
public: public:
DECLARE_STATIC_FUNC; JITKERNEL_DECLARE_STATIC_FUNC;
explicit VAddReluKernelImpl(int d) : VAddReluKernel<T>() { explicit VAddReluKernelImpl(int d) : VAddReluKernel<T>() {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) { if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 0, true, jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 0, true,
sz > 4096 ? sz : 4096)); sz > 4096 ? sz : 4096));
this->Compute = this->Compute =
...@@ -273,11 +262,11 @@ bool VAddReluKernelImpl<float>::useJIT(int d) { ...@@ -273,11 +262,11 @@ bool VAddReluKernelImpl<float>::useJIT(int d) {
template <typename T> template <typename T>
class VScalKernelImpl : public VScalKernel<T> { class VScalKernelImpl : public VScalKernel<T> {
public: public:
DECLARE_STATIC_FUNC; JITKERNEL_DECLARE_STATIC_FUNC;
explicit VScalKernelImpl(int d) : VScalKernel<T>() { explicit VScalKernelImpl(int d) : VScalKernel<T>() {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) { if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::mul, 1, false, jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::mul, 1, false,
sz > 4096 ? sz : 4096)); sz > 4096 ? sz : 4096));
this->Compute = this->Compute =
...@@ -322,11 +311,11 @@ bool VScalKernelImpl<double>::useMKL(int d) { ...@@ -322,11 +311,11 @@ bool VScalKernelImpl<double>::useMKL(int d) {
template <typename T> template <typename T>
class VAddBiasKernelImpl : public VAddBiasKernel<T> { class VAddBiasKernelImpl : public VAddBiasKernel<T> {
public: public:
DECLARE_STATIC_FUNC; JITKERNEL_DECLARE_STATIC_FUNC;
explicit VAddBiasKernelImpl(int d) : VAddBiasKernel<T>() { explicit VAddBiasKernelImpl(int d) : VAddBiasKernel<T>() {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) { if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 1, false, jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 1, false,
sz > 4096 ? sz : 4096)); sz > 4096 ? sz : 4096));
this->Compute = this->Compute =
...@@ -355,15 +344,15 @@ bool VAddBiasKernelImpl<float>::useJIT(int d) { ...@@ -355,15 +344,15 @@ bool VAddBiasKernelImpl<float>::useJIT(int d) {
template <typename T> template <typename T>
class VReluKernelImpl : public VReluKernel<T> { class VReluKernelImpl : public VReluKernel<T> {
public: public:
DECLARE_STATIC_FUNC; JITKERNEL_DECLARE_STATIC_FUNC;
explicit VReluKernelImpl(int d) : VReluKernel<T>() { explicit VReluKernelImpl(int d) : VReluKernel<T>() {
this->num_ = d; // TODO(TJ): remove me when ComputeDeprecated done
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) { if (useJIT(d)) {
size_t sz = 96 /*init*/ + size_t sz = 96 /* init size */ +
d / AVX_FLOAT_BLOCK * 4 /* instructions*/ * d / YMM_FLOAT_BLOCK * 4 /* instructions */ *
8 /*everage byte for each instruction*/; 8 /* average bytes for each instruction */;
jitcode_.reset(new gen::ReluJitCode(d, sz > 4096 ? sz : 4096)); jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::relu,
sz > 4096 ? sz : 4096));
this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>(); this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>();
return; return;
} }
...@@ -371,24 +360,32 @@ class VReluKernelImpl : public VReluKernel<T> { ...@@ -371,24 +360,32 @@ class VReluKernelImpl : public VReluKernel<T> {
this->Compute = VReluRefer<T>; this->Compute = VReluRefer<T>;
} }
void ComputeDeprecated(const T* x, T* y) const override {
VReluRefer(x, y, this->num_);
}
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
private: private:
std::unique_ptr<gen::ReluJitCode> jitcode_{nullptr}; std::unique_ptr<gen::VActJitCode> jitcode_{nullptr};
#endif #endif
}; };
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
template <> template <>
bool VReluKernelImpl<float>::useJIT(int d) { bool VReluKernelImpl<float>::useJIT(int d) {
return gen::ReluJitCode::init(d); return gen::VActJitCode::init(d, gen::operand_type::relu);
} }
#endif #endif
#undef DECLARE_STATIC_FUNC template <typename T>
inline void VIdentityRefer(const T* x, T* y, int n) {}
/* An empty JitKernel */
template <typename T>
class VIdentityKernelImpl : public VIdentityKernel<T> {
public:
JITKERNEL_DECLARE_STATIC_FUNC;
explicit VIdentityKernelImpl(int d) : VIdentityKernel<T>() {
this->Compute = VIdentityRefer<T>;
}
};
REGISTER_JITKERNEL(vmul, VMulKernel); REGISTER_JITKERNEL(vmul, VMulKernel);
REGISTER_JITKERNEL(vadd, VAddKernel); REGISTER_JITKERNEL(vadd, VAddKernel);
...@@ -396,16 +393,7 @@ REGISTER_JITKERNEL(vaddrelu, VAddReluKernel); ...@@ -396,16 +393,7 @@ REGISTER_JITKERNEL(vaddrelu, VAddReluKernel);
REGISTER_JITKERNEL(vscal, VScalKernel); REGISTER_JITKERNEL(vscal, VScalKernel);
REGISTER_JITKERNEL(vaddbias, VAddBiasKernel); REGISTER_JITKERNEL(vaddbias, VAddBiasKernel);
REGISTER_JITKERNEL(vrelu, VReluKernel); REGISTER_JITKERNEL(vrelu, VReluKernel);
REGISTER_JITKERNEL(videntity, VIdentityKernel);
/* An empty JitKernel */
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
class VIdentityKernelImpl : public VIdentityKernel<T> {
public:
explicit VIdentityKernelImpl(int d) : VIdentityKernel<T>() { this->num_ = d; }
void ComputeDeprecated(const T* x, T* y) const override {}
};
REGISTER_JITKERNEL_DEPRECATED(videntity, VIdentityKernel);
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
......
...@@ -105,14 +105,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> { ...@@ -105,14 +105,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
int tag_num) \ int tag_num) \
: CRFDecodeKernel<float>() { \ : CRFDecodeKernel<float>() { \
this->num_ = tag_num; \ this->num_ = tag_num; \
this->end_ = this->num_ / AVX_FLOAT_BLOCK; \ this->end_ = this->num_ / YMM_FLOAT_BLOCK; \
this->rest_ = this->num_ % AVX_FLOAT_BLOCK; \ this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \
} \ } \
template <> \ template <> \
void CRFDecodeKernelImpl<float, jit::avx, block>::Compute( \ void CRFDecodeKernelImpl<float, jit::avx, block>::Compute( \
const int seq_len, const float* x, const float* w, float* alpha, \ const int seq_len, const float* x, const float* w, float* alpha, \
int* track) const { \ int* track) const { \
INIT_ALPHA(AVX_FLOAT_BLOCK) \ INIT_ALPHA(YMM_FLOAT_BLOCK) \
/* Use the column-major strategy to get the location of maximum score.*/ \ /* Use the column-major strategy to get the location of maximum score.*/ \
int seq_offset = 0; \ int seq_offset = 0; \
constexpr int state_trans_base_idx = 2; \ constexpr int state_trans_base_idx = 2; \
...@@ -150,7 +150,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> { ...@@ -150,7 +150,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
max_score = _mm256_max_ps(max_score, score_v); \ max_score = _mm256_max_ps(max_score, score_v); \
trans_offset += this->num_; \ trans_offset += this->num_; \
} \ } \
UPDATE_ALPHA(AVX_FLOAT_BLOCK) \ UPDATE_ALPHA(YMM_FLOAT_BLOCK) \
} \ } \
seq_offset += this->num_; \ seq_offset += this->num_; \
} \ } \
...@@ -161,14 +161,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> { ...@@ -161,14 +161,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
CRFDecodeKernelImpl<float, isa, block>::CRFDecodeKernelImpl(int tag_num) \ CRFDecodeKernelImpl<float, isa, block>::CRFDecodeKernelImpl(int tag_num) \
: CRFDecodeKernel<float>() { \ : CRFDecodeKernel<float>() { \
this->num_ = tag_num; \ this->num_ = tag_num; \
this->end_ = this->num_ / AVX2_FLOAT_BLOCK; \ this->end_ = this->num_ / YMM_FLOAT_BLOCK; \
this->rest_ = this->num_ % AVX2_FLOAT_BLOCK; \ this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \
} \ } \
template <> \ template <> \
void CRFDecodeKernelImpl<float, isa, block>::Compute( \ void CRFDecodeKernelImpl<float, isa, block>::Compute( \
const int seq_len, const float* x, const float* w, float* alpha, \ const int seq_len, const float* x, const float* w, float* alpha, \
int* track) const { \ int* track) const { \
INIT_ALPHA(AVX2_FLOAT_BLOCK) \ INIT_ALPHA(YMM_FLOAT_BLOCK) \
/* Use the column-major strategy to get the location of maximum score.*/ \ /* Use the column-major strategy to get the location of maximum score.*/ \
int seq_offset = 0; \ int seq_offset = 0; \
constexpr int state_trans_base_idx = 2; \ constexpr int state_trans_base_idx = 2; \
...@@ -196,7 +196,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> { ...@@ -196,7 +196,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
max_score = _mm256_max_ps(max_score, score_v); \ max_score = _mm256_max_ps(max_score, score_v); \
trans_offset += this->num_; \ trans_offset += this->num_; \
} \ } \
UPDATE_ALPHA(AVX2_FLOAT_BLOCK) \ UPDATE_ALPHA(YMM_FLOAT_BLOCK) \
} \ } \
seq_offset += this->num_; \ seq_offset += this->num_; \
} \ } \
...@@ -208,14 +208,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> { ...@@ -208,14 +208,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
int tag_num) \ int tag_num) \
: CRFDecodeKernel<float>() { \ : CRFDecodeKernel<float>() { \
this->num_ = tag_num; \ this->num_ = tag_num; \
this->end_ = this->num_ / AVX512_FLOAT_BLOCK; \ this->end_ = this->num_ / ZMM_FLOAT_BLOCK; \
this->rest_ = this->num_ % AVX512_FLOAT_BLOCK; \ this->rest_ = this->num_ % ZMM_FLOAT_BLOCK; \
} \ } \
template <> \ template <> \
void CRFDecodeKernelImpl<float, jit::avx512f, block>::Compute( \ void CRFDecodeKernelImpl<float, jit::avx512f, block>::Compute( \
const int seq_len, const float* x, const float* w, float* alpha, \ const int seq_len, const float* x, const float* w, float* alpha, \
int* track) const { \ int* track) const { \
INIT_ALPHA(AVX512_FLOAT_BLOCK) \ INIT_ALPHA(ZMM_FLOAT_BLOCK) \
/* Use the column-major strategy to get the location of maximum score.*/ \ /* Use the column-major strategy to get the location of maximum score.*/ \
int seq_offset = 0; \ int seq_offset = 0; \
constexpr int state_trans_base_idx = 2; \ constexpr int state_trans_base_idx = 2; \
...@@ -250,7 +250,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> { ...@@ -250,7 +250,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
this->num_ + j_offset), \ this->num_ + j_offset), \
max_j); \ max_j); \
/* Calculate the offset of next step*/ \ /* Calculate the offset of next step*/ \
j_offset += AVX512_FLOAT_BLOCK; \ j_offset += ZMM_FLOAT_BLOCK; \
if (j == this->end_ - 1) { \ if (j == this->end_ - 1) { \
if (this->rest_ > 0) { \ if (this->rest_ > 0) { \
j_offset += last_offset; \ j_offset += last_offset; \
......
...@@ -15,12 +15,20 @@ limitations under the License. */ ...@@ -15,12 +15,20 @@ limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
namespace jitkernel { namespace jitkernel {
#define JITKERNEL_DECLARE_STATIC_FUNC \
static inline std::string name(int d) { \
PADDLE_THROW("DType should be either float or double"); \
} \
static inline bool useJIT(int d) { return false; } \
static inline bool useMKL(int d) { return false; }
#define JITKERNEL_DEFINE_NAME(ker_key, ker_class) \ #define JITKERNEL_DEFINE_NAME(ker_key, ker_class) \
template <> \ template <> \
std::string ker_class##Impl<float>::name(int d) { \ std::string ker_class##Impl<float>::name(int d) { \
...@@ -87,13 +95,13 @@ namespace jitkernel { ...@@ -87,13 +95,13 @@ namespace jitkernel {
namespace jit = platform::jit; namespace jit = platform::jit;
// TODO(TJ): below defines are deprecated, would be remove recently // TODO(TJ): below defines are deprecated, would be remove recently
#define SEARCH_BLOCK(macro_, ker, dtype, isa) \ #define SEARCH_BLOCK(macro_, ker, dtype, isa) \
if (d < AVX_FLOAT_BLOCK) { \ if (d < YMM_FLOAT_BLOCK) { \
macro_(ker, dtype, isa, kLT8); \ macro_(ker, dtype, isa, kLT8); \
} else if (d == AVX_FLOAT_BLOCK) { \ } else if (d == YMM_FLOAT_BLOCK) { \
macro_(ker, dtype, isa, kEQ8); \ macro_(ker, dtype, isa, kEQ8); \
} else if (d > AVX_FLOAT_BLOCK && d < AVX512_FLOAT_BLOCK) { \ } else if (d > YMM_FLOAT_BLOCK && d < ZMM_FLOAT_BLOCK) { \
macro_(ker, dtype, isa, kGT8LT16); \ macro_(ker, dtype, isa, kGT8LT16); \
} else if (d == AVX512_FLOAT_BLOCK) { \ } else if (d == ZMM_FLOAT_BLOCK) { \
macro_(ker, dtype, isa, kEQ16); \ macro_(ker, dtype, isa, kEQ16); \
} else { \ } else { \
macro_(ker, dtype, isa, kGT16); \ macro_(ker, dtype, isa, kGT16); \
......
...@@ -175,26 +175,26 @@ class LSTMKernelImpl : public LSTMKernel<T> { ...@@ -175,26 +175,26 @@ class LSTMKernelImpl : public LSTMKernel<T> {
void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data, void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data,
T* checked) const override { T* checked) const override {
// gates: W_ch, W_ih, W_fh, W_oh // gates: W_ch, W_ih, W_fh, W_oh
act_gate_d3_->ComputeDeprecated(gates + d_, gates + d_); act_gate_d3_->Compute(gates + d_, gates + d_, d3_);
/* C_t = C_t-1 * fgated + cand_gated * igated */ /* C_t = C_t-1 * fgated + cand_gated * igated */
act_cand_d_->ComputeDeprecated(gates, gates); act_cand_d_->Compute(gates, gates, d_);
vmul_d_->Compute(gates, gates + d_, gates + d_, d_); vmul_d_->Compute(gates, gates + d_, gates + d_, d_);
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_); vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_);
vadd_d_->Compute(gates + d_, gates + d2_, ct, d_); vadd_d_->Compute(gates + d_, gates + d2_, ct, d_);
/* H_t = act_cell(C_t) * ogated */ /* H_t = act_cell(C_t) * ogated */
act_cell_d_->ComputeDeprecated(ct, gates + d2_); act_cell_d_->Compute(ct, gates + d2_, d_);
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_); vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
} }
void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override { void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override {
/* C_t = igated * cgated*/ /* C_t = igated * cgated*/
act_gate_d_->ComputeDeprecated(gates + d_, gates + d_); act_gate_d_->Compute(gates + d_, gates + d_, d_);
act_cand_d_->ComputeDeprecated(gates, gates); act_cand_d_->Compute(gates, gates, d_);
vmul_d_->Compute(gates, gates + d_, ct, d_); vmul_d_->Compute(gates, gates + d_, ct, d_);
/* H_t = act_cell(C_t) * ogated */ /* H_t = act_cell(C_t) * ogated */
act_gate_d_->ComputeDeprecated(gates + d3_, gates + d3_); act_gate_d_->Compute(gates + d3_, gates + d3_, d_);
act_cell_d_->ComputeDeprecated(ct, gates + d2_); act_cell_d_->Compute(ct, gates + d2_, d_);
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_); vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
} }
...@@ -292,32 +292,32 @@ class PeepholeKernelImpl : public LSTMKernel<T> { ...@@ -292,32 +292,32 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
vmul_d_->Compute(wp_data, ct_1, checked, d_); vmul_d_->Compute(wp_data, ct_1, checked, d_);
vmul_d_->Compute(wp_data + d_, ct_1, checked + d_, d_); vmul_d_->Compute(wp_data + d_, ct_1, checked + d_, d_);
vadd_d2_->Compute(checked, gates + d_, gates + d_, d2_); vadd_d2_->Compute(checked, gates + d_, gates + d_, d2_);
act_gate_d2_->ComputeDeprecated(gates + d_, gates + d_); act_gate_d2_->Compute(gates + d_, gates + d_, d2_);
/* C_t = C_t-1 * fgated + cand_gated * igated*/ /* C_t = C_t-1 * fgated + cand_gated * igated*/
act_cand_d_->ComputeDeprecated(gates, gates); act_cand_d_->Compute(gates, gates, d_);
vmul_d_->Compute(gates, gates + d_, gates + d_, d_); vmul_d_->Compute(gates, gates + d_, gates + d_, d_);
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_); vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_);
vadd_d_->Compute(gates + d_, gates + d2_, ct, d_); vadd_d_->Compute(gates + d_, gates + d2_, ct, d_);
/* get ogated*/ /* get ogated*/
vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_); vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_);
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_); vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_);
act_gate_d_->ComputeDeprecated(gates + d3_, gates + d3_); act_gate_d_->Compute(gates + d3_, gates + d3_, d_);
/* H_t = act_cell(C_t) * ogated */ /* H_t = act_cell(C_t) * ogated */
act_cell_d_->ComputeDeprecated(ct, gates + d2_); act_cell_d_->Compute(ct, gates + d2_, d_);
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_); vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
} }
void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override { void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override {
/* C_t = igated * cgated*/ /* C_t = igated * cgated*/
act_gate_d_->ComputeDeprecated(gates + d_, gates + d_); act_gate_d_->Compute(gates + d_, gates + d_, d_);
act_cand_d_->ComputeDeprecated(gates, gates); act_cand_d_->Compute(gates, gates, d_);
vmul_d_->Compute(gates, gates + d_, ct, d_); vmul_d_->Compute(gates, gates + d_, ct, d_);
/* get outgated, put W_oc * C_t on igated */ /* get outgated, put W_oc * C_t on igated */
vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_); vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_);
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_); vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_);
/* H_t = act_cell(C_t) * ogated */ /* H_t = act_cell(C_t) * ogated */
act_gate_d_->ComputeDeprecated(gates + d3_, gates + d3_); act_gate_d_->Compute(gates + d3_, gates + d3_, d_);
act_cell_d_->ComputeDeprecated(ct, gates + d2_); act_cell_d_->Compute(ct, gates + d2_, d_);
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_); vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
} }
...@@ -376,20 +376,20 @@ class GRUKernelImpl : public GRUKernel<T> { ...@@ -376,20 +376,20 @@ class GRUKernelImpl : public GRUKernel<T> {
} }
void ComputeH1(T* gates, T* ht) const override { void ComputeH1(T* gates, T* ht) const override {
act_gate_d_->ComputeDeprecated(gates, gates); act_gate_d_->Compute(gates, gates, d_);
act_state_d_->ComputeDeprecated(gates + d2_, gates + d2_); act_state_d_->Compute(gates + d2_, gates + d2_, d_);
vmul_d_->Compute(gates, gates + d2_, ht, d_); vmul_d_->Compute(gates, gates + d2_, ht, d_);
} }
void ComputeHtPart1(T* gates, const T* ht_1, T* ht) const override { void ComputeHtPart1(T* gates, const T* ht_1, T* ht) const override {
// W: {W_update, W_reset; W_state} // W: {W_update, W_reset; W_state}
act_gate_d2_->ComputeDeprecated(gates, gates); act_gate_d2_->Compute(gates, gates, d2_);
vmul_d_->Compute(ht_1, gates + d_, ht, d_); vmul_d_->Compute(ht_1, gates + d_, ht, d_);
} }
void ComputeHtPart2(T* gates, const T* ht_1, T* ht) const override { void ComputeHtPart2(T* gates, const T* ht_1, T* ht) const override {
T* y = gates + d2_; T* y = gates + d2_;
act_state_d_->ComputeDeprecated(y, y); act_state_d_->Compute(y, y, d_);
// out = zt*ht~ + (1-zt)*ht_1 // out = zt*ht~ + (1-zt)*ht_1
for (int i = 0; i < d_; ++i) { for (int i = 0; i < d_; ++i) {
ht[i] = gates[i] * y[i] + (static_cast<T>(1) - gates[i]) * ht_1[i]; ht[i] = gates[i] * y[i] + (static_cast<T>(1) - gates[i]) * ht_1[i];
......
...@@ -181,7 +181,8 @@ TEST(JitKernel, vexp) { ...@@ -181,7 +181,8 @@ TEST(JitKernel, vexp) {
auto ttgts = GetCurrentUS(); auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
ker->ComputeDeprecated(x_data, ztgt_data); // ker->Compute(x_data, ztgt_data);
ker->Compute(x_data, ztgt_data, d);
} }
auto ttgte = GetCurrentUS(); auto ttgte = GetCurrentUS();
...@@ -222,7 +223,7 @@ void vsigmoid_better( ...@@ -222,7 +223,7 @@ void vsigmoid_better(
y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = 0.f - y[i]; y[i] = 0.f - y[i];
} }
vexp->ComputeDeprecated(y, y); vexp->Compute(y, y, n);
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
y[i] = 1.f / (1.f + y[i]); y[i] = 1.f / (1.f + y[i]);
} }
...@@ -253,7 +254,7 @@ TEST(JitKernel, vsigmoid) { ...@@ -253,7 +254,7 @@ TEST(JitKernel, vsigmoid) {
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
auto ttgts = GetCurrentUS(); auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
ker->ComputeDeprecated(x_data, ztgt_data); ker->Compute(x_data, ztgt_data, d);
} }
auto ttgte = GetCurrentUS(); auto ttgte = GetCurrentUS();
...@@ -287,7 +288,7 @@ void vtanh_better( ...@@ -287,7 +288,7 @@ void vtanh_better(
const int n, const float* x, float* y) { const int n, const float* x, float* y) {
const float a = 2.f, b = -1.f; const float a = 2.f, b = -1.f;
vscal->Compute(&a, x, y, n); vscal->Compute(&a, x, y, n);
vsigmoid->ComputeDeprecated(y, y); vsigmoid->Compute(y, y, n);
vscal->Compute(&a, y, y, n); vscal->Compute(&a, y, y, n);
vaddbias->Compute(&b, y, y, n); vaddbias->Compute(&b, y, y, n);
} }
...@@ -321,7 +322,7 @@ TEST(JitKernel, vtanh) { ...@@ -321,7 +322,7 @@ TEST(JitKernel, vtanh) {
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
auto ttgts = GetCurrentUS(); auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
ker->ComputeDeprecated(x_data, ztgt_data); ker->Compute(x_data, ztgt_data, d);
} }
auto ttgte = GetCurrentUS(); auto ttgte = GetCurrentUS();
...@@ -344,8 +345,8 @@ void lstm_ctht_ref( ...@@ -344,8 +345,8 @@ void lstm_ctht_ref(
const std::shared_ptr< const std::shared_ptr<
const paddle::operators::math::jitkernel::VExpKernel<float>>& vexp_1, const paddle::operators::math::jitkernel::VExpKernel<float>>& vexp_1,
const int d, float* gates, const float* ct_1, float* ct, float* ht) { const int d, float* gates, const float* ct_1, float* ct, float* ht) {
vsigmoid_3d->ComputeDeprecated(gates + d, gates + d); vsigmoid_3d->Compute(gates + d, gates + d, 3 * d);
vtanh_d->ComputeDeprecated(gates, gates); vtanh_d->Compute(gates, gates, d);
const float *i = gates + d, *f = gates + d * 2, *o = gates + d * 3; const float *i = gates + d, *f = gates + d * 2, *o = gates + d * 3;
const float min = SIGMOID_THRESHOLD_MIN; const float min = SIGMOID_THRESHOLD_MIN;
const float max = SIGMOID_THRESHOLD_MAX; const float max = SIGMOID_THRESHOLD_MAX;
...@@ -355,7 +356,7 @@ void lstm_ctht_ref( ...@@ -355,7 +356,7 @@ void lstm_ctht_ref(
// H_t = act_cell(C_t) * ogated // H_t = act_cell(C_t) * ogated
float tmp = ct[k] * 2; float tmp = ct[k] * 2;
tmp = 0.f - ((tmp < min) ? min : ((tmp > max) ? max : tmp)); tmp = 0.f - ((tmp < min) ? min : ((tmp > max) ? max : tmp));
vexp_1->ComputeDeprecated(&tmp, &tmp); vexp_1->Compute(&tmp, &tmp, 1);
tmp = 2.f / (1.f + tmp) - 1.f; tmp = 2.f / (1.f + tmp) - 1.f;
ht[k] = tmp * o[k]; ht[k] = tmp * o[k];
} }
...@@ -373,13 +374,13 @@ void lstm_ctht_better( ...@@ -373,13 +374,13 @@ void lstm_ctht_better(
const paddle::operators::math::jitkernel::VAddKernel<float>>& vadd_d, const paddle::operators::math::jitkernel::VAddKernel<float>>& vadd_d,
const int d, float* gates, const float* ct_1, float* ct, float* ht) { const int d, float* gates, const float* ct_1, float* ct, float* ht) {
int d2 = d * 2; int d2 = d * 2;
vsigmoid_3d->ComputeDeprecated(gates + d, gates + d); vsigmoid_3d->Compute(gates + d, gates + d, 3 * d);
vtanh_d->ComputeDeprecated(gates, gates); vtanh_d->Compute(gates, gates, d);
vmul_d->Compute(gates, gates + d, gates + d, d); vmul_d->Compute(gates, gates + d, gates + d, d);
vmul_d->Compute(ct_1, gates + d2, gates + d2, d); vmul_d->Compute(ct_1, gates + d2, gates + d2, d);
vadd_d->Compute(gates + d, gates + d2, ct, d); vadd_d->Compute(gates + d, gates + d2, ct, d);
/* H_t = act_cell(C_t) * ogated */ /* H_t = act_cell(C_t) * ogated */
vtanh_d->ComputeDeprecated(ct, gates + d2); vtanh_d->Compute(ct, gates + d2, d);
vmul_d->Compute(gates + d2, gates + d * 3, ht, d); vmul_d->Compute(gates + d2, gates + d * 3, ht, d);
} }
...@@ -736,7 +737,7 @@ void vaddrelu_better( ...@@ -736,7 +737,7 @@ void vaddrelu_better(
const paddle::operators::math::jitkernel::VReluKernel<float>>& vrelu, const paddle::operators::math::jitkernel::VReluKernel<float>>& vrelu,
const float* x, const float* y, float* z, int d) { const float* x, const float* y, float* z, int d) {
vadd->Compute(x, y, z, d); vadd->Compute(x, y, z, d);
vrelu->ComputeDeprecated(z, z); vrelu->Compute(z, z, d);
} }
TEST(JitKernel, vaddrelu) { TEST(JitKernel, vaddrelu) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册