提交 6e1ee7fb 编写于 作者: T tensor-tang

cache softmax kernel func

test=develop
上级 c7449227
......@@ -118,26 +118,33 @@ typename KernelTuples::func_type Get(
return GetRefer<KT, KernelTuples>();
}
template <KernelType KT, typename KernelTuples>
class KernelFuncsCache {
template <KernelType KT, typename KernelTuples, typename PlaceType>
class KernelFuncs {
public:
KernelFuncsCache() = default;
static KernelFuncsCache& Instance() {
static thread_local KernelFuncsCache<KT, KernelTuples> g_func_cache;
KernelFuncs() = default;
static KernelFuncs& Cache() {
static thread_local KernelFuncs<KT, KernelTuples, PlaceType> g_func_cache;
return g_func_cache;
}
bool Has(int key) const { return funcs_.find(key) != funcs_.end(); }
typename KernelTuples::func_type At(int key) { return funcs_.at(key); }
void Insert(int key, typename KernelTuples::func_type func) {
funcs_.emplace(key, func);
}
typename KernelTuples::func_type At(int key) {
if (Has(key)) {
return funcs_.at(key);
}
auto func = Get<KT, KernelTuples, PlaceType>(key);
Insert(key, func);
return func;
}
private:
std::unordered_map<int, typename KernelTuples::func_type> funcs_;
DISABLE_COPY_AND_ASSIGN(KernelFuncsCache);
DISABLE_COPY_AND_ASSIGN(KernelFuncs);
};
const char* to_string(KernelType kt);
......
......@@ -49,49 +49,16 @@ void VTanh(const T* x, T* y, int n) {
}
void Softmax(const T* x, T* y, int n, int bs) {
typename XRNTuples<T>::func_type compute_hmax{nullptr};
typename XRNTuples<T>::func_type compute_hsum{nullptr};
typename AXYNTuples<T>::func_type compute_vscal{nullptr};
typename AXYNTuples<T>::func_type compute_vaddbias{nullptr};
typename XYNTuples<T>::func_type compute_vexp{nullptr};
if (!KernelFuncsCache<kHMax, XRNTuples<T>>::Instance().Has(n)) {
compute_hmax = Get<kHMax, XRNTuples<T>, platform::CPUPlace>(n);
KernelFuncsCache<kHMax, XRNTuples<T>>::Instance().Insert(n, compute_hmax);
} else {
compute_hmax = KernelFuncsCache<kHMax, XRNTuples<T>>::Instance().At(n);
}
if (!KernelFuncsCache<kHSum, XRNTuples<T>>::Instance().Has(n)) {
compute_hsum = Get<kHSum, XRNTuples<T>, platform::CPUPlace>(n);
KernelFuncsCache<kHSum, XRNTuples<T>>::Instance().Insert(n, compute_hsum);
} else {
compute_hsum = KernelFuncsCache<kHSum, XRNTuples<T>>::Instance().At(n);
}
if (!KernelFuncsCache<kVScal, AXYNTuples<T>>::Instance().Has(n)) {
compute_vscal = Get<kVScal, AXYNTuples<T>, platform::CPUPlace>(n);
KernelFuncsCache<kVScal, AXYNTuples<T>>::Instance().Insert(n,
compute_vscal);
} else {
compute_vscal = KernelFuncsCache<kVScal, AXYNTuples<T>>::Instance().At(n);
}
if (!KernelFuncsCache<kVAddBias, AXYNTuples<T>>::Instance().Has(n)) {
compute_vaddbias = Get<kVAddBias, AXYNTuples<T>, platform::CPUPlace>(n);
KernelFuncsCache<kVAddBias, AXYNTuples<T>>::Instance().Insert(
n, compute_vaddbias);
} else {
compute_vaddbias =
KernelFuncsCache<kVAddBias, AXYNTuples<T>>::Instance().At(n);
}
if (!KernelFuncsCache<kVExp, XYNTuples<T>>::Instance().Has(n)) {
compute_vexp = Get<KernelType::kVExp, XYNTuples<T>, platform::CPUPlace>(n);
KernelFuncsCache<kVExp, XYNTuples<T>>::Instance().Insert(n, compute_vexp);
} else {
compute_vexp = KernelFuncsCache<kVExp, XYNTuples<T>>::Instance().At(n);
}
auto compute_hmax =
KernelFuncs<kHMax, XRNTuples<T>, platform::CPUPlace>::Cache().At(n);
auto compute_hsum =
KernelFuncs<kHSum, XRNTuples<T>, platform::CPUPlace>::Cache().At(n);
auto compute_vscal =
KernelFuncs<kVScal, AXYNTuples<T>, platform::CPUPlace>::Cache().At(n);
auto compute_vaddbias =
KernelFuncs<kVAddBias, AXYNTuples<T>, platform::CPUPlace>::Cache().At(n);
auto compute_vexp =
KernelFuncs<kVExp, XYNTuples<T>, platform::CPUPlace>::Cache().At(n);
for (int i = 0; i < bs; ++i) {
T scalar;
......
......@@ -82,8 +82,9 @@ class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> {
const int kClassDim = 1;
// 2D data. Batch x C
auto compute_softmax =
jit::Get<jit::kSoftmax, jit::SoftmaxTuples<float>, platform::CPUPlace>(
in_dims[kClassDim]);
jit::KernelFuncs<jit::kSoftmax, jit::SoftmaxTuples<float>,
platform::CPUPlace>::Cache()
.At(in_dims[kClassDim]);
compute_softmax(in_data, out_data, in_dims[kClassDim], in_dims[kBatchDim]);
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册