提交 4dbdfa60 编写于 作者: T tensor-tang

sigmoid and tanh support all size

test=develop
上级 ccb89637
...@@ -132,56 +132,8 @@ const int exp_int_0x7f[] ALIGN32 = {REPEAT_8TIMES(0x7f)}; ...@@ -132,56 +132,8 @@ const int exp_int_0x7f[] ALIGN32 = {REPEAT_8TIMES(0x7f)};
int g_tmp_mem[16] ALIGN32 = {0}; int g_tmp_mem[16] ALIGN32 = {0};
bool VActJitCode::init(int d, operand_type type) { bool VActJitCode::init(int d, operand_type type) {
bool ok = MayIUse(avx);
if (type == operand_type::relu || type == operand_type::exp) {
// TODO(TJ): implement avx512, avx_exp is slower than mkl when d >= 256 // TODO(TJ): implement avx512, avx_exp is slower than mkl when d >= 256
return ok; return MayIUse(avx);
} else {
// TODO(TJ): support more
return ok && d % 8 == 0;
}
}
void VActJitCode::sigmoid_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
int fy_idx, int mask_idx, int tmp_idx) {
// y = 1 / (1 + e^-x)
ymm_t ymm_tmp = ymm_t(tmp_idx);
reg64_t reg_ptr_global = rax;
push(reg_ptr_global);
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MAX]);
vminps(ymm_src, ymm_src, ymm_tmp);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MIN]);
vmaxps(ymm_src, ymm_src, ymm_tmp);
vxorps(ymm_tmp, ymm_tmp, ymm_tmp);
vsubps(ymm_src, ymm_tmp, ymm_src);
exp_jmm<ymm_t>(ymm_dst, ymm_src, fx_idx, fy_idx, mask_idx, tmp_idx);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
vaddps(ymm_dst, ymm_dst, ymm_tmp);
vdivps(ymm_dst, ymm_tmp, ymm_dst);
pop(reg_ptr_global);
}
void VActJitCode::tanh_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
int fy_idx, int mask_idx, int tmp_idx) {
// y = 2 / (1 + e^(-2x)) - 1
ymm_t ymm_tmp = ymm_t(tmp_idx);
ymm_t ymm_zero = ymm_t(mask_idx);
reg64_t reg_ptr_global = rax;
push(reg_ptr_global);
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
vxorps(ymm_zero, ymm_zero, ymm_zero);
vsubps(ymm_tmp, ymm_zero, ymm_tmp);
vmulps(ymm_src, ymm_src, ymm_tmp);
exp_jmm<ymm_t>(ymm_dst, ymm_src, fx_idx, fy_idx, mask_idx, tmp_idx);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
vaddps(ymm_dst, ymm_dst, ymm_tmp);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
vdivps(ymm_dst, ymm_tmp, ymm_dst);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
vsubps(ymm_dst, ymm_dst, ymm_tmp);
pop(reg_ptr_global);
} }
void VActJitCode::generate() { void VActJitCode::generate() {
...@@ -201,10 +153,10 @@ void VActJitCode::generate() { ...@@ -201,10 +153,10 @@ void VActJitCode::generate() {
exp_jmm<ymm_t>(ymm_dst, ymm_src, 2, 3, 4, 5); exp_jmm<ymm_t>(ymm_dst, ymm_src, 2, 3, 4, 5);
break; break;
case operand_type::sigmoid: case operand_type::sigmoid:
sigmoid_ymm(ymm_dst, ymm_src, 2, 3, 4, 5); sigmoid_jmm<ymm_t>(ymm_dst, ymm_src, 2, 3, 4, 5);
break; break;
case operand_type::tanh: case operand_type::tanh:
tanh_ymm(ymm_dst, ymm_src, 2, 3, 4, 5); tanh_jmm<ymm_t>(ymm_dst, ymm_src, 2, 3, 4, 5);
break; break;
case operand_type::identity: case operand_type::identity:
break; break;
...@@ -214,11 +166,6 @@ void VActJitCode::generate() { ...@@ -214,11 +166,6 @@ void VActJitCode::generate() {
vmovups(ptr[param2 + offset], ymm_dst); vmovups(ptr[param2 + offset], ymm_dst);
offset += sizeof(float) * YMM_FLOAT_BLOCK; offset += sizeof(float) * YMM_FLOAT_BLOCK;
} }
if (type_ != operand_type::relu && type_ != operand_type::exp) {
// TODO(TJ): remove me
ret();
return;
}
int rest = num_ % YMM_FLOAT_BLOCK; int rest = num_ % YMM_FLOAT_BLOCK;
int block = XMM_FLOAT_BLOCK; int block = XMM_FLOAT_BLOCK;
while (rest > 0) { while (rest > 0) {
...@@ -236,6 +183,12 @@ void VActJitCode::generate() { ...@@ -236,6 +183,12 @@ void VActJitCode::generate() {
case operand_type::exp: case operand_type::exp:
exp_jmm<xmm_t>(xmm_dst, xmm_src, 2, 3, 4, 5); exp_jmm<xmm_t>(xmm_dst, xmm_src, 2, 3, 4, 5);
break; break;
case operand_type::sigmoid:
sigmoid_jmm<xmm_t>(xmm_dst, xmm_src, 2, 3, 4, 5);
break;
case operand_type::tanh:
tanh_jmm<xmm_t>(xmm_dst, xmm_src, 2, 3, 4, 5);
break;
default: default:
break; break;
} }
......
...@@ -263,13 +263,51 @@ class VActJitCode : public JitCode { ...@@ -263,13 +263,51 @@ class VActJitCode : public JitCode {
pop(reg_ptr_global); pop(reg_ptr_global);
} }
// compute sigmoid with ymm // compute sigmoid with ymm, xmm
void sigmoid_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2, template <typename JMM>
int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5); void sigmoid_jmm(JMM& dst, JMM& src, int fx_idx = 2, // NOLINT
int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5) {
// y = 1 / (1 + e^-x)
JMM jmm_tmp = JMM(tmp_idx);
reg64_t reg_ptr_global = rax;
push(reg_ptr_global);
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MAX]);
vminps(src, src, jmm_tmp);
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MIN]);
vmaxps(src, src, jmm_tmp);
vxorps(jmm_tmp, jmm_tmp, jmm_tmp);
vsubps(src, jmm_tmp, src);
exp_jmm<JMM>(dst, src, fx_idx, fy_idx, mask_idx, tmp_idx);
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
vaddps(dst, dst, jmm_tmp);
vdivps(dst, jmm_tmp, dst);
pop(reg_ptr_global);
}
// compute tanh with ymm // compute tanh with ymm, xmm
void tanh_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2, template <typename JMM>
int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5); void tanh_jmm(JMM& dst, JMM& src, int fx_idx = 2, int fy_idx = 3, // NOLINT
int mask_idx = 4, int tmp_idx = 5) {
// y = 2 / (1 + e^(-2x)) - 1
JMM jmm_tmp = JMM(tmp_idx);
JMM jmm_zero = JMM(mask_idx);
reg64_t reg_ptr_global = rax;
push(reg_ptr_global);
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
vxorps(jmm_zero, jmm_zero, jmm_zero);
vsubps(jmm_tmp, jmm_zero, jmm_tmp);
vmulps(src, src, jmm_tmp);
exp_jmm<JMM>(dst, src, fx_idx, fy_idx, mask_idx, tmp_idx);
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
vaddps(dst, dst, jmm_tmp);
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
vdivps(dst, jmm_tmp, dst);
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
vsubps(dst, dst, jmm_tmp);
pop(reg_ptr_global);
}
protected: protected:
int num_; int num_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册