From 2b0811c3fbd8ac31f986c0ed8fed345fe4e3f526 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Tue, 29 Jan 2019 08:01:33 +0000 Subject: [PATCH] refine vadd jitkernel choice test=develop --- paddle/fluid/operators/jit/benchmark.cc | 4 ++++ paddle/fluid/operators/jit/gen/blas.cc | 2 +- paddle/fluid/operators/jit/gen/blas.h | 1 + paddle/fluid/operators/jit/more/mkl/mkl.cc | 2 +- 4 files changed, 7 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/jit/benchmark.cc b/paddle/fluid/operators/jit/benchmark.cc index 5c5a61f6409..9d2ec5f91a9 100644 --- a/paddle/fluid/operators/jit/benchmark.cc +++ b/paddle/fluid/operators/jit/benchmark.cc @@ -93,6 +93,7 @@ std::vector TestSizes() { template struct BenchFunc { // return this function avg time + // TODO(TJ): clear cache every time double operator()(const typename KernelTuples::func_type tgt, Args... args) { for (int i = 0; i < FLAGS_burning; ++i) { tgt(args...); @@ -172,6 +173,9 @@ void BenchXYZNKernel() { RandomVec(d, y_data); BenchAllImpls, PlaceType>(d, x.data(), y.data(), z_data, d); + // test inplace + BenchAllImpls, PlaceType>(d, x.data(), z_data, + z_data, d); } } diff --git a/paddle/fluid/operators/jit/gen/blas.cc b/paddle/fluid/operators/jit/gen/blas.cc index dee6c7b9d3e..5da24c359ed 100644 --- a/paddle/fluid/operators/jit/gen/blas.cc +++ b/paddle/fluid/operators/jit/gen/blas.cc @@ -155,7 +155,7 @@ class NCHW16CMulNCCreator : public JitCodeCreator { class name##Creator : public JitCodeCreator { \ public: \ bool UseMe(const int& attr) const override { \ - return platform::MayIUse(platform::avx); \ + return platform::MayIUse(platform::avx) && attr <= 1024; \ } \ size_t CodeSize(const int& d) const override { \ return 96 + d / YMM_FLOAT_BLOCK * 4 * 8; \ diff --git a/paddle/fluid/operators/jit/gen/blas.h b/paddle/fluid/operators/jit/gen/blas.h index de6b33f4672..66a97c1be50 100644 --- a/paddle/fluid/operators/jit/gen/blas.h +++ b/paddle/fluid/operators/jit/gen/blas.h @@ -61,6 +61,7 @@ class VXXJitCode : public JitCode { base += "_Vec"; } base += (with_relu_ ? "_Relu" : ""); + base += "_D" + std::to_string(num_); return base.c_str(); } void genCode() override; diff --git a/paddle/fluid/operators/jit/more/mkl/mkl.cc b/paddle/fluid/operators/jit/more/mkl/mkl.cc index 28a37198dae..3f6814d6c6b 100644 --- a/paddle/fluid/operators/jit/more/mkl/mkl.cc +++ b/paddle/fluid/operators/jit/more/mkl/mkl.cc @@ -139,7 +139,7 @@ bool VMulKernel::UseMe(const int& d) const { template <> bool VAddKernel::UseMe(const int& d) const { - return platform::MayIUse(platform::avx512f) && d > 512; + return platform::MayIUse(platform::avx) && d > 512; } template <> -- GitLab