diff --git a/paddle/fluid/operators/jit/helper.h b/paddle/fluid/operators/jit/helper.h index 7bdc45779b7d39d36db0d52ca9361943cdcdef3e..7e8049c0e1d162b6f2e6daae058b0b941ba14e90 100644 --- a/paddle/fluid/operators/jit/helper.h +++ b/paddle/fluid/operators/jit/helper.h @@ -118,26 +118,33 @@ typename KernelTuples::func_type Get( return GetRefer(); } -template -class KernelFuncsCache { +template +class KernelFuncs { public: - KernelFuncsCache() = default; - static KernelFuncsCache& Instance() { - static thread_local KernelFuncsCache g_func_cache; + KernelFuncs() = default; + static KernelFuncs& Cache() { + static thread_local KernelFuncs g_func_cache; return g_func_cache; } bool Has(int key) const { return funcs_.find(key) != funcs_.end(); } - typename KernelTuples::func_type At(int key) { return funcs_.at(key); } - void Insert(int key, typename KernelTuples::func_type func) { funcs_.emplace(key, func); } + typename KernelTuples::func_type At(int key) { + if (Has(key)) { + return funcs_.at(key); + } + auto func = Get(key); + Insert(key, func); + return func; + } + private: std::unordered_map funcs_; - DISABLE_COPY_AND_ASSIGN(KernelFuncsCache); + DISABLE_COPY_AND_ASSIGN(KernelFuncs); }; const char* to_string(KernelType kt); diff --git a/paddle/fluid/operators/jit/more/mix/mix.cc b/paddle/fluid/operators/jit/more/mix/mix.cc index 0f42ac158ca7926981df55936cb903d5f4ae4806..0036d1c238b17768c4df61af22a85588990e1815 100644 --- a/paddle/fluid/operators/jit/more/mix/mix.cc +++ b/paddle/fluid/operators/jit/more/mix/mix.cc @@ -49,49 +49,16 @@ void VTanh(const T* x, T* y, int n) { } void Softmax(const T* x, T* y, int n, int bs) { - typename XRNTuples::func_type compute_hmax{nullptr}; - typename XRNTuples::func_type compute_hsum{nullptr}; - typename AXYNTuples::func_type compute_vscal{nullptr}; - typename AXYNTuples::func_type compute_vaddbias{nullptr}; - typename XYNTuples::func_type compute_vexp{nullptr}; - - if (!KernelFuncsCache>::Instance().Has(n)) { - compute_hmax = Get, platform::CPUPlace>(n); - KernelFuncsCache>::Instance().Insert(n, compute_hmax); - } else { - compute_hmax = KernelFuncsCache>::Instance().At(n); - } - - if (!KernelFuncsCache>::Instance().Has(n)) { - compute_hsum = Get, platform::CPUPlace>(n); - KernelFuncsCache>::Instance().Insert(n, compute_hsum); - } else { - compute_hsum = KernelFuncsCache>::Instance().At(n); - } - - if (!KernelFuncsCache>::Instance().Has(n)) { - compute_vscal = Get, platform::CPUPlace>(n); - KernelFuncsCache>::Instance().Insert(n, - compute_vscal); - } else { - compute_vscal = KernelFuncsCache>::Instance().At(n); - } - - if (!KernelFuncsCache>::Instance().Has(n)) { - compute_vaddbias = Get, platform::CPUPlace>(n); - KernelFuncsCache>::Instance().Insert( - n, compute_vaddbias); - } else { - compute_vaddbias = - KernelFuncsCache>::Instance().At(n); - } - - if (!KernelFuncsCache>::Instance().Has(n)) { - compute_vexp = Get, platform::CPUPlace>(n); - KernelFuncsCache>::Instance().Insert(n, compute_vexp); - } else { - compute_vexp = KernelFuncsCache>::Instance().At(n); - } + auto compute_hmax = + KernelFuncs, platform::CPUPlace>::Cache().At(n); + auto compute_hsum = + KernelFuncs, platform::CPUPlace>::Cache().At(n); + auto compute_vscal = + KernelFuncs, platform::CPUPlace>::Cache().At(n); + auto compute_vaddbias = + KernelFuncs, platform::CPUPlace>::Cache().At(n); + auto compute_vexp = + KernelFuncs, platform::CPUPlace>::Cache().At(n); for (int i = 0; i < bs; ++i) { T scalar; diff --git a/paddle/fluid/operators/math/softmax_impl.h b/paddle/fluid/operators/math/softmax_impl.h index 1ff9ff684fc8001afb0f768a033b4c5bd1592702..a1cb3f972826a67721b00ce6df0ec48cc34d6e03 100644 --- a/paddle/fluid/operators/math/softmax_impl.h +++ b/paddle/fluid/operators/math/softmax_impl.h @@ -82,8 +82,9 @@ class SoftmaxFunctor> { const int kClassDim = 1; // 2D data. Batch x C auto compute_softmax = - jit::Get, platform::CPUPlace>( - in_dims[kClassDim]); + jit::KernelFuncs, + platform::CPUPlace>::Cache() + .At(in_dims[kClassDim]); compute_softmax(in_data, out_data, in_dims[kClassDim], in_dims[kBatchDim]); } };