diff --git a/paddle/fluid/operators/math/fc_compute.h b/paddle/fluid/operators/math/fc_compute.h index 87220d4019fc9337fb8355172ca4f1372cfd4558..b072b4c20a171d148bd892c162436d03da404fb9 100644 --- a/paddle/fluid/operators/math/fc_compute.h +++ b/paddle/fluid/operators/math/fc_compute.h @@ -36,7 +36,7 @@ inline void FCCompute(const BlasT& blas, const int M, .template Get>(N); for (int i = 0; i < M; i++) { T* dst = Y + i * N; - vaddrelu->Compute(B, dst, dst); + vaddrelu->Compute(B, dst, dst, N); } } else { const auto& vadd = jitkernel::KernelPool::Instance() @@ -47,7 +47,7 @@ inline void FCCompute(const BlasT& blas, const int M, #endif for (int i = 0; i < M; i++) { T* dst = Y + i * N; - vadd->Compute(B, dst, dst); + vadd->Compute(B, dst, dst, N); } } } diff --git a/paddle/fluid/operators/math/jit_code.cc b/paddle/fluid/operators/math/jit_code.cc index 9e2cc18c7a5e396be40b2336382f68a17f8a2bf9..a92e5d351e71a55bca2845ce275780950d096031 100644 --- a/paddle/fluid/operators/math/jit_code.cc +++ b/paddle/fluid/operators/math/jit_code.cc @@ -24,19 +24,29 @@ namespace gen { using namespace platform::jit; // NOLINT -bool VMulJitCode::init(int d) { +bool VVVJitCode::init(int d) { // It's not necessary to use avx512 since it would slow down the frequency // and this kernel is not compute bound. return MayIUse(avx); } -void VMulJitCode::generate() { +void VVVJitCode::generate() { // do not need push stack, and do not need save avx512reg if do not use avx512 int offset = 0; + if (with_relu_) { + vxorps(ymm_zero, ymm_zero, ymm_zero); + } for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) { vmovups(ymm_src1, ptr[param1 + offset]); vmovups(ymm_src2, ptr[param2 + offset]); - vmulps(ymm_dst, ymm_src1, ymm_src2); + if (type_ == operand_type::mul) { + vmulps(ymm_dst, ymm_src1, ymm_src2); + } else if (type_ == operand_type::add) { + vaddps(ymm_dst, ymm_src1, ymm_src2); + } + if (with_relu_) { + vmaxps(ymm_dst, ymm_zero, ymm_dst); + } vmovups(ptr[param3 + offset], ymm_dst); offset += sizeof(float) * AVX_FLOAT_BLOCK; } @@ -44,7 +54,14 @@ void VMulJitCode::generate() { if (rest >= 4) { vmovups(xmm_src1, ptr[param1 + offset]); vmovups(xmm_src2, ptr[param2 + offset]); - vmulps(xmm_dst, xmm_src1, xmm_src2); + if (type_ == operand_type::mul) { + vmulps(xmm_dst, xmm_src1, xmm_src2); + } else if (type_ == operand_type::add) { + vaddps(xmm_dst, xmm_src1, xmm_src2); + } + if (with_relu_) { + vmaxps(xmm_dst, xmm_zero, xmm_dst); + } vmovups(ptr[param3 + offset], xmm_dst); offset += sizeof(float) * 4; rest -= 4; @@ -52,7 +69,14 @@ void VMulJitCode::generate() { if (rest >= 2) { vmovq(xmm_src1, ptr[param1 + offset]); vmovq(xmm_src2, ptr[param2 + offset]); - vmulps(xmm_dst, xmm_src1, xmm_src2); + if (type_ == operand_type::mul) { + vmulps(xmm_dst, xmm_src1, xmm_src2); + } else if (type_ == operand_type::add) { + vaddps(xmm_dst, xmm_src1, xmm_src2); + } + if (with_relu_) { + vmaxps(xmm_dst, xmm_zero, xmm_dst); + } vmovq(ptr[param3 + offset], xmm_dst); offset += sizeof(float) * 2; rest -= 2; @@ -60,12 +84,18 @@ void VMulJitCode::generate() { if (rest > 0) { vmovss(xmm_src1, ptr[param1 + offset]); vmovss(xmm_src2, ptr[param2 + offset]); - vmulss(xmm_dst, xmm_src1, xmm_src2); + if (type_ == operand_type::mul) { + vmulss(xmm_dst, xmm_src1, xmm_src2); + } else if (type_ == operand_type::add) { + vaddss(xmm_dst, xmm_src1, xmm_src2); + } + if (with_relu_) { + vmaxps(xmm_dst, xmm_zero, xmm_dst); + } vmovss(ptr[param3 + offset], xmm_dst); } ret(); } - } // namespace gen } // namespace jitkernel } // namespace math diff --git a/paddle/fluid/operators/math/jit_code.h b/paddle/fluid/operators/math/jit_code.h index 6007b290815de0ceaa2226962c5273ae7da72e7e..73692ebc67c71f6190f2d18bd50071a28a35d4c9 100644 --- a/paddle/fluid/operators/math/jit_code.h +++ b/paddle/fluid/operators/math/jit_code.h @@ -14,8 +14,8 @@ limitations under the License. */ #pragma once +#include #include "paddle/fluid/operators/math/jit_gen.h" - namespace paddle { namespace operators { namespace math { @@ -29,28 +29,47 @@ using ymm_t = const Xbyak::Ymm; using zmm_t = const Xbyak::Zmm; using Label = Xbyak::Label; -class VMulJitCode : public JitCode { +// function: vec = Operand(vec, vec) (maybe with relu) +typedef enum { mul = 0, add } operand_type; + +class VVVJitCode : public JitCode { public: - DECLARE_JIT_CODE(VMulJitCode); - explicit VMulJitCode(int d, size_t code_size = 256 * 1024, - void* code_ptr = nullptr) - : JitCode(code_size, code_ptr), num_(d) {} + const char* name() const override { + std::string base = "VVVJitCode"; + if (type_ == operand_type::mul) { + base += "_Mul"; + } else if (type_ == operand_type::add) { + base += "_Add"; + } + base += (with_relu_ ? "_relu" : ""); + return base.c_str(); + } + explicit VVVJitCode(int d, operand_type type, bool with_relu, + size_t code_size = 256 * 1024, void* code_ptr = nullptr) + : JitCode(code_size, code_ptr), + num_(d), + type_(type), + with_relu_(with_relu) {} static bool init(int d); void generate() override; private: int num_; + operand_type type_; + bool with_relu_; reg64_t param1{abi_param1}; reg64_t param2{abi_param2}; reg64_t param3{abi_param3}; xmm_t xmm_src1 = xmm_t(0); xmm_t xmm_src2 = xmm_t(1); - xmm_t xmm_dst = xmm_t(2); + xmm_t xmm_dst = xmm_t(1); + xmm_t xmm_zero = xmm_t(2); ymm_t ymm_src1 = ymm_t(0); ymm_t ymm_src2 = ymm_t(1); - ymm_t ymm_dst = ymm_t(2); + ymm_t ymm_dst = ymm_t(1); + ymm_t ymm_zero = ymm_t(2); }; } // namespace gen diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index 7b6027aa267803ff8ff830deabda536b1b27fec8..04e0b81d3e7c696ac2f5ee78db90fb3c89ab345d 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -71,26 +71,26 @@ class VMulKernel : public Kernel { template class VAddKernel : public Kernel { public: - virtual void Compute(const T *x, const T *y, T *z) const = 0; + void (*Compute)(const T *, const T *, T *, int); }; template -class VScalKernel : public Kernel { +class VAddReluKernel : public Kernel { public: - virtual void Compute(const T a, const T *x, T *y) const = 0; - virtual void Compute(const T a, T *x) const = 0; + void (*Compute)(const T *, const T *, T *, int); }; template -class VAddBiasKernel : public Kernel { +class VScalKernel : public Kernel { public: virtual void Compute(const T a, const T *x, T *y) const = 0; + virtual void Compute(const T a, T *x) const = 0; }; template -class VAddReluKernel : public Kernel { +class VAddBiasKernel : public Kernel { public: - virtual void Compute(const T *x, const T *y, T *z) const = 0; + virtual void Compute(const T a, const T *x, T *y) const = 0; }; template diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc index 8a988f8f482e4a4963f70c39bccd89387c1e0059..9acb349f663cca1d38fa7840c3308dfa17a43ab1 100644 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -42,6 +42,21 @@ void VMulRefer(const T* x, const T* y, T* z, int n) { } } +template +void VAddRefer(const T* x, const T* y, T* z, int n) { + for (int i = 0; i < n; ++i) { + z[i] = x[i] + y[i]; + } +} + +template +void VAddReluRefer(const T* x, const T* y, T* z, int n) { + for (int i = 0; i < n; ++i) { + z[i] = x[i] + y[i]; + z[i] = z[i] > 0 ? z[i] : 0; + } +} + #ifdef PADDLE_WITH_MKLML template void VMulMKL(const T* x, const T* y, T* z, int n); @@ -50,28 +65,45 @@ template <> void VMulMKL(const float* x, const float* y, float* z, int n) { platform::dynload::vsMul(n, x, y, z); } + template <> void VMulMKL(const double* x, const double* y, double* z, int n) { platform::dynload::vdMul(n, x, y, z); } + +template +void VAddMKL(const T* x, const T* y, T* z, int n); + +template <> +void VAddMKL(const float* x, const float* y, float* z, int n) { + platform::dynload::vsAdd(n, x, y, z); +} + +template <> +void VAddMKL(const double* x, const double* y, double* z, int n) { + platform::dynload::vdAdd(n, x, y, z); +} #endif +#define DECLARE_STATIC_FUNC \ + static inline std::string name(int d) { \ + PADDLE_THROW("DType should be either float or double"); \ + } \ + static inline bool useJIT(int d) { return false; } \ + static inline bool useMKL(int d) { return false; } + /* VMUL JitKernel */ template class VMulKernelImpl : public VMulKernel { public: - static inline std::string name(int d) { - PADDLE_THROW("DType should be either float or double"); - } - static inline bool useJIT(int d) { return false; } - static inline bool useMKL(int d) { return false; } - + DECLARE_STATIC_FUNC; explicit VMulKernelImpl(int d) : VMulKernel() { #ifdef PADDLE_WITH_XBYAK if (useJIT(d)) { // roughly estimate the size of code size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; - jitcode_.reset(new gen::VMulJitCode(d, sz > 4096 ? sz : 4096)); + jitcode_.reset(new gen::VVVJitCode(d, gen::operand_type::mul, false, + sz > 4096 ? sz : 4096)); this->Compute = jitcode_->getCode(); return; @@ -89,14 +121,14 @@ class VMulKernelImpl : public VMulKernel { #ifdef PADDLE_WITH_XBYAK private: - std::unique_ptr jitcode_{nullptr}; + std::unique_ptr jitcode_{nullptr}; #endif }; #ifdef PADDLE_WITH_XBYAK template <> bool VMulKernelImpl::useJIT(int d) { - return gen::VMulJitCode::init(d); + return gen::VVVJitCode::init(d); } #endif @@ -112,63 +144,89 @@ bool VMulKernelImpl::useMKL(int d) { } #endif -REGISTER_JITKERNEL(vmul, VMulKernel); - -/* VADD JitKernel */ -template +/* VAdd JitKernel */ +template class VAddKernelImpl : public VAddKernel { public: - explicit VAddKernelImpl(int d) : VAddKernel() { this->num_ = d; } - void Compute(const T* x, const T* y, T* z) const override { - for (int i = 0; i < this->num_; ++i) { - z[i] = x[i] + y[i]; + DECLARE_STATIC_FUNC; + explicit VAddKernelImpl(int d) : VAddKernel() { +#ifdef PADDLE_WITH_XBYAK + if (useJIT(d)) { + size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; + jitcode_.reset(new gen::VVVJitCode(d, gen::operand_type::add, false, + sz > 4096 ? sz : 4096)); + this->Compute = + jitcode_->getCode(); + return; } - } -}; - +#endif #ifdef PADDLE_WITH_MKLML -#define MKL_FLOAT(isa, block) \ - template <> \ - void VAddKernelImpl::Compute( \ - const float* x, const float* y, float* z) const { \ - platform::dynload::vsAdd(this->num_, x, y, z); \ + if (useMKL(d)) { + this->Compute = VAddMKL; + return; + } +#endif + this->Compute = VAddRefer; } -#define MKL_DOUBLE(isa, block) \ - template <> \ - void VAddKernelImpl::Compute( \ - const double* x, const double* y, double* z) const { \ - platform::dynload::vdAdd(this->num_, x, y, z); \ - } + private: + std::unique_ptr jitcode_{nullptr}; +}; -FOR_EACH_ISA(MKL_FLOAT, kGT16); -FOR_EACH_ISA_BLOCK(MKL_DOUBLE); +#ifdef PADDLE_WITH_XBYAK +template <> +bool VAddKernelImpl::useJIT(int d) { + return gen::VVVJitCode::init(d); +} #endif -#define INTRI8_FLOAT(isa) \ - template <> \ - void VAddKernelImpl::Compute( \ - const float* x, const float* y, float* z) const { \ - __m256 tmpx, tmpy; \ - tmpx = _mm256_loadu_ps(x); \ - tmpy = _mm256_loadu_ps(y); \ - tmpx = _mm256_add_ps(tmpx, tmpy); \ - _mm256_storeu_ps(z, tmpx); \ - } -#ifdef __AVX__ -INTRI8_FLOAT(jit::avx); +#ifdef PADDLE_WITH_MKLML +template <> +bool VAddKernelImpl::useMKL(int d) { + return d > 512; +} + +template <> +bool VAddKernelImpl::useMKL(int d) { + return true; +} #endif -#ifdef __AVX2__ -INTRI8_FLOAT(jit::avx2); + +/* VAddRelu JitKernel */ +template +class VAddReluKernelImpl : public VAddReluKernel { + public: + DECLARE_STATIC_FUNC; + explicit VAddReluKernelImpl(int d) : VAddReluKernel() { +#ifdef PADDLE_WITH_XBYAK + if (useJIT(d)) { + size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; + jitcode_.reset(new gen::VVVJitCode(d, gen::operand_type::add, true, + sz > 4096 ? sz : 4096)); + this->Compute = + jitcode_->getCode(); + return; + } #endif -#ifdef __AVX512F__ -INTRI8_FLOAT(jit::avx512f); + this->Compute = VAddReluRefer; + } + + private: + std::unique_ptr jitcode_{nullptr}; +}; + +#ifdef PADDLE_WITH_XBYAK +template <> +bool VAddReluKernelImpl::useJIT(int d) { + return gen::VVVJitCode::init(d); +} #endif -// TODO(TJ): eq16 test and complete avx512 -#undef INTRI8_FLOAT -#undef MKL_FLOAT -#undef MKL_DOUBLE +#undef DECLARE_STATIC_FUNC + +REGISTER_JITKERNEL(vmul, VMulKernel); +REGISTER_JITKERNEL(vadd, VAddKernel); +REGISTER_JITKERNEL(vaddrelu, VAddReluKernel); /* VSCAL JitKernel */ template @@ -405,98 +463,9 @@ class VIdentityKernelImpl : public VIdentityKernel { void Compute(const T* x, T* y) const override {} }; -/* VAddRelu JitKernel */ -template -class VAddReluKernelImpl : public VAddReluKernel { - public: - explicit VAddReluKernelImpl(int d) : VAddReluKernel() { this->num_ = d; } - void Compute(const T* x, const T* y, T* z) const override { - for (int i = 0; i < this->num_; ++i) { - z[i] = x[i] + y[i]; - z[i] = z[i] > 0 ? z[i] : 0; - } - } -}; - -#define INTRI8_FLOAT(isa) \ - template <> \ - void VAddReluKernelImpl::Compute( \ - const float* x, const float* y, float* z) const { \ - __m256 tmpx = _mm256_loadu_ps(x); \ - __m256 tmpy = _mm256_loadu_ps(y); \ - tmpy = _mm256_add_ps(tmpx, tmpy); \ - tmpy = _mm256_max_ps(tmpy, _mm256_setzero_ps()); \ - _mm256_storeu_ps(z, tmpy); \ - } - -#define INTRI16_FLOAT(isa) \ - template <> \ - void VAddReluKernelImpl::Compute( \ - const float* x, const float* y, float* z) const { \ - __m256 zeros = _mm256_setzero_ps(); \ - __m256 tmp0 = _mm256_loadu_ps(x); \ - __m256 tmp1 = _mm256_loadu_ps(y); \ - tmp0 = _mm256_add_ps(tmp0, tmp1); \ - tmp0 = _mm256_max_ps(tmp0, zeros); \ - tmp1 = _mm256_loadu_ps(x + 8); \ - __m256 tmp2 = _mm256_loadu_ps(y + 8); \ - tmp1 = _mm256_add_ps(tmp1, tmp2); \ - tmp1 = _mm256_max_ps(tmp1, zeros); \ - _mm256_storeu_ps(z, tmp0); \ - _mm256_storeu_ps(z + 8, tmp1); \ - } - -#define INTRI_COMMON_FLOAT(isa, block) \ - template <> \ - VAddReluKernelImpl::VAddReluKernelImpl(int d) \ - : VAddReluKernel() { \ - this->num_ = d; \ - this->end_ = d - d % AVX_FLOAT_BLOCK; \ - this->rest_ = d - this->end_; \ - } \ - template <> \ - void VAddReluKernelImpl::Compute( \ - const float* x, const float* y, float* z) const { \ - __m256 zeros = _mm256_setzero_ps(); \ - for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \ - __m256 tmpx = _mm256_loadu_ps(x + i); \ - __m256 tmpy = _mm256_loadu_ps(y + i); \ - tmpy = _mm256_add_ps(tmpx, tmpy); \ - tmpy = _mm256_max_ps(tmpy, zeros); \ - _mm256_storeu_ps(z + i, tmpy); \ - } \ - for (int i = this->end_; i < this->num_; ++i) { \ - z[i] = x[i] + y[i]; \ - z[i] = z[i] > 0 ? z[i] : 0; \ - } \ - } - -#ifdef __AVX__ -INTRI8_FLOAT(jit::avx); -INTRI16_FLOAT(jit::avx); -INTRI_COMMON_FLOAT(jit::avx, kGT16); -#endif -#ifdef __AVX2__ -INTRI8_FLOAT(jit::avx2); -INTRI16_FLOAT(jit::avx2); -INTRI_COMMON_FLOAT(jit::avx2, kGT16); -#endif -#ifdef __AVX512F__ -// TODO(TJ): refine avx512 -INTRI8_FLOAT(jit::avx512f); -INTRI16_FLOAT(jit::avx512f); -INTRI_COMMON_FLOAT(jit::avx512f, kGT16); -#endif - -#undef INTRI8_FLOAT -#undef INTRI16_FLOAT -#undef INTRI_COMMON_FLOAT - -REGISTER_JITKERNEL_DEPRECATED(vadd, VAddKernel); REGISTER_JITKERNEL_DEPRECATED(vscal, VScalKernel); REGISTER_JITKERNEL_DEPRECATED(vaddb, VAddBiasKernel); REGISTER_JITKERNEL_DEPRECATED(vrelu, VReluKernel); -REGISTER_JITKERNEL_DEPRECATED(vaddrelu, VAddReluKernel); REGISTER_JITKERNEL_DEPRECATED(videntity, VIdentityKernel); } // namespace jitkernel diff --git a/paddle/fluid/operators/math/jit_kernel_rnn.cc b/paddle/fluid/operators/math/jit_kernel_rnn.cc index d0932a37bb85bbc41f662a106c8ef5693a72efeb..ba3e917377cf12192a068a9d71238442e12d5e5e 100644 --- a/paddle/fluid/operators/math/jit_kernel_rnn.cc +++ b/paddle/fluid/operators/math/jit_kernel_rnn.cc @@ -181,7 +181,7 @@ class LSTMKernelImpl : public LSTMKernel { act_cand_d_->Compute(gates, gates); vmul_d_->Compute(gates, gates + d_, gates + d_, d_); vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_); - vadd_d_->Compute(gates + d_, gates + d2_, ct); + vadd_d_->Compute(gates + d_, gates + d2_, ct, d_); /* H_t = act_cell(C_t) * ogated */ act_cell_d_->Compute(ct, gates + d2_); @@ -291,16 +291,16 @@ class PeepholeKernelImpl : public LSTMKernel { /* get fgated and igated*/ vmul_d_->Compute(wp_data, ct_1, checked, d_); vmul_d_->Compute(wp_data + d_, ct_1, checked + d_, d_); - vadd_d2_->Compute(checked, gates + d_, gates + d_); + vadd_d2_->Compute(checked, gates + d_, gates + d_, d2_); act_gate_d2_->Compute(gates + d_, gates + d_); /* C_t = C_t-1 * fgated + cand_gated * igated*/ act_cand_d_->Compute(gates, gates); vmul_d_->Compute(gates, gates + d_, gates + d_, d_); vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_); - vadd_d_->Compute(gates + d_, gates + d2_, ct); + vadd_d_->Compute(gates + d_, gates + d2_, ct, d_); /* get ogated*/ vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_); - vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_); + vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_); act_gate_d_->Compute(gates + d3_, gates + d3_); /* H_t = act_cell(C_t) * ogated */ act_cell_d_->Compute(ct, gates + d2_); @@ -314,7 +314,7 @@ class PeepholeKernelImpl : public LSTMKernel { vmul_d_->Compute(gates, gates + d_, ct, d_); /* get outgated, put W_oc * C_t on igated */ vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_); - vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_); + vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_); /* H_t = act_cell(C_t) * ogated */ act_gate_d_->Compute(gates + d3_, gates + d3_); act_cell_d_->Compute(ct, gates + d2_); diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc index 34fa2b9a7814dbd96de1e7c4a5be5a88978445bd..9a19424691fad70c161ca6036c5cdfd3b2b22ada 100644 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ b/paddle/fluid/operators/math/jit_kernel_test.cc @@ -371,7 +371,7 @@ void lstm_ctht_better( vtanh_d->Compute(gates, gates); vmul_d->Compute(gates, gates + d, gates + d, d); vmul_d->Compute(ct_1, gates + d2, gates + d2, d); - vadd_d->Compute(gates + d, gates + d2, ct); + vadd_d->Compute(gates + d, gates + d2, ct, d); /* H_t = act_cell(C_t) * ogated */ vtanh_d->Compute(ct, gates + d2); vmul_d->Compute(gates + d2, gates + d * 3, ht, d); @@ -695,7 +695,7 @@ TEST(JitKernel, vadd) { auto ttgts = GetCurrentUS(); for (int i = 0; i < repeat; ++i) { - ker->Compute(x_data, y_data, ztgt_data); + ker->Compute(x_data, y_data, ztgt_data, d); } auto ttgte = GetCurrentUS(); @@ -723,8 +723,8 @@ void vaddrelu_better( const paddle::operators::math::jitkernel::VAddKernel>& vadd, const std::shared_ptr< const paddle::operators::math::jitkernel::VReluKernel>& vrelu, - const float* x, const float* y, float* z) { - vadd->Compute(x, y, z); + const float* x, const float* y, float* z, int d) { + vadd->Compute(x, y, z, d); vrelu->Compute(z, z); } @@ -752,12 +752,12 @@ TEST(JitKernel, vaddrelu) { auto trefe = GetCurrentUS(); auto tmkls = GetCurrentUS(); for (int i = 0; i < repeat; ++i) { - vaddrelu_better(vadd, vrelu, x_data, y_data, zref_data); + vaddrelu_better(vadd, vrelu, x_data, y_data, zref_data, d); } auto tmkle = GetCurrentUS(); auto ttgts = GetCurrentUS(); for (int i = 0; i < repeat; ++i) { - ker->Compute(x_data, y_data, ztgt_data); + ker->Compute(x_data, y_data, ztgt_data, d); } auto ttgte = GetCurrentUS(); VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat