From 2d0ff6a3c265067208d53b4ef5faffb474a6508f Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Fri, 28 Sep 2018 23:16:41 +0800 Subject: [PATCH] add vexp and unit test --- paddle/fluid/operators/math/CMakeLists.txt | 3 +- paddle/fluid/operators/math/jit_kernel.h | 6 + .../fluid/operators/math/jit_kernel_blas.cc | 158 +++++------------- paddle/fluid/operators/math/jit_kernel_exp.cc | 115 +++++++++++++ .../fluid/operators/math/jit_kernel_macro.h | 94 +++++++++++ .../fluid/operators/math/jit_kernel_test.cc | 63 ++++++- 6 files changed, 318 insertions(+), 121 deletions(-) create mode 100644 paddle/fluid/operators/math/jit_kernel_exp.cc create mode 100644 paddle/fluid/operators/math/jit_kernel_macro.h diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index 9763d14d54a..2a389ea1c84 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -76,5 +76,6 @@ if(WITH_GPU) endif() cc_test(concat_test SRCS concat_test.cc DEPS concat) cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) -cc_library(jit_kernel SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_lstm.cc DEPS cpu_info cblas) +cc_library(jit_kernel_exp SRCS jit_kernel_exp.cc DEPS cpu_info cblas activation_functions) +cc_library(jit_kernel SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_lstm.cc DEPS cpu_info cblas jit_kernel_exp) cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel) diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index 9cb15f9bdb2..0a16a878558 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -82,6 +82,12 @@ class VScalKernel : public Kernel { virtual void Compute(const int n, const T a, T *x) = 0; }; +template +class VExpKernel : public Kernel { + public: + virtual void Compute(const int n, const T *x, T *y) = 0; +}; + template class LSTMKernel : public Kernel { public: diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc index 0ec9ac10c81..a08d53f4963 100644 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/jit_kernel.h" #include +#include "paddle/fluid/operators/math/jit_kernel_macro.h" #ifdef PADDLE_WITH_MKLML #include "paddle/fluid/platform/dynload/mklml.h" #endif @@ -29,71 +30,6 @@ namespace jitkernel { namespace jit = platform::jit; -#define NEW_IMPL(src, t, isa, k) \ - p = std::dynamic_pointer_cast>( \ - std::make_shared>()) - -#define SEARCH_BLOCK(src, t, isa) \ - if (d < AVX_FLOAT_BLOCK) { \ - NEW_IMPL(src, t, isa, kLT8); \ - } else if (d == AVX_FLOAT_BLOCK) { \ - NEW_IMPL(src, t, isa, kEQ8); \ - } else if (d > AVX_FLOAT_BLOCK && d < AVX512_FLOAT_BLOCK) { \ - NEW_IMPL(src, t, isa, kGT8LT16); \ - } else if (d == AVX512_FLOAT_BLOCK) { \ - NEW_IMPL(src, t, isa, kEQ16); \ - } else { \ - NEW_IMPL(src, t, isa, kGT16); \ - } - -#define SEARCH_ISA_BLOCK(src, t) \ - if (jit::MayIUse(jit::avx512f)) { \ - SEARCH_BLOCK(src, t, jit::avx512f); \ - } else if (jit::MayIUse(jit::avx2)) { \ - SEARCH_BLOCK(src, t, jit::avx2); \ - } else if (jit::MayIUse(jit::avx)) { \ - SEARCH_BLOCK(src, t, jit::avx); \ - } else { \ - SEARCH_BLOCK(src, t, jit::isa_any); \ - } - -#define DEFINE_WITH_DTYPE(ker_key, ker_class, ker_dtype, dtype_key) \ - template <> \ - const std::shared_ptr> \ - KernelPool::Get>(int d) { \ - std::string key = #ker_key #dtype_key + std::to_string(d); \ - if (kers_.find(key) == kers_.end()) { \ - std::shared_ptr> p; \ - SEARCH_ISA_BLOCK(ker_class, ker_dtype); \ - kers_.insert({key, std::dynamic_pointer_cast(p)}); \ - return p; \ - } \ - return std::dynamic_pointer_cast>(kers_.at(key)); \ - } - -#define REGISTER_BLAS_JITKERNEL(ker_key, ker_class) \ - DEFINE_WITH_DTYPE(ker_key, ker_class, float, f); \ - DEFINE_WITH_DTYPE(ker_key, ker_class, double, d) - -#define FOR_EACH_ISA(macro_, block) \ - macro_(jit::avx512f, block); \ - macro_(jit::avx2, block); \ - macro_(jit::avx, block); \ - macro_(jit::isa_any, block) - -#define FOR_EACH_BLOCK(macro_, isa) \ - macro_(isa, kLT8); \ - macro_(isa, kEQ8); \ - macro_(isa, kGT8LT16); \ - macro_(isa, kEQ16); \ - macro_(isa, kGT16) - -#define FOR_EACH_ISA_BLOCK(macro_) \ - FOR_EACH_BLOCK(macro_, jit::avx512f); \ - FOR_EACH_BLOCK(macro_, jit::avx2); \ - FOR_EACH_BLOCK(macro_, jit::avx); \ - FOR_EACH_BLOCK(macro_, jit::isa_any) - /* VMUL JitKernel */ template class VMulKernelImpl : public VMulKernel { @@ -106,25 +42,25 @@ class VMulKernelImpl : public VMulKernel { }; #ifdef PADDLE_WITH_MKLML -#define VMUL_MKL_FLOAT(isa, block) \ +#define MKL_FLOAT(isa, block) \ template <> \ void VMulKernelImpl::Compute(const int n, const float* x, \ const float* y, float* z) { \ platform::dynload::vsMul(n, x, y, z); \ } -#define VMUL_MKL_DOUBLE(isa, block) \ +#define MKL_DOUBLE(isa, block) \ template <> \ void VMulKernelImpl::Compute( \ const int n, const double* x, const double* y, double* z) { \ platform::dynload::vdMul(n, x, y, z); \ } -FOR_EACH_ISA(VMUL_MKL_FLOAT, kGT16); -FOR_EACH_ISA_BLOCK(VMUL_MKL_DOUBLE); +FOR_EACH_ISA(MKL_FLOAT, kGT16); +FOR_EACH_ISA_BLOCK(MKL_DOUBLE); #endif -#define VMUL_INTRI8_FLOAT(isa) \ +#define INTRI8_FLOAT(isa) \ template <> \ void VMulKernelImpl::Compute(const int n, const float* x, \ const float* y, float* z) { \ @@ -137,19 +73,18 @@ FOR_EACH_ISA_BLOCK(VMUL_MKL_DOUBLE); // avx > for > mkl #ifdef __AVX__ -VMUL_INTRI8_FLOAT(jit::avx); +INTRI8_FLOAT(jit::avx); #endif #ifdef __AVX2__ -VMUL_INTRI8_FLOAT(jit::avx2); +INTRI8_FLOAT(jit::avx2); #endif #ifdef __AVX512F__ -VMUL_INTRI8_FLOAT(jit::avx512f); +INTRI8_FLOAT(jit::avx512f); #endif - // TODO(TJ): eq16 test and complete avx512 -#undef VMUL_INTRI8_FLOAT -#undef VMUL_MKL_FLOAT -#undef VMUL_MKL_DOUBLE +#undef INTRI8_FLOAT +#undef MKL_FLOAT +#undef MKL_DOUBLE /* VADD JitKernel */ template @@ -163,25 +98,25 @@ class VAddKernelImpl : public VAddKernel { }; #ifdef PADDLE_WITH_MKLML -#define VADD_MKL_FLOAT(isa, block) \ +#define MKL_FLOAT(isa, block) \ template <> \ void VAddKernelImpl::Compute(const int n, const float* x, \ const float* y, float* z) { \ platform::dynload::vsAdd(n, x, y, z); \ } -#define VADD_MKL_DOUBLE(isa, block) \ +#define MKL_DOUBLE(isa, block) \ template <> \ void VAddKernelImpl::Compute( \ const int n, const double* x, const double* y, double* z) { \ platform::dynload::vdAdd(n, x, y, z); \ } -FOR_EACH_ISA(VADD_MKL_FLOAT, kGT16); -FOR_EACH_ISA_BLOCK(VADD_MKL_DOUBLE); +FOR_EACH_ISA(MKL_FLOAT, kGT16); +FOR_EACH_ISA_BLOCK(MKL_DOUBLE); #endif -#define VADD_INTRI8_FLOAT(isa) \ +#define INTRI8_FLOAT(isa) \ template <> \ void VAddKernelImpl::Compute(const int n, const float* x, \ const float* y, float* z) { \ @@ -192,19 +127,19 @@ FOR_EACH_ISA_BLOCK(VADD_MKL_DOUBLE); _mm256_storeu_ps(z, tmpx); \ } #ifdef __AVX__ -VADD_INTRI8_FLOAT(jit::avx); +INTRI8_FLOAT(jit::avx); #endif #ifdef __AVX2__ -VADD_INTRI8_FLOAT(jit::avx2); +INTRI8_FLOAT(jit::avx2); #endif #ifdef __AVX512F__ -VADD_INTRI8_FLOAT(jit::avx512f); +INTRI8_FLOAT(jit::avx512f); #endif // TODO(TJ): eq16 test and complete avx512 -#undef VADD_INTRI8_FLOAT -#undef VADD_MKL_FLOAT -#undef VADD_MKL_DOUBLE +#undef INTRI8_FLOAT +#undef MKL_FLOAT +#undef MKL_DOUBLE /* VSCAL JitKernel */ template @@ -223,25 +158,25 @@ class VScalKernelImpl : public VScalKernel { }; #ifdef PADDLE_WITH_MKLML -#define VSCAL_MKL_FLOAT(isa, block) \ +#define MKL_FLOAT(isa, block) \ template <> \ void VScalKernelImpl::Compute(const int n, const float a, \ float* x) { \ platform::dynload::cblas_sscal(n, a, x, 1); \ } -#define VSCAL_MKL_DOUBLE(isa, block) \ +#define MKL_DOUBLE(isa, block) \ template <> \ void VScalKernelImpl::Compute( \ const int n, const double a, double* x) { \ platform::dynload::cblas_dscal(n, a, x, 1); \ } -FOR_EACH_ISA(VSCAL_MKL_FLOAT, kGT16); -FOR_EACH_ISA_BLOCK(VSCAL_MKL_DOUBLE); +FOR_EACH_ISA(MKL_FLOAT, kGT16); +FOR_EACH_ISA_BLOCK(MKL_DOUBLE); #endif -#define VSCAL_INTRI8(isa) \ +#define INTRI8_FLOAT(isa) \ template <> \ void VScalKernelImpl::Compute(const int n, const float a, \ const float* x, float* y) { \ @@ -251,7 +186,7 @@ FOR_EACH_ISA_BLOCK(VSCAL_MKL_DOUBLE); tmp = _mm256_mul_ps(tmp, scalar); \ _mm256_storeu_ps(y, tmp); \ } -#define VSCAL_INTRI8_INPLACE(isa) \ +#define INTRI8_INPLACE_FLOAT(isa) \ template <> \ void VScalKernelImpl::Compute(const int n, const float a, \ float* x) { \ @@ -263,36 +198,27 @@ FOR_EACH_ISA_BLOCK(VSCAL_MKL_DOUBLE); } #ifdef __AVX__ -VSCAL_INTRI8(jit::avx); -VSCAL_INTRI8_INPLACE(jit::avx); +INTRI8_FLOAT(jit::avx); +INTRI8_INPLACE_FLOAT(jit::avx); #endif #ifdef __AVX2__ -VSCAL_INTRI8(jit::avx2); -VSCAL_INTRI8_INPLACE(jit::avx2); +INTRI8_FLOAT(jit::avx2); +INTRI8_INPLACE_FLOAT(jit::avx2); #endif #ifdef __AVX512F__ -VSCAL_INTRI8(jit::avx512f); -VSCAL_INTRI8_INPLACE(jit::avx512f); +INTRI8_FLOAT(jit::avx512f); +INTRI8_INPLACE_FLOAT(jit::avx512f); #endif // TODO(TJ): eq16 test and complete avx512 -#undef VSCAL_INTRI8 -#undef VSCAL_INTRI8_INPLACE -#undef VSCAL_MKL_FLOAT -#undef VSCAL_MKL_DOUBLE - -REGISTER_BLAS_JITKERNEL(vmul, VMulKernel); -REGISTER_BLAS_JITKERNEL(vadd, VAddKernel); -REGISTER_BLAS_JITKERNEL(vscal, VScalKernel); +#undef INTRI8_FLOAT +#undef INTRI8_INPLACE_FLOAT +#undef MKL_FLOAT +#undef MKL_DOUBLE -#undef FOR_EACH_ISA -#undef FOR_EACH_BLOCK -#undef FOR_EACH_ISA_BLOCK -#undef REGISTER_BLAS_JITKERNEL -#undef DEFINE_WITH_DTYPE -#undef SEARCH_ISA_BLOCK -#undef SEARCH_BLOCK -#undef NEW_IMPL +REGISTER_JITKERNEL(vmul, VMulKernel); +REGISTER_JITKERNEL(vadd, VAddKernel); +REGISTER_JITKERNEL(vscal, VScalKernel); } // namespace jitkernel } // namespace math diff --git a/paddle/fluid/operators/math/jit_kernel_exp.cc b/paddle/fluid/operators/math/jit_kernel_exp.cc new file mode 100644 index 00000000000..5f04ba97be0 --- /dev/null +++ b/paddle/fluid/operators/math/jit_kernel_exp.cc @@ -0,0 +1,115 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/math/jit_kernel.h" +#include +#include "paddle/fluid/operators/math/jit_kernel_macro.h" +#ifdef PADDLE_WITH_MKLML +#include "paddle/fluid/platform/dynload/mklml.h" +#endif + +#ifdef __AVX__ +#include +#endif + +namespace paddle { +namespace operators { +namespace math { + +#ifdef __AVX__ +namespace detail { +__m256 Exp(__m256 a); +} // namespace detail +#endif + +namespace jitkernel { + +namespace jit = platform::jit; + +/* VExp JitKernel */ +template +class VExpKernelImpl : public VExpKernel { + public: + void Compute(const int n, const T* x, T* y) override { + for (int i = 0; i < n; ++i) { + y[i] = std::exp(x[i]); + } + } +}; + +#ifdef PADDLE_WITH_MKLML +#define MKL_FLOAT(isa, block) \ + template <> \ + void VExpKernelImpl::Compute(const int n, const float* x, \ + float* y) { \ + platform::dynload::vsExp(n, x, y); \ + } + +#define MKL_DOUBLE(isa, block) \ + template <> \ + void VExpKernelImpl::Compute( \ + const int n, const double* x, double* y) { \ + platform::dynload::vdExp(n, x, y); \ + } +FOR_EACH_ISA(MKL_FLOAT, kLT8); +FOR_EACH_ISA(MKL_FLOAT, kGT8LT16); +FOR_EACH_ISA(MKL_FLOAT, kGT16); +FOR_EACH_ISA_BLOCK(MKL_DOUBLE); +#endif + +#define INTRI8_FLOAT(isa) \ + template <> \ + void VExpKernelImpl::Compute(const int n, const float* x, \ + float* y) { \ + __m256 tmp = _mm256_loadu_ps(x); \ + _mm256_storeu_ps(y, detail::Exp(tmp)); \ + } + +#define INTRI16_FLOAT(isa) \ + template <> \ + void VExpKernelImpl::Compute(const int n, const float* x, \ + float* y) { \ + __m256 tmp0 = _mm256_loadu_ps(x); \ + __m256 tmp1 = _mm256_loadu_ps(x + 8); \ + tmp0 = detail::Exp(tmp0); \ + tmp1 = detail::Exp(tmp1); \ + _mm256_storeu_ps(y, tmp0); \ + _mm256_storeu_ps(y + 8, tmp1); \ + } + +#ifdef __AVX__ +INTRI8_FLOAT(jit::avx); +INTRI16_FLOAT(jit::avx); +#endif +#ifdef __AVX2__ +INTRI8_FLOAT(jit::avx2); +INTRI16_FLOAT(jit::avx2); +#endif +#ifdef __AVX512F__ +INTRI8_FLOAT(jit::avx512f); +INTRI16_FLOAT(jit::avx512f); +#endif +// TODO(TJ): eq16 test and complete avx512 + +#undef INTRI8_FLOAT +#undef INTRI16_FLOAT +#undef MKL_FLOAT +#undef MKL_DOUBLE + +REGISTER_JITKERNEL(vexp, VExpKernel); + +} // namespace jitkernel +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/jit_kernel_macro.h b/paddle/fluid/operators/math/jit_kernel_macro.h new file mode 100644 index 00000000000..239583f3018 --- /dev/null +++ b/paddle/fluid/operators/math/jit_kernel_macro.h @@ -0,0 +1,94 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/fluid/platform/cpu_info.h" + +namespace paddle { +namespace operators { +namespace math { +namespace jitkernel { + +namespace jit = platform::jit; + +#define NEW_JITKERNEL_IMPL(src, t, isa, k) \ + p = std::dynamic_pointer_cast>( \ + std::make_shared>()) + +#define SEARCH_BLOCK(src, t, isa) \ + if (d < AVX_FLOAT_BLOCK) { \ + NEW_JITKERNEL_IMPL(src, t, isa, kLT8); \ + } else if (d == AVX_FLOAT_BLOCK) { \ + NEW_JITKERNEL_IMPL(src, t, isa, kEQ8); \ + } else if (d > AVX_FLOAT_BLOCK && d < AVX512_FLOAT_BLOCK) { \ + NEW_JITKERNEL_IMPL(src, t, isa, kGT8LT16); \ + } else if (d == AVX512_FLOAT_BLOCK) { \ + NEW_JITKERNEL_IMPL(src, t, isa, kEQ16); \ + } else { \ + NEW_JITKERNEL_IMPL(src, t, isa, kGT16); \ + } + +#define SEARCH_ISA_BLOCK(src, t) \ + if (jit::MayIUse(jit::avx512f)) { \ + SEARCH_BLOCK(src, t, jit::avx512f); \ + } else if (jit::MayIUse(jit::avx2)) { \ + SEARCH_BLOCK(src, t, jit::avx2); \ + } else if (jit::MayIUse(jit::avx)) { \ + SEARCH_BLOCK(src, t, jit::avx); \ + } else { \ + SEARCH_BLOCK(src, t, jit::isa_any); \ + } + +#define JITKERNEL_WITH_DTYPE(ker_key, ker_class, ker_dtype, dtype_key) \ + template <> \ + const std::shared_ptr> \ + KernelPool::Get>(int d) { \ + std::string key = #ker_key #dtype_key + std::to_string(d); \ + if (kers_.find(key) == kers_.end()) { \ + std::shared_ptr> p; \ + SEARCH_ISA_BLOCK(ker_class, ker_dtype); \ + kers_.insert({key, std::dynamic_pointer_cast(p)}); \ + return p; \ + } \ + return std::dynamic_pointer_cast>(kers_.at(key)); \ + } + +#define REGISTER_JITKERNEL(ker_key, ker_class) \ + JITKERNEL_WITH_DTYPE(ker_key, ker_class, float, f); \ + JITKERNEL_WITH_DTYPE(ker_key, ker_class, double, d) + +#define FOR_EACH_ISA(macro_, block) \ + macro_(jit::avx512f, block); \ + macro_(jit::avx2, block); \ + macro_(jit::avx, block); \ + macro_(jit::isa_any, block) + +#define FOR_EACH_BLOCK(macro_, isa) \ + macro_(isa, kLT8); \ + macro_(isa, kEQ8); \ + macro_(isa, kGT8LT16); \ + macro_(isa, kEQ16); \ + macro_(isa, kGT16) + +#define FOR_EACH_ISA_BLOCK(macro_) \ + FOR_EACH_BLOCK(macro_, jit::avx512f); \ + FOR_EACH_BLOCK(macro_, jit::avx2); \ + FOR_EACH_BLOCK(macro_, jit::avx); \ + FOR_EACH_BLOCK(macro_, jit::isa_any) + +} // namespace jitkernel +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc index ccd687d587d..a23d5fff04e 100644 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ b/paddle/fluid/operators/math/jit_kernel_test.cc @@ -14,7 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/jit_kernel.h" #include -#include +#include // for memcpy #include #include #include "gflags/gflags.h" @@ -38,17 +38,72 @@ inline double GetCurrentUS() { } template -void RandomVec(const int n, T* a) { +void RandomVec(const int n, T* a, const T lower = static_cast(-20.f), + const T upper = static_cast(20.f)) { static unsigned int seed = 100; std::mt19937 rng(seed++); std::uniform_real_distribution uniform_dist(0, 1); - const T lower = static_cast(-20.f); - const T upper = static_cast(20.f); for (int i = 0; i < n; ++i) { a[i] = static_cast(uniform_dist(rng) * (upper - lower) + lower); } } +void vexp_ref(const int n, const float* x, float* y) { + for (int i = 0; i < n; ++i) { + y[i] = std::exp(x[i]); + } +} + +#ifdef PADDLE_WITH_MKLML +void vexp_mkl(const int n, const float* x, float* y) { + paddle::platform::dynload::vsExp(n, x, y); +} +#endif + +TEST(JitKernel, vexp) { + namespace jit = paddle::operators::math::jitkernel; + for (int d : {7, 8, 15, 16, 30, 128}) { + std::vector x(d); + std::vector zref(d), ztgt(d); + RandomVec(d, x.data(), -2.f, 2.f); + const auto& ker = + jit::KernelPool::Instance().template Get>(d); + const float* x_data = x.data(); + float* ztgt_data = ztgt.data(); + float* zref_data = zref.data(); + auto trefs = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + vexp_ref(d, x_data, zref_data); + } + auto trefe = GetCurrentUS(); + +#ifdef PADDLE_WITH_MKLML + auto tmkls = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + vexp_mkl(d, x_data, zref_data); + } + auto tmkle = GetCurrentUS(); +#endif + + auto ttgts = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + ker->Compute(d, x_data, ztgt_data); + } + auto ttgte = GetCurrentUS(); + + VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat +#ifdef PADDLE_WITH_MKLML + << " us, mkl takes: " << (tmkle - tmkls) / repeat << " us, " +#else + << " us, " +#endif + << "tgt takes: " << (ttgte - ttgts) / repeat; + for (int i = 0; i < d; ++i) { + EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); + } + } +} + void vscal_ref(const int n, const float a, const float* x, float* y) { for (int i = 0; i < n; ++i) { y[i] = a * x[i]; -- GitLab