diff --git a/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt b/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt index 863cc720d68ce3dcfe045aa11c559a06a50909f3..f5ed2f0572176e42b774259c2b8fe9713d989417 100644 --- a/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt +++ b/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt @@ -9,3 +9,4 @@ USE_JITKERNEL_MORE(kVScal, mkl) USE_JITKERNEL_MORE(kVExp, mkl) USE_JITKERNEL_MORE(kVSigmoid, mkl) USE_JITKERNEL_MORE(kVTanh, mkl) +USE_JITKERNEL_MORE(kSeqPool, mkl) diff --git a/paddle/fluid/operators/jit/more/mkl/mkl.cc b/paddle/fluid/operators/jit/more/mkl/mkl.cc index a5b088d4812b8a54e3b4fb1cb83d9e8bc7501994..5a499ac2c02aa70d2824f0d3be618e083ba10334 100644 --- a/paddle/fluid/operators/jit/more/mkl/mkl.cc +++ b/paddle/fluid/operators/jit/more/mkl/mkl.cc @@ -72,6 +72,26 @@ void VExp(const double* x, double* y, int n) { platform::dynload::vdExp(n, x, y); } +template <> +void VCopy(const float* x, float* y, int n) { + platform::dynload::cblas_scopy(n, x, 1, y, 1); +} + +template <> +void VCopy(const double* x, double* y, int n) { + platform::dynload::cblas_dcopy(n, x, 1, y, 1); +} + +template <> +void VAXPY(float a, const float* x, float* y, int n) { + platform::dynload::cblas_saxpy(n, a, x, 1, y, 1); +} + +template <> +void VAXPY(double a, const double* x, double* y, int n) { + platform::dynload::cblas_daxpy(n, a, x, 1, y, 1); +} + // TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512 template <> bool VMulKernel::UseMe(const int& d) const { @@ -103,6 +123,16 @@ bool VTanhKernel::UseMe(const int& d) const { return d > 7; } +template <> +bool SeqPoolKernel::UseMe(const seq_pool_attr_t& attr) const { + return true; +} + +template <> +bool SeqPoolKernel::UseMe(const seq_pool_attr_t& attr) const { + return true; +} + #define AWALYS_USE_ME_WITH_DOUBLE(func) \ template <> \ bool func##Kernel::UseMe(const int& d) const { \ @@ -135,5 +165,6 @@ REGISTER_MKL_KERNEL(kVScal, VScal); REGISTER_MKL_KERNEL(kVExp, VExp); REGISTER_MKL_KERNEL(kVSigmoid, VSigmoid); REGISTER_MKL_KERNEL(kVTanh, VTanh); +REGISTER_MKL_KERNEL(kSeqPool, SeqPool); #undef REGISTER_MKL_KERNEL diff --git a/paddle/fluid/operators/jit/more/mkl/mkl.h b/paddle/fluid/operators/jit/more/mkl/mkl.h index ee1031c028ff72181f504004b7cbeb9f7ee578f1..0a3816db24ccd0820cb259b40044e1f5b66665f7 100644 --- a/paddle/fluid/operators/jit/more/mkl/mkl.h +++ b/paddle/fluid/operators/jit/more/mkl/mkl.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include "paddle/fluid/operators/jit/kernel_base.h" @@ -35,6 +36,12 @@ void VScal(const T* a, const T* x, T* y, int n); template void VExp(const T* x, T* y, int n); +template +void VCopy(const T* x, T* y, int n); + +template +void VAXPY(T a, const T* x, T* y, int n); + template void VSigmoid(const T* x, T* y, int n) { const T min = SIGMOID_THRESHOLD_MIN; @@ -60,6 +67,23 @@ void VTanh(const T* x, T* y, int n) { } } +template +void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) { + VCopy(x, y, attr->w); + for (int h = 1; h != attr->h; ++h) { + VAXPY(static_cast(1), x + h * attr->w, y, attr->w); + } + if (attr->type == SeqPoolType::kAvg || attr->type == SeqPoolType::kSqrt) { + T scalar = static_cast(1); + if (attr->type == SeqPoolType::kAvg) { + scalar = scalar / static_cast(attr->h); + } else { + scalar = scalar / std::sqrt(static_cast(attr->h)); + } + VScal(&scalar, y, y, attr->w); + } +} + #define DECLARE_MKL_KERNEL(name, tuples) \ template \ class name##Kernel : public KernelMore> { \ @@ -81,6 +105,8 @@ DECLARE_MKL_KERNEL(VExp, XYNTuples); DECLARE_MKL_KERNEL(VSigmoid, XYNTuples); DECLARE_MKL_KERNEL(VTanh, XYNTuples); +DECLARE_MKL_KERNEL(SeqPool, SeqPoolTuples); + #undef DECLARE_MKL_KERNEL } // namespace mkl diff --git a/paddle/fluid/operators/jit/refer/refer.h b/paddle/fluid/operators/jit/refer/refer.h index c2aa9225284a215e9c4ad1f27fffdee9857bf78f..4e19783c862f18c344a2807fa06938a603bfc7f1 100644 --- a/paddle/fluid/operators/jit/refer/refer.h +++ b/paddle/fluid/operators/jit/refer/refer.h @@ -344,6 +344,15 @@ void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) { src += attr->w; } } + if (attr->type == SeqPoolType::kAvg || attr->type == SeqPoolType::kSqrt) { + T scalar = static_cast(1); + if (attr->type == SeqPoolType::kAvg) { + scalar = scalar / static_cast(attr->h); + } else { + scalar = scalar / std::sqrt(static_cast(attr->h)); + } + VScal(&scalar, y, y, attr->w); + } } #define DECLARE_REFER_KERNEL(name, tuples) \