From d3eae8f61b26c4fa053a74ce35aeb241db2c3b3b Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Fri, 16 Nov 2018 14:58:43 +0000 Subject: [PATCH] refine relu and fix addrelu test --- paddle/fluid/operators/math/jit_code.cc | 12 ++---------- paddle/fluid/operators/math/jit_code.h | 8 ++++---- paddle/fluid/operators/math/jit_kernel_test.cc | 2 +- 3 files changed, 7 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/operators/math/jit_code.cc b/paddle/fluid/operators/math/jit_code.cc index a5eef019c8..2a10cd7821 100644 --- a/paddle/fluid/operators/math/jit_code.cc +++ b/paddle/fluid/operators/math/jit_code.cc @@ -177,14 +177,6 @@ bool VActJitCode::init(int d, operand_type type) { } } -void VActJitCode::relu_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, ymm_t& ymm_zero) { - vmaxps(ymm_dst, ymm_zero, ymm_src); -} - -void VActJitCode::relu_xmm(xmm_t& xmm_dst, xmm_t& xmm_src, xmm_t& xmm_zero) { - vmaxps(xmm_dst, xmm_zero, xmm_src); -} - void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx, int fy_idx, int mask_idx, int tmp_idx) { assert(ymm_src.getIdx() != ymm_dst.getIdx()); // TODO(TJ): use enfore @@ -378,7 +370,7 @@ void VActJitCode::generate() { vmovups(ymm_src, ptr[param1 + offset]); switch (type_) { case operand_type::relu: - relu_ymm(ymm_dst, ymm_src, ymm_zero); + relu_jmm(ymm_dst, ymm_src, ymm_zero); break; case operand_type::exp: exp_ymm(ymm_dst, ymm_src, 2, 3, 4, 5); @@ -414,7 +406,7 @@ void VActJitCode::generate() { } switch (type_) { case operand_type::relu: - relu_xmm(xmm_dst, xmm_src, xmm_zero); + relu_jmm(xmm_dst, xmm_src, xmm_zero); break; case operand_type::exp: exp_xmm(xmm_dst, xmm_src, 2, 3, 4, 5); diff --git a/paddle/fluid/operators/math/jit_code.h b/paddle/fluid/operators/math/jit_code.h index 1467978f26..6adeebca7c 100644 --- a/paddle/fluid/operators/math/jit_code.h +++ b/paddle/fluid/operators/math/jit_code.h @@ -128,10 +128,10 @@ class VActJitCode : public JitCode { protected: // compute relu with ymm, xmm - void relu_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, - const Xbyak::Ymm& zero); - void relu_xmm(const Xbyak::Xmm& dst, const Xbyak::Xmm& src, - const Xbyak::Xmm& zero); + template + void relu_jmm(JMM& dst, JMM& src, JMM& zero) { // NOLINT + vmaxps(dst, src, zero); + } // compute exp with ymm, xmm void exp_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2, diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc index 178298bf56..932fa4c000 100644 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ b/paddle/fluid/operators/math/jit_kernel_test.cc @@ -762,7 +762,7 @@ TEST(JitKernel, vaddrelu) { float* zref_data = zref.data(); auto trefs = GetCurrentUS(); for (int i = 0; i < repeat; ++i) { - vadd_ref(d, x_data, y_data, zref_data); + vaddrelu_ref(d, x_data, y_data, zref_data); } auto trefe = GetCurrentUS(); auto tmkls = GetCurrentUS(); -- GitLab