提交 6e1ee7fb 编写于 作者: T tensor-tang

cache softmax kernel func

test=develop
上级 c7449227
...@@ -118,26 +118,33 @@ typename KernelTuples::func_type Get( ...@@ -118,26 +118,33 @@ typename KernelTuples::func_type Get(
return GetRefer<KT, KernelTuples>(); return GetRefer<KT, KernelTuples>();
} }
template <KernelType KT, typename KernelTuples> template <KernelType KT, typename KernelTuples, typename PlaceType>
class KernelFuncsCache { class KernelFuncs {
public: public:
KernelFuncsCache() = default; KernelFuncs() = default;
static KernelFuncsCache& Instance() { static KernelFuncs& Cache() {
static thread_local KernelFuncsCache<KT, KernelTuples> g_func_cache; static thread_local KernelFuncs<KT, KernelTuples, PlaceType> g_func_cache;
return g_func_cache; return g_func_cache;
} }
bool Has(int key) const { return funcs_.find(key) != funcs_.end(); } bool Has(int key) const { return funcs_.find(key) != funcs_.end(); }
typename KernelTuples::func_type At(int key) { return funcs_.at(key); }
void Insert(int key, typename KernelTuples::func_type func) { void Insert(int key, typename KernelTuples::func_type func) {
funcs_.emplace(key, func); funcs_.emplace(key, func);
} }
typename KernelTuples::func_type At(int key) {
if (Has(key)) {
return funcs_.at(key);
}
auto func = Get<KT, KernelTuples, PlaceType>(key);
Insert(key, func);
return func;
}
private: private:
std::unordered_map<int, typename KernelTuples::func_type> funcs_; std::unordered_map<int, typename KernelTuples::func_type> funcs_;
DISABLE_COPY_AND_ASSIGN(KernelFuncsCache); DISABLE_COPY_AND_ASSIGN(KernelFuncs);
}; };
const char* to_string(KernelType kt); const char* to_string(KernelType kt);
......
...@@ -49,49 +49,16 @@ void VTanh(const T* x, T* y, int n) { ...@@ -49,49 +49,16 @@ void VTanh(const T* x, T* y, int n) {
} }
void Softmax(const T* x, T* y, int n, int bs) { void Softmax(const T* x, T* y, int n, int bs) {
typename XRNTuples<T>::func_type compute_hmax{nullptr}; auto compute_hmax =
typename XRNTuples<T>::func_type compute_hsum{nullptr}; KernelFuncs<kHMax, XRNTuples<T>, platform::CPUPlace>::Cache().At(n);
typename AXYNTuples<T>::func_type compute_vscal{nullptr}; auto compute_hsum =
typename AXYNTuples<T>::func_type compute_vaddbias{nullptr}; KernelFuncs<kHSum, XRNTuples<T>, platform::CPUPlace>::Cache().At(n);
typename XYNTuples<T>::func_type compute_vexp{nullptr}; auto compute_vscal =
KernelFuncs<kVScal, AXYNTuples<T>, platform::CPUPlace>::Cache().At(n);
if (!KernelFuncsCache<kHMax, XRNTuples<T>>::Instance().Has(n)) { auto compute_vaddbias =
compute_hmax = Get<kHMax, XRNTuples<T>, platform::CPUPlace>(n); KernelFuncs<kVAddBias, AXYNTuples<T>, platform::CPUPlace>::Cache().At(n);
KernelFuncsCache<kHMax, XRNTuples<T>>::Instance().Insert(n, compute_hmax); auto compute_vexp =
} else { KernelFuncs<kVExp, XYNTuples<T>, platform::CPUPlace>::Cache().At(n);
compute_hmax = KernelFuncsCache<kHMax, XRNTuples<T>>::Instance().At(n);
}
if (!KernelFuncsCache<kHSum, XRNTuples<T>>::Instance().Has(n)) {
compute_hsum = Get<kHSum, XRNTuples<T>, platform::CPUPlace>(n);
KernelFuncsCache<kHSum, XRNTuples<T>>::Instance().Insert(n, compute_hsum);
} else {
compute_hsum = KernelFuncsCache<kHSum, XRNTuples<T>>::Instance().At(n);
}
if (!KernelFuncsCache<kVScal, AXYNTuples<T>>::Instance().Has(n)) {
compute_vscal = Get<kVScal, AXYNTuples<T>, platform::CPUPlace>(n);
KernelFuncsCache<kVScal, AXYNTuples<T>>::Instance().Insert(n,
compute_vscal);
} else {
compute_vscal = KernelFuncsCache<kVScal, AXYNTuples<T>>::Instance().At(n);
}
if (!KernelFuncsCache<kVAddBias, AXYNTuples<T>>::Instance().Has(n)) {
compute_vaddbias = Get<kVAddBias, AXYNTuples<T>, platform::CPUPlace>(n);
KernelFuncsCache<kVAddBias, AXYNTuples<T>>::Instance().Insert(
n, compute_vaddbias);
} else {
compute_vaddbias =
KernelFuncsCache<kVAddBias, AXYNTuples<T>>::Instance().At(n);
}
if (!KernelFuncsCache<kVExp, XYNTuples<T>>::Instance().Has(n)) {
compute_vexp = Get<KernelType::kVExp, XYNTuples<T>, platform::CPUPlace>(n);
KernelFuncsCache<kVExp, XYNTuples<T>>::Instance().Insert(n, compute_vexp);
} else {
compute_vexp = KernelFuncsCache<kVExp, XYNTuples<T>>::Instance().At(n);
}
for (int i = 0; i < bs; ++i) { for (int i = 0; i < bs; ++i) {
T scalar; T scalar;
......
...@@ -82,8 +82,9 @@ class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> { ...@@ -82,8 +82,9 @@ class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> {
const int kClassDim = 1; const int kClassDim = 1;
// 2D data. Batch x C // 2D data. Batch x C
auto compute_softmax = auto compute_softmax =
jit::Get<jit::kSoftmax, jit::SoftmaxTuples<float>, platform::CPUPlace>( jit::KernelFuncs<jit::kSoftmax, jit::SoftmaxTuples<float>,
in_dims[kClassDim]); platform::CPUPlace>::Cache()
.At(in_dims[kClassDim]);
compute_softmax(in_data, out_data, in_dims[kClassDim], in_dims[kBatchDim]); compute_softmax(in_data, out_data, in_dims[kClassDim], in_dims[kBatchDim]);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册