diff --git a/paddle/fluid/operators/math/fc_compute.h b/paddle/fluid/operators/math/fc_compute.h index 2d7e877a77b458e59b131df976169fb58b634f24..87220d4019fc9337fb8355172ca4f1372cfd4558 100644 --- a/paddle/fluid/operators/math/fc_compute.h +++ b/paddle/fluid/operators/math/fc_compute.h @@ -15,7 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/fluid/operators/math/blas.h" -#include "paddle/fluid/operators/math/jit_kernel.h" // TODO(TJ): add deps +#include "paddle/fluid/operators/math/jit_kernel.h" DECLARE_int32(paddle_num_threads); diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc index a486a0ca8040fb1dd8facf25762b70e34290affd..c88b17b012d1b9cd59220f6b37cb2ecb8a1551a4 100644 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -447,20 +447,17 @@ class VAddReluKernelImpl : public VAddReluKernel { #ifdef __AVX__ INTRI8_FLOAT(jit::avx); INTRI16_FLOAT(jit::avx); -INTRI_COMMON_FLOAT(jit::avx, kGT8LT16); INTRI_COMMON_FLOAT(jit::avx, kGT16); #endif #ifdef __AVX2__ INTRI8_FLOAT(jit::avx2); INTRI16_FLOAT(jit::avx2); -INTRI_COMMON_FLOAT(jit::avx2, kGT8LT16); INTRI_COMMON_FLOAT(jit::avx2, kGT16); #endif #ifdef __AVX512F__ // TODO(TJ): refine avx512 INTRI8_FLOAT(jit::avx512f); INTRI16_FLOAT(jit::avx512f); -INTRI_COMMON_FLOAT(jit::avx512f, kGT8LT16); INTRI_COMMON_FLOAT(jit::avx512f, kGT16); #endif diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc index 7fdd1c6b76aebcea757540e7312a679b8c08402a..c9e6ab740da4e39de3dc41f9df352b88e696c38d 100644 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ b/paddle/fluid/operators/math/jit_kernel_test.cc @@ -712,6 +712,63 @@ TEST(JitKernel, vadd) { } } +void vaddrelu_ref(const int n, const float* x, const float* y, float* z) { + for (int i = 0; i < n; ++i) { + z[i] = x[i] + y[i]; + z[i] = z[i] > 0 ? z[i] : 0; + } +} +void vaddrelu_better( + const std::shared_ptr< + const paddle::operators::math::jitkernel::VAddKernel>& vadd, + const std::shared_ptr< + const paddle::operators::math::jitkernel::VReluKernel>& vrelu, + const float* x, const float* y, float* z) { + vadd->Compute(x, y, z); + vrelu->Compute(z, z); +} + +TEST(JitKernel, vaddrelu) { + namespace jit = paddle::operators::math::jitkernel; + for (int d : {7, 8, 15, 16, 30, 256, 512}) { + std::vector x(d), y(d); + std::vector zref(d), ztgt(d); + RandomVec(d, x.data()); + RandomVec(d, y.data()); + const auto& ker = + jit::KernelPool::Instance().template Get>(d); + const auto& vadd = + jit::KernelPool::Instance().template Get>(d); + const auto& vrelu = + jit::KernelPool::Instance().template Get>(d); + const float* x_data = x.data(); + const float* y_data = y.data(); + float* ztgt_data = ztgt.data(); + float* zref_data = zref.data(); + auto trefs = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + vadd_ref(d, x_data, y_data, zref_data); + } + auto trefe = GetCurrentUS(); + auto tmkls = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + vaddrelu_better(vadd, vrelu, x_data, y_data, zref_data); + } + auto tmkle = GetCurrentUS(); + auto ttgts = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + ker->Compute(x_data, y_data, ztgt_data); + } + auto ttgte = GetCurrentUS(); + VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat + << " us, better takes: " << (tmkle - tmkls) / repeat << " us, " + << "tgt takes: " << (ttgte - ttgts) / repeat; + for (int i = 0; i < d; ++i) { + EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); + } + } +} + TEST(JitKernel, pool) { namespace jit = paddle::operators::math::jitkernel; const int frame_size = 4;