提交 d3eae8f6 编写于 作者: T tensor-tang

refine relu and fix addrelu test

上级 4e67fe6a
...@@ -177,14 +177,6 @@ bool VActJitCode::init(int d, operand_type type) { ...@@ -177,14 +177,6 @@ bool VActJitCode::init(int d, operand_type type) {
} }
} }
void VActJitCode::relu_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, ymm_t& ymm_zero) {
vmaxps(ymm_dst, ymm_zero, ymm_src);
}
void VActJitCode::relu_xmm(xmm_t& xmm_dst, xmm_t& xmm_src, xmm_t& xmm_zero) {
vmaxps(xmm_dst, xmm_zero, xmm_src);
}
void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx, void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
int fy_idx, int mask_idx, int tmp_idx) { int fy_idx, int mask_idx, int tmp_idx) {
assert(ymm_src.getIdx() != ymm_dst.getIdx()); // TODO(TJ): use enfore assert(ymm_src.getIdx() != ymm_dst.getIdx()); // TODO(TJ): use enfore
...@@ -378,7 +370,7 @@ void VActJitCode::generate() { ...@@ -378,7 +370,7 @@ void VActJitCode::generate() {
vmovups(ymm_src, ptr[param1 + offset]); vmovups(ymm_src, ptr[param1 + offset]);
switch (type_) { switch (type_) {
case operand_type::relu: case operand_type::relu:
relu_ymm(ymm_dst, ymm_src, ymm_zero); relu_jmm<ymm_t>(ymm_dst, ymm_src, ymm_zero);
break; break;
case operand_type::exp: case operand_type::exp:
exp_ymm(ymm_dst, ymm_src, 2, 3, 4, 5); exp_ymm(ymm_dst, ymm_src, 2, 3, 4, 5);
...@@ -414,7 +406,7 @@ void VActJitCode::generate() { ...@@ -414,7 +406,7 @@ void VActJitCode::generate() {
} }
switch (type_) { switch (type_) {
case operand_type::relu: case operand_type::relu:
relu_xmm(xmm_dst, xmm_src, xmm_zero); relu_jmm<xmm_t>(xmm_dst, xmm_src, xmm_zero);
break; break;
case operand_type::exp: case operand_type::exp:
exp_xmm(xmm_dst, xmm_src, 2, 3, 4, 5); exp_xmm(xmm_dst, xmm_src, 2, 3, 4, 5);
......
...@@ -128,10 +128,10 @@ class VActJitCode : public JitCode { ...@@ -128,10 +128,10 @@ class VActJitCode : public JitCode {
protected: protected:
// compute relu with ymm, xmm // compute relu with ymm, xmm
void relu_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, template <typename JMM>
const Xbyak::Ymm& zero); void relu_jmm(JMM& dst, JMM& src, JMM& zero) { // NOLINT
void relu_xmm(const Xbyak::Xmm& dst, const Xbyak::Xmm& src, vmaxps(dst, src, zero);
const Xbyak::Xmm& zero); }
// compute exp with ymm, xmm // compute exp with ymm, xmm
void exp_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2, void exp_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2,
......
...@@ -762,7 +762,7 @@ TEST(JitKernel, vaddrelu) { ...@@ -762,7 +762,7 @@ TEST(JitKernel, vaddrelu) {
float* zref_data = zref.data(); float* zref_data = zref.data();
auto trefs = GetCurrentUS(); auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
vadd_ref(d, x_data, y_data, zref_data); vaddrelu_ref(d, x_data, y_data, zref_data);
} }
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
auto tmkls = GetCurrentUS(); auto tmkls = GetCurrentUS();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册