From d59f7335515ac769d8f4d288b7eb32b1669490b2 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Mon, 28 Jan 2019 18:06:56 +0000 Subject: [PATCH] refine softmax and use with cache test=develop --- paddle/fluid/operators/jit/benchmark.cc | 3 ++ paddle/fluid/operators/jit/gen/act.cc | 28 ++++++++++-- paddle/fluid/operators/jit/helper.h | 22 ++++++++++ paddle/fluid/operators/jit/more/mix/mix.cc | 50 +++++++++++++++++++--- paddle/fluid/operators/jit/more/mkl/mkl.cc | 3 +- paddle/fluid/operators/math/CMakeLists.txt | 2 +- paddle/fluid/operators/math/softmax_impl.h | 28 +++--------- 7 files changed, 102 insertions(+), 34 deletions(-) diff --git a/paddle/fluid/operators/jit/benchmark.cc b/paddle/fluid/operators/jit/benchmark.cc index 383532d8d22..5c5a61f6409 100644 --- a/paddle/fluid/operators/jit/benchmark.cc +++ b/paddle/fluid/operators/jit/benchmark.cc @@ -187,6 +187,9 @@ void BenchAXYNKernel() { RandomVec(d, x_data); BenchAllImpls, PlaceType>(d, &a, x.data(), y_data, d); + // test inplace + BenchAllImpls, PlaceType>(d, &a, x.data(), x_data, + d); } } diff --git a/paddle/fluid/operators/jit/gen/act.cc b/paddle/fluid/operators/jit/gen/act.cc index a2a5661b93a..e7a73758790 100644 --- a/paddle/fluid/operators/jit/gen/act.cc +++ b/paddle/fluid/operators/jit/gen/act.cc @@ -81,9 +81,7 @@ void VActJitCode::genCode() { #define DECLARE_ACT_CREATOR(name) \ class name##Creator : public JitCodeCreator { \ public: \ - bool UseMe(const int& attr) const override { \ - return platform::MayIUse(platform::avx); \ - } \ + bool UseMe(const int& attr) const override; \ size_t CodeSize(const int& d) const override; \ std::unique_ptr CreateJitCode(const int& attr) const override { \ return make_unique(attr, CodeSize(attr)); \ @@ -98,6 +96,30 @@ DECLARE_ACT_CREATOR(VSigmoid); DECLARE_ACT_CREATOR(VTanh); // TODO(TJ): tuning use me +bool VReluCreator::UseMe(const int& d) const { + return platform::MayIUse(platform::avx); +} + +bool VSquareCreator::UseMe(const int& d) const { + return platform::MayIUse(platform::avx); +} + +bool VIdentityCreator::UseMe(const int& d) const { + return platform::MayIUse(platform::avx); +} + +bool VExpCreator::UseMe(const int& d) const { + return platform::MayIUse(platform::avx) && d < 32; +} + +bool VSigmoidCreator::UseMe(const int& d) const { + return platform::MayIUse(platform::avx); +} + +bool VTanhCreator::UseMe(const int& d) const { + return platform::MayIUse(platform::avx); +} + size_t VReluCreator::CodeSize(const int& d) const { return 96 /* init size */ + (d / YMM_FLOAT_BLOCK + 3) * 4 /* instructions */ * diff --git a/paddle/fluid/operators/jit/helper.h b/paddle/fluid/operators/jit/helper.h index fbf34fc4b3d..7bdc45779b7 100644 --- a/paddle/fluid/operators/jit/helper.h +++ b/paddle/fluid/operators/jit/helper.h @@ -118,6 +118,28 @@ typename KernelTuples::func_type Get( return GetRefer(); } +template +class KernelFuncsCache { + public: + KernelFuncsCache() = default; + static KernelFuncsCache& Instance() { + static thread_local KernelFuncsCache g_func_cache; + return g_func_cache; + } + + bool Has(int key) const { return funcs_.find(key) != funcs_.end(); } + + typename KernelTuples::func_type At(int key) { return funcs_.at(key); } + + void Insert(int key, typename KernelTuples::func_type func) { + funcs_.emplace(key, func); + } + + private: + std::unordered_map funcs_; + DISABLE_COPY_AND_ASSIGN(KernelFuncsCache); +}; + const char* to_string(KernelType kt); const char* to_string(SeqPoolType kt); diff --git a/paddle/fluid/operators/jit/more/mix/mix.cc b/paddle/fluid/operators/jit/more/mix/mix.cc index 2a75eb23cdf..0f42ac158ca 100644 --- a/paddle/fluid/operators/jit/more/mix/mix.cc +++ b/paddle/fluid/operators/jit/more/mix/mix.cc @@ -49,12 +49,50 @@ void VTanh(const T* x, T* y, int n) { } void Softmax(const T* x, T* y, int n, int bs) { - auto compute_hmax = Get, platform::CPUPlace>(n); - auto compute_hsum = Get, platform::CPUPlace>(n); - auto compute_vscal = Get, platform::CPUPlace>(n); - auto compute_vaddbias = Get, platform::CPUPlace>(n); - auto compute_vexp = - Get, platform::CPUPlace>(n); + typename XRNTuples::func_type compute_hmax{nullptr}; + typename XRNTuples::func_type compute_hsum{nullptr}; + typename AXYNTuples::func_type compute_vscal{nullptr}; + typename AXYNTuples::func_type compute_vaddbias{nullptr}; + typename XYNTuples::func_type compute_vexp{nullptr}; + + if (!KernelFuncsCache>::Instance().Has(n)) { + compute_hmax = Get, platform::CPUPlace>(n); + KernelFuncsCache>::Instance().Insert(n, compute_hmax); + } else { + compute_hmax = KernelFuncsCache>::Instance().At(n); + } + + if (!KernelFuncsCache>::Instance().Has(n)) { + compute_hsum = Get, platform::CPUPlace>(n); + KernelFuncsCache>::Instance().Insert(n, compute_hsum); + } else { + compute_hsum = KernelFuncsCache>::Instance().At(n); + } + + if (!KernelFuncsCache>::Instance().Has(n)) { + compute_vscal = Get, platform::CPUPlace>(n); + KernelFuncsCache>::Instance().Insert(n, + compute_vscal); + } else { + compute_vscal = KernelFuncsCache>::Instance().At(n); + } + + if (!KernelFuncsCache>::Instance().Has(n)) { + compute_vaddbias = Get, platform::CPUPlace>(n); + KernelFuncsCache>::Instance().Insert( + n, compute_vaddbias); + } else { + compute_vaddbias = + KernelFuncsCache>::Instance().At(n); + } + + if (!KernelFuncsCache>::Instance().Has(n)) { + compute_vexp = Get, platform::CPUPlace>(n); + KernelFuncsCache>::Instance().Insert(n, compute_vexp); + } else { + compute_vexp = KernelFuncsCache>::Instance().At(n); + } + for (int i = 0; i < bs; ++i) { T scalar; compute_hmax(x, &scalar, n); diff --git a/paddle/fluid/operators/jit/more/mkl/mkl.cc b/paddle/fluid/operators/jit/more/mkl/mkl.cc index b13b8638e28..28a37198dae 100644 --- a/paddle/fluid/operators/jit/more/mkl/mkl.cc +++ b/paddle/fluid/operators/jit/more/mkl/mkl.cc @@ -179,7 +179,8 @@ bool SeqPoolKernel::UseMe(const seq_pool_attr_t& attr) const { template <> bool SoftmaxKernel::UseMe(const int& d) const { - return true; + // tuned on avx2 + return platform::MayIUse(platform::avx) && d < 60; } #define AWALYS_USE_ME_WITH_DOUBLE(func) \ diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index 6bbb7155dda..e20524012a5 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -53,7 +53,7 @@ math_library(sequence2batch) math_library(sequence_padding) math_library(sequence_pooling DEPS math_function jit_kernel_helper) math_library(sequence_scale) -math_library(softmax DEPS math_function) +math_library(softmax DEPS math_function jit_kernel_helper) math_library(beam_search DEPS math_function) math_library(matrix_bit_code) diff --git a/paddle/fluid/operators/math/softmax_impl.h b/paddle/fluid/operators/math/softmax_impl.h index 1d9d98b1064..1ff9ff684fc 100644 --- a/paddle/fluid/operators/math/softmax_impl.h +++ b/paddle/fluid/operators/math/softmax_impl.h @@ -16,8 +16,8 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/jit/kernels.h" -#include "paddle/fluid/operators/math/blas.h" namespace paddle { namespace operators { namespace math { @@ -81,28 +81,10 @@ class SoftmaxFunctor> { const int kBatchDim = 0; const int kClassDim = 1; // 2D data. Batch x C - const int batch_size = in_dims[kBatchDim]; - const int num_classes = in_dims[kClassDim]; - std::vector entities(batch_size); - auto blas = math::GetBlas(context); - for (int n = 0; n < batch_size; ++n) { - entities[n] = in_data[n * num_classes]; - for (int c = 1; c < num_classes; ++c) { - entities[n] = in_data[n * num_classes + c] > entities[n] - ? in_data[n * num_classes + c] - : entities[n]; - } - for (int c = 0; c < num_classes; ++c) { - out_data[n * num_classes + c] = - in_data[n * num_classes + c] - entities[n]; - } - } - - blas.VEXP(num_classes * batch_size, out_data, out_data); - for (int n = 0; n < batch_size; ++n) { - auto sum = blas.ASUM(num_classes, &out_data[n * num_classes], 1); - blas.SCAL(num_classes, 1.0f / sum, &out_data[n * num_classes]); - } + auto compute_softmax = + jit::Get, platform::CPUPlace>( + in_dims[kClassDim]); + compute_softmax(in_data, out_data, in_dims[kClassDim], in_dims[kBatchDim]); } }; -- GitLab