提交 d3eae8f6 编写于 作者: T tensor-tang

refine relu and fix addrelu test

上级 4e67fe6a
......@@ -177,14 +177,6 @@ bool VActJitCode::init(int d, operand_type type) {
}
}
void VActJitCode::relu_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, ymm_t& ymm_zero) {
vmaxps(ymm_dst, ymm_zero, ymm_src);
}
void VActJitCode::relu_xmm(xmm_t& xmm_dst, xmm_t& xmm_src, xmm_t& xmm_zero) {
vmaxps(xmm_dst, xmm_zero, xmm_src);
}
void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
int fy_idx, int mask_idx, int tmp_idx) {
assert(ymm_src.getIdx() != ymm_dst.getIdx()); // TODO(TJ): use enfore
......@@ -378,7 +370,7 @@ void VActJitCode::generate() {
vmovups(ymm_src, ptr[param1 + offset]);
switch (type_) {
case operand_type::relu:
relu_ymm(ymm_dst, ymm_src, ymm_zero);
relu_jmm<ymm_t>(ymm_dst, ymm_src, ymm_zero);
break;
case operand_type::exp:
exp_ymm(ymm_dst, ymm_src, 2, 3, 4, 5);
......@@ -414,7 +406,7 @@ void VActJitCode::generate() {
}
switch (type_) {
case operand_type::relu:
relu_xmm(xmm_dst, xmm_src, xmm_zero);
relu_jmm<xmm_t>(xmm_dst, xmm_src, xmm_zero);
break;
case operand_type::exp:
exp_xmm(xmm_dst, xmm_src, 2, 3, 4, 5);
......
......@@ -128,10 +128,10 @@ class VActJitCode : public JitCode {
protected:
// compute relu with ymm, xmm
void relu_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src,
const Xbyak::Ymm& zero);
void relu_xmm(const Xbyak::Xmm& dst, const Xbyak::Xmm& src,
const Xbyak::Xmm& zero);
template <typename JMM>
void relu_jmm(JMM& dst, JMM& src, JMM& zero) { // NOLINT
vmaxps(dst, src, zero);
}
// compute exp with ymm, xmm
void exp_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2,
......
......@@ -762,7 +762,7 @@ TEST(JitKernel, vaddrelu) {
float* zref_data = zref.data();
auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
vadd_ref(d, x_data, y_data, zref_data);
vaddrelu_ref(d, x_data, y_data, zref_data);
}
auto trefe = GetCurrentUS();
auto tmkls = GetCurrentUS();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册