提交 ba3eaed7 编写于 作者: T tensor-tang

exp support all size

上级 d239801b
...@@ -81,10 +81,10 @@ void VXXJitCode::generate() { ...@@ -81,10 +81,10 @@ void VXXJitCode::generate() {
} }
if (rest >= 2) { if (rest >= 2) {
if (scalar_index_ != 1) { if (scalar_index_ != 1) {
vmovups(xmm_src1, ptr[param1 + offset]); vmovq(xmm_src1, ptr[param1 + offset]);
} }
if (scalar_index_ != 2) { if (scalar_index_ != 2) {
vmovups(xmm_src2, ptr[param2 + offset]); vmovq(xmm_src2, ptr[param2 + offset]);
} }
if (type_ == operand_type::mul) { if (type_ == operand_type::mul) {
vmulps(xmm_dst, xmm_src1, xmm_src2); vmulps(xmm_dst, xmm_src1, xmm_src2);
...@@ -100,10 +100,10 @@ void VXXJitCode::generate() { ...@@ -100,10 +100,10 @@ void VXXJitCode::generate() {
} }
if (rest > 0) { if (rest > 0) {
if (scalar_index_ != 1) { if (scalar_index_ != 1) {
vmovups(xmm_src1, ptr[param1 + offset]); vmovss(xmm_src1, ptr[param1 + offset]);
} }
if (scalar_index_ != 2) { if (scalar_index_ != 2) {
vmovups(xmm_src2, ptr[param2 + offset]); vmovss(xmm_src2, ptr[param2 + offset]);
} }
if (type_ == operand_type::mul) { if (type_ == operand_type::mul) {
vmulss(xmm_dst, xmm_src1, xmm_src2); vmulss(xmm_dst, xmm_src1, xmm_src2);
...@@ -179,7 +179,7 @@ bool VActJitCode::init(int d, operand_type type) { ...@@ -179,7 +179,7 @@ bool VActJitCode::init(int d, operand_type type) {
return ok; return ok;
} else if (type == operand_type::exp) { } else if (type == operand_type::exp) {
// exp is slower than mkl when d >= 256 // exp is slower than mkl when d >= 256
return ok && d % 8 == 0 && d < 256; return ok; //&& d % 4 == 0 && d < 256;
} else { } else {
// TODO(TJ): support more // TODO(TJ): support more
return ok && d % 8 == 0; return ok && d % 8 == 0;
...@@ -190,6 +190,10 @@ void VActJitCode::relu_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, ymm_t& ymm_zero) { ...@@ -190,6 +190,10 @@ void VActJitCode::relu_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, ymm_t& ymm_zero) {
vmaxps(ymm_dst, ymm_zero, ymm_src); vmaxps(ymm_dst, ymm_zero, ymm_src);
} }
void VActJitCode::relu_xmm(xmm_t& xmm_dst, xmm_t& xmm_src, xmm_t& xmm_zero) {
vmaxps(xmm_dst, xmm_zero, xmm_src);
}
void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx, void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
int fy_idx, int mask_idx, int tmp_idx) { int fy_idx, int mask_idx, int tmp_idx) {
assert(ymm_src.getIdx() != ymm_dst.getIdx()); // TODO(TJ): use enfore assert(ymm_src.getIdx() != ymm_dst.getIdx()); // TODO(TJ): use enfore
...@@ -271,6 +275,65 @@ void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx, ...@@ -271,6 +275,65 @@ void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
pop(reg_ptr_global); pop(reg_ptr_global);
} }
void VActJitCode::exp_xmm(xmm_t& ymm_dst, xmm_t& ymm_src, int fx_idx,
int fy_idx, int mask_idx, int tmp_idx) {
assert(ymm_src.getIdx() != ymm_dst.getIdx()); // TODO(TJ): use enfore
// check all idx can not equal
xmm_t ymm_fx = xmm_t(fx_idx);
xmm_t ymm_fy = xmm_t(fy_idx);
xmm_t ymm_mask = xmm_t(mask_idx);
xmm_t ymm_tmp = xmm_t(tmp_idx);
reg64_t reg_ptr_global = rax;
push(reg_ptr_global);
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]);
vminps(ymm_src, ymm_src, ymm_tmp);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOW]);
vmaxps(ymm_src, ymm_src, ymm_tmp);
// express exp(x) as exp(g + n*log(2))
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOG2EF]);
vmulps(ymm_fx, ymm_src, ymm_tmp);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_0P5]);
vaddps(ymm_fx, ymm_fx, ymm_tmp);
vroundps(ymm_fy, ymm_fx, 0x01);
// if greater, substract 1
vcmpgtps(ymm_mask, ymm_fy, ymm_fx);
vmovaps(ymm_tmp, ptr[reg_ptr_global]);
vandps(ymm_mask, ymm_mask, ymm_tmp);
vsubps(ymm_fx, ymm_fy, ymm_mask);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C1]);
vmulps(ymm_fy, ymm_fx, ymm_tmp);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C2]);
xmm_t ymm_z = xmm_t(ymm_mask.getIdx());
vmulps(ymm_z, ymm_fx, ymm_tmp);
vsubps(ymm_src, ymm_src, ymm_fy);
vsubps(ymm_src, ymm_src, ymm_z);
vmulps(ymm_z, ymm_src, ymm_src);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P0]);
vmulps(ymm_dst, ymm_src, ymm_tmp);
for (size_t i = OFFSET_EXP_P1; i < OFFSET_EXP_P5;
i += (YMM_FLOAT_BLOCK * sizeof(float))) {
vmovaps(ymm_tmp, ptr[reg_ptr_global + i]); // P1~P4
vaddps(ymm_dst, ymm_dst, ymm_tmp);
vmulps(ymm_dst, ymm_dst, ymm_src);
}
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P5]);
vaddps(ymm_dst, ymm_dst, ymm_tmp);
vmulps(ymm_dst, ymm_dst, ymm_z);
vaddps(ymm_dst, ymm_dst, ymm_src);
vmovaps(ymm_tmp, ptr[reg_ptr_global]);
vaddps(ymm_dst, ymm_dst, ymm_tmp);
// build 2^n
xmm_t ymm_int = ymm_fx;
vcvttps2dq(ymm_int, ymm_fx);
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_int_0x7f));
vmovdqa(ymm_tmp, ptr[reg_ptr_global]);
vpaddd(ymm_int, ymm_int, ymm_tmp);
vpslld(ymm_int, ymm_int, 23);
vmulps(ymm_dst, ymm_dst, ymm_int);
pop(reg_ptr_global);
}
void VActJitCode::sigmoid_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx, void VActJitCode::sigmoid_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
int fy_idx, int mask_idx, int tmp_idx) { int fy_idx, int mask_idx, int tmp_idx) {
// y = 1 / (1 + e^-x) // y = 1 / (1 + e^-x)
...@@ -343,7 +406,7 @@ void VActJitCode::generate() { ...@@ -343,7 +406,7 @@ void VActJitCode::generate() {
vmovups(ptr[param2 + offset], ymm_dst); vmovups(ptr[param2 + offset], ymm_dst);
offset += sizeof(float) * YMM_FLOAT_BLOCK; offset += sizeof(float) * YMM_FLOAT_BLOCK;
} }
if (type_ != operand_type::relu) { if (type_ != operand_type::relu && type_ != operand_type::exp) {
// TODO(TJ): remove me // TODO(TJ): remove me
ret(); ret();
return; return;
...@@ -351,21 +414,50 @@ void VActJitCode::generate() { ...@@ -351,21 +414,50 @@ void VActJitCode::generate() {
int rest = num_ % YMM_FLOAT_BLOCK; int rest = num_ % YMM_FLOAT_BLOCK;
if (rest >= 4) { if (rest >= 4) {
vmovups(xmm_src, ptr[param1 + offset]); vmovups(xmm_src, ptr[param1 + offset]);
vmaxps(xmm_dst, xmm_zero, xmm_src); switch (type_) {
case operand_type::relu:
relu_xmm(xmm_dst, xmm_src, xmm_zero);
break;
case operand_type::exp:
exp_xmm(xmm_dst, xmm_src, 2, 3, 4, 5);
break;
default:
break;
}
vmovups(ptr[param2 + offset], xmm_dst); vmovups(ptr[param2 + offset], xmm_dst);
offset += sizeof(float) * 4; offset += sizeof(float) * 4;
rest -= 4; rest -= 4;
} }
if (rest >= 2) { if (rest >= 2) {
vmovups(xmm_src, ptr[param1 + offset]); vmovq(xmm_src, ptr[param1 + offset]);
vmaxps(xmm_dst, xmm_zero, xmm_src); switch (type_) {
case operand_type::relu:
relu_xmm(xmm_dst, xmm_src, xmm_zero);
break;
case operand_type::exp:
exp_xmm(xmm_dst, xmm_src, 2, 3, 4, 5);
break;
default:
break;
}
vmovq(ptr[param2 + offset], xmm_dst); vmovq(ptr[param2 + offset], xmm_dst);
offset += sizeof(float) * 2; offset += sizeof(float) * 2;
rest -= 2; rest -= 2;
} }
if (rest > 0) { if (rest > 0) {
vmovups(xmm_src, ptr[param1 + offset]); // vmovups();
vmaxps(xmm_dst, xmm_zero, xmm_src); vmovss(xmm_src, ptr[param1 + offset]);
switch (type_) {
case operand_type::relu:
relu_xmm(xmm_dst, xmm_src, xmm_zero);
break;
case operand_type::exp:
exp_xmm(xmm_dst, xmm_src, 2, 3, 4, 5);
break;
default:
break;
}
vmovss(ptr[param2 + offset], xmm_dst); vmovss(ptr[param2 + offset], xmm_dst);
} }
ret(); ret();
......
...@@ -127,13 +127,17 @@ class VActJitCode : public JitCode { ...@@ -127,13 +127,17 @@ class VActJitCode : public JitCode {
void generate() override; void generate() override;
protected: protected:
// compute relu with ymm // compute relu with ymm, xmm
void relu_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, void relu_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src,
const Xbyak::Ymm& zero); const Xbyak::Ymm& zero);
void relu_xmm(const Xbyak::Xmm& dst, const Xbyak::Xmm& src,
const Xbyak::Xmm& zero);
// compute exp with ymm // compute exp with ymm, xmm
void exp_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2, void exp_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2,
int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5); int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5);
void exp_xmm(const Xbyak::Xmm& dst, const Xbyak::Xmm& src, int fx_idx = 2,
int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5);
// compute sigmoid with ymm // compute sigmoid with ymm
void sigmoid_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2, void sigmoid_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2,
......
...@@ -33,6 +33,9 @@ limitations under the License. */ ...@@ -33,6 +33,9 @@ limitations under the License. */
constexpr int repeat = 20000; constexpr int repeat = 20000;
// TODO(TJ): benchmark and test should be seperated,
// benchmark should verify more sizes
inline double GetCurrentUS() { inline double GetCurrentUS() {
struct timeval time; struct timeval time;
gettimeofday(&time, NULL); gettimeofday(&time, NULL);
...@@ -156,7 +159,7 @@ void vexp_mkl(const int n, const float* x, float* y) { ...@@ -156,7 +159,7 @@ void vexp_mkl(const int n, const float* x, float* y) {
TEST(JitKernel, vexp) { TEST(JitKernel, vexp) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
for (int d : {7, 8, 15, 16, 30, 128, 256}) { for (int d : {7, 8, 12, 15, 16, 20, 30, 128, 256}) {
std::vector<float> x(d); std::vector<float> x(d);
std::vector<float> zref(d), ztgt(d); std::vector<float> zref(d), ztgt(d);
RandomVec<float>(d, x.data(), -2.f, 2.f); RandomVec<float>(d, x.data(), -2.f, 2.f);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册