diff --git a/paddle/fluid/operators/jit/kernel_base.h b/paddle/fluid/operators/jit/kernel_base.h index b4a2d5d47301a2fd82bf27ddfaaa31ef23e431c2..8f13fbb16e9200fc082d3aa8901e494682a54278 100644 --- a/paddle/fluid/operators/jit/kernel_base.h +++ b/paddle/fluid/operators/jit/kernel_base.h @@ -41,6 +41,7 @@ typedef enum { kCRFDecoding, kLayerNorm, kNCHW16CMulNC, + kSeqPool, } KernelType; template @@ -112,6 +113,25 @@ struct GRUTuples { typedef void (*func_type)(gru_t*, const gru_attr_t*); }; +typedef enum { + non = 0, + sum, + avg, + sqrt, +} SeqPoolType; + +typedef struct { + int h, w; + SeqPoolType type; +} seq_pool_attr_t; + +template +struct SeqPoolTuples { + typedef T data_type; + typedef seq_pool_attr_t attr_type; + typedef void (*func_type)(const T*, T*, const seq_pool_attr_t*); +}; + template struct CRFDecodingTuples { typedef T data_type; diff --git a/paddle/fluid/operators/jit/kernel_key.cc b/paddle/fluid/operators/jit/kernel_key.cc index 4e6a19f04fd425b920aeea49b63001941d800a73..6b0025a75a4a6bba53e695c344e57e9e325a64fa 100644 --- a/paddle/fluid/operators/jit/kernel_key.cc +++ b/paddle/fluid/operators/jit/kernel_key.cc @@ -42,6 +42,12 @@ size_t JitCodeKey(const gru_attr_t& attr) { (static_cast(attr.act_cand) << act_type_shift); } +template <> +size_t JitCodeKey(const seq_pool_attr_t& attr) { + size_t key = static_cast(attr.type); + return key + (attr.w << act_type_shift); +} + } // namespace jit } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/jit/refer/CMakeLists.txt b/paddle/fluid/operators/jit/refer/CMakeLists.txt index 07497b732050a7299e224531db37eb56e60ef605..0f626bb3bfd2851e3fb6ad8265169f9bb9860851 100644 --- a/paddle/fluid/operators/jit/refer/CMakeLists.txt +++ b/paddle/fluid/operators/jit/refer/CMakeLists.txt @@ -26,3 +26,4 @@ USE_JITKERNEL_REFER(kGRUHtPart2) USE_JITKERNEL_REFER(kCRFDecoding) USE_JITKERNEL_REFER(kLayerNorm) USE_JITKERNEL_REFER(kNCHW16CMulNC) +USE_JITKERNEL_REFER(kSeqPool) diff --git a/paddle/fluid/operators/jit/refer/refer.cc b/paddle/fluid/operators/jit/refer/refer.cc index d196266326b4ee668f647fa51032f6344d26e5c6..85381daa47484a4053326f04e12d583543a423e0 100644 --- a/paddle/fluid/operators/jit/refer/refer.cc +++ b/paddle/fluid/operators/jit/refer/refer.cc @@ -47,4 +47,6 @@ REGISTER_REFER_KERNEL(kLayerNorm, LayerNorm); REGISTER_REFER_KERNEL(kNCHW16CMulNC, NCHW16CMulNC); +REGISTER_REFER_KERNEL(kSeqPool, SeqPool); + #undef REGISTER_REFER_KERNEL diff --git a/paddle/fluid/operators/jit/refer/refer.h b/paddle/fluid/operators/jit/refer/refer.h index 0fd1b89dfdba9f4655f649fa6d32604188c78da3..52fe2de02a8731ee93ec684f883276dc8edceb72 100644 --- a/paddle/fluid/operators/jit/refer/refer.h +++ b/paddle/fluid/operators/jit/refer/refer.h @@ -332,6 +332,20 @@ void NCHW16CMulNC(const T* x, const T* y, T* z, int height, int width) { } } +template +void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) { + PADDLE_ENFORCE(attr->type == SeqPoolType::sum, "Only support sum yet"); + for (int w = 0; w < attr->w; ++w) { + const T* src = x + w; + T* dst = y + w; + *dst = static_cast(0); + for (int h = 0; h < attr->h; ++h) { + *dst = *dst + *src; + src += attr->w; + } + } +} + #define DECLARE_REFER_KERNEL(name, tuples) \ template \ class name##Kernel : public ReferKernel> { \ @@ -370,6 +384,8 @@ DECLARE_REFER_KERNEL(LayerNorm, LayerNormTuples); DECLARE_REFER_KERNEL(NCHW16CMulNC, NCHW16CMulNCTuples); +DECLARE_REFER_KERNEL(SeqPool, SeqPoolTuples); + #undef DECLARE_REFER_KERNEL } // namespace refer