未验证 提交 a6a1a92e 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #15586 from tensor-tang/jit/cache

refine bert
...@@ -93,6 +93,7 @@ std::vector<int> TestSizes() { ...@@ -93,6 +93,7 @@ std::vector<int> TestSizes() {
template <typename KernelTuples, typename... Args> template <typename KernelTuples, typename... Args>
struct BenchFunc { struct BenchFunc {
// return this function avg time // return this function avg time
// TODO(TJ): clear cache every time
double operator()(const typename KernelTuples::func_type tgt, Args... args) { double operator()(const typename KernelTuples::func_type tgt, Args... args) {
for (int i = 0; i < FLAGS_burning; ++i) { for (int i = 0; i < FLAGS_burning; ++i) {
tgt(args...); tgt(args...);
...@@ -172,6 +173,9 @@ void BenchXYZNKernel() { ...@@ -172,6 +173,9 @@ void BenchXYZNKernel() {
RandomVec<T>(d, y_data); RandomVec<T>(d, y_data);
BenchAllImpls<KT, jit::XYZNTuples<T>, PlaceType>(d, x.data<T>(), BenchAllImpls<KT, jit::XYZNTuples<T>, PlaceType>(d, x.data<T>(),
y.data<T>(), z_data, d); y.data<T>(), z_data, d);
// test inplace
BenchAllImpls<KT, jit::XYZNTuples<T>, PlaceType>(d, x.data<T>(), z_data,
z_data, d);
} }
} }
......
...@@ -155,7 +155,7 @@ class NCHW16CMulNCCreator : public JitCodeCreator<int> { ...@@ -155,7 +155,7 @@ class NCHW16CMulNCCreator : public JitCodeCreator<int> {
class name##Creator : public JitCodeCreator<int> { \ class name##Creator : public JitCodeCreator<int> { \
public: \ public: \
bool UseMe(const int& attr) const override { \ bool UseMe(const int& attr) const override { \
return platform::MayIUse(platform::avx); \ return platform::MayIUse(platform::avx) && attr <= 1024; \
} \ } \
size_t CodeSize(const int& d) const override { \ size_t CodeSize(const int& d) const override { \
return 96 + d / YMM_FLOAT_BLOCK * 4 * 8; \ return 96 + d / YMM_FLOAT_BLOCK * 4 * 8; \
......
...@@ -61,6 +61,7 @@ class VXXJitCode : public JitCode { ...@@ -61,6 +61,7 @@ class VXXJitCode : public JitCode {
base += "_Vec"; base += "_Vec";
} }
base += (with_relu_ ? "_Relu" : ""); base += (with_relu_ ? "_Relu" : "");
base += "_D" + std::to_string(num_);
return base.c_str(); return base.c_str();
} }
void genCode() override; void genCode() override;
......
...@@ -118,26 +118,33 @@ typename KernelTuples::func_type Get( ...@@ -118,26 +118,33 @@ typename KernelTuples::func_type Get(
return GetRefer<KT, KernelTuples>(); return GetRefer<KT, KernelTuples>();
} }
template <KernelType KT, typename KernelTuples> template <KernelType KT, typename KernelTuples, typename PlaceType>
class KernelFuncsCache { class KernelFuncs {
public: public:
KernelFuncsCache() = default; KernelFuncs() = default;
static KernelFuncsCache& Instance() { static KernelFuncs& Cache() {
static thread_local KernelFuncsCache<KT, KernelTuples> g_func_cache; static thread_local KernelFuncs<KT, KernelTuples, PlaceType> g_func_cache;
return g_func_cache; return g_func_cache;
} }
bool Has(int key) const { return funcs_.find(key) != funcs_.end(); } bool Has(int key) const { return funcs_.find(key) != funcs_.end(); }
typename KernelTuples::func_type At(int key) { return funcs_.at(key); }
void Insert(int key, typename KernelTuples::func_type func) { void Insert(int key, typename KernelTuples::func_type func) {
funcs_.emplace(key, func); funcs_.emplace(key, func);
} }
typename KernelTuples::func_type At(int key) {
if (Has(key)) {
return funcs_.at(key);
}
auto func = Get<KT, KernelTuples, PlaceType>(key);
Insert(key, func);
return func;
}
private: private:
std::unordered_map<int, typename KernelTuples::func_type> funcs_; std::unordered_map<int, typename KernelTuples::func_type> funcs_;
DISABLE_COPY_AND_ASSIGN(KernelFuncsCache); DISABLE_COPY_AND_ASSIGN(KernelFuncs);
}; };
const char* to_string(KernelType kt); const char* to_string(KernelType kt);
......
...@@ -49,49 +49,16 @@ void VTanh(const T* x, T* y, int n) { ...@@ -49,49 +49,16 @@ void VTanh(const T* x, T* y, int n) {
} }
void Softmax(const T* x, T* y, int n, int bs) { void Softmax(const T* x, T* y, int n, int bs) {
typename XRNTuples<T>::func_type compute_hmax{nullptr}; auto compute_hmax =
typename XRNTuples<T>::func_type compute_hsum{nullptr}; KernelFuncs<kHMax, XRNTuples<T>, platform::CPUPlace>::Cache().At(n);
typename AXYNTuples<T>::func_type compute_vscal{nullptr}; auto compute_hsum =
typename AXYNTuples<T>::func_type compute_vaddbias{nullptr}; KernelFuncs<kHSum, XRNTuples<T>, platform::CPUPlace>::Cache().At(n);
typename XYNTuples<T>::func_type compute_vexp{nullptr}; auto compute_vscal =
KernelFuncs<kVScal, AXYNTuples<T>, platform::CPUPlace>::Cache().At(n);
if (!KernelFuncsCache<kHMax, XRNTuples<T>>::Instance().Has(n)) { auto compute_vaddbias =
compute_hmax = Get<kHMax, XRNTuples<T>, platform::CPUPlace>(n); KernelFuncs<kVAddBias, AXYNTuples<T>, platform::CPUPlace>::Cache().At(n);
KernelFuncsCache<kHMax, XRNTuples<T>>::Instance().Insert(n, compute_hmax); auto compute_vexp =
} else { KernelFuncs<kVExp, XYNTuples<T>, platform::CPUPlace>::Cache().At(n);
compute_hmax = KernelFuncsCache<kHMax, XRNTuples<T>>::Instance().At(n);
}
if (!KernelFuncsCache<kHSum, XRNTuples<T>>::Instance().Has(n)) {
compute_hsum = Get<kHSum, XRNTuples<T>, platform::CPUPlace>(n);
KernelFuncsCache<kHSum, XRNTuples<T>>::Instance().Insert(n, compute_hsum);
} else {
compute_hsum = KernelFuncsCache<kHSum, XRNTuples<T>>::Instance().At(n);
}
if (!KernelFuncsCache<kVScal, AXYNTuples<T>>::Instance().Has(n)) {
compute_vscal = Get<kVScal, AXYNTuples<T>, platform::CPUPlace>(n);
KernelFuncsCache<kVScal, AXYNTuples<T>>::Instance().Insert(n,
compute_vscal);
} else {
compute_vscal = KernelFuncsCache<kVScal, AXYNTuples<T>>::Instance().At(n);
}
if (!KernelFuncsCache<kVAddBias, AXYNTuples<T>>::Instance().Has(n)) {
compute_vaddbias = Get<kVAddBias, AXYNTuples<T>, platform::CPUPlace>(n);
KernelFuncsCache<kVAddBias, AXYNTuples<T>>::Instance().Insert(
n, compute_vaddbias);
} else {
compute_vaddbias =
KernelFuncsCache<kVAddBias, AXYNTuples<T>>::Instance().At(n);
}
if (!KernelFuncsCache<kVExp, XYNTuples<T>>::Instance().Has(n)) {
compute_vexp = Get<KernelType::kVExp, XYNTuples<T>, platform::CPUPlace>(n);
KernelFuncsCache<kVExp, XYNTuples<T>>::Instance().Insert(n, compute_vexp);
} else {
compute_vexp = KernelFuncsCache<kVExp, XYNTuples<T>>::Instance().At(n);
}
for (int i = 0; i < bs; ++i) { for (int i = 0; i < bs; ++i) {
T scalar; T scalar;
......
...@@ -136,7 +136,7 @@ bool VMulKernel<float>::UseMe(const int& d) const { ...@@ -136,7 +136,7 @@ bool VMulKernel<float>::UseMe(const int& d) const {
template <> template <>
bool VAddKernel<float>::UseMe(const int& d) const { bool VAddKernel<float>::UseMe(const int& d) const {
return platform::MayIUse(platform::avx512f) && d > 512; return platform::MayIUse(platform::avx) && d > 512;
} }
template <> template <>
......
...@@ -30,15 +30,17 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M, ...@@ -30,15 +30,17 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M,
return; return;
} }
if (relu) { if (relu) {
auto compute = auto compute = jit::KernelFuncs<jit::kVAddRelu, jit::XYZNTuples<T>,
jit::Get<jit::kVAddRelu, jit::XYZNTuples<T>, platform::CPUPlace>(N); platform::CPUPlace>::Cache()
.At(N);
for (int i = 0; i < M; i++) { for (int i = 0; i < M; i++) {
T* dst = Y + i * N; T* dst = Y + i * N;
compute(B, dst, dst, N); compute(B, dst, dst, N);
} }
} else { } else {
auto compute = auto compute = jit::KernelFuncs<jit::kVAdd, jit::XYZNTuples<T>,
jit::Get<jit::kVAdd, jit::XYZNTuples<T>, platform::CPUPlace>(N); platform::CPUPlace>::Cache()
.At(N);
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#pragma omp parallel for #pragma omp parallel for
#endif #endif
......
...@@ -82,8 +82,9 @@ class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> { ...@@ -82,8 +82,9 @@ class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> {
const int kClassDim = 1; const int kClassDim = 1;
// 2D data. Batch x C // 2D data. Batch x C
auto compute_softmax = auto compute_softmax =
jit::Get<jit::kSoftmax, jit::SoftmaxTuples<float>, platform::CPUPlace>( jit::KernelFuncs<jit::kSoftmax, jit::SoftmaxTuples<float>,
in_dims[kClassDim]); platform::CPUPlace>::Cache()
.At(in_dims[kClassDim]);
compute_softmax(in_data, out_data, in_dims[kClassDim], in_dims[kBatchDim]); compute_softmax(in_data, out_data, in_dims[kClassDim], in_dims[kBatchDim]);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册