From 3e01a4048f28ad5cf4b33fb808b07965d9e7ff5d Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Fri, 28 Dec 2018 16:34:13 +0000 Subject: [PATCH] add refer seqpool jitkernel --- paddle/fluid/operators/jit/kernel_base.h | 20 +++++++++++++++++++ paddle/fluid/operators/jit/kernel_key.cc | 6 ++++++ .../fluid/operators/jit/refer/CMakeLists.txt | 1 + paddle/fluid/operators/jit/refer/refer.cc | 2 ++ paddle/fluid/operators/jit/refer/refer.h | 16 +++++++++++++++ 5 files changed, 45 insertions(+) diff --git a/paddle/fluid/operators/jit/kernel_base.h b/paddle/fluid/operators/jit/kernel_base.h index b4a2d5d47..8f13fbb16 100644 --- a/paddle/fluid/operators/jit/kernel_base.h +++ b/paddle/fluid/operators/jit/kernel_base.h @@ -41,6 +41,7 @@ typedef enum { kCRFDecoding, kLayerNorm, kNCHW16CMulNC, + kSeqPool, } KernelType; template @@ -112,6 +113,25 @@ struct GRUTuples { typedef void (*func_type)(gru_t*, const gru_attr_t*); }; +typedef enum { + non = 0, + sum, + avg, + sqrt, +} SeqPoolType; + +typedef struct { + int h, w; + SeqPoolType type; +} seq_pool_attr_t; + +template +struct SeqPoolTuples { + typedef T data_type; + typedef seq_pool_attr_t attr_type; + typedef void (*func_type)(const T*, T*, const seq_pool_attr_t*); +}; + template struct CRFDecodingTuples { typedef T data_type; diff --git a/paddle/fluid/operators/jit/kernel_key.cc b/paddle/fluid/operators/jit/kernel_key.cc index 4e6a19f04..6b0025a75 100644 --- a/paddle/fluid/operators/jit/kernel_key.cc +++ b/paddle/fluid/operators/jit/kernel_key.cc @@ -42,6 +42,12 @@ size_t JitCodeKey(const gru_attr_t& attr) { (static_cast(attr.act_cand) << act_type_shift); } +template <> +size_t JitCodeKey(const seq_pool_attr_t& attr) { + size_t key = static_cast(attr.type); + return key + (attr.w << act_type_shift); +} + } // namespace jit } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/jit/refer/CMakeLists.txt b/paddle/fluid/operators/jit/refer/CMakeLists.txt index 07497b732..0f626bb3b 100644 --- a/paddle/fluid/operators/jit/refer/CMakeLists.txt +++ b/paddle/fluid/operators/jit/refer/CMakeLists.txt @@ -26,3 +26,4 @@ USE_JITKERNEL_REFER(kGRUHtPart2) USE_JITKERNEL_REFER(kCRFDecoding) USE_JITKERNEL_REFER(kLayerNorm) USE_JITKERNEL_REFER(kNCHW16CMulNC) +USE_JITKERNEL_REFER(kSeqPool) diff --git a/paddle/fluid/operators/jit/refer/refer.cc b/paddle/fluid/operators/jit/refer/refer.cc index d19626632..85381daa4 100644 --- a/paddle/fluid/operators/jit/refer/refer.cc +++ b/paddle/fluid/operators/jit/refer/refer.cc @@ -47,4 +47,6 @@ REGISTER_REFER_KERNEL(kLayerNorm, LayerNorm); REGISTER_REFER_KERNEL(kNCHW16CMulNC, NCHW16CMulNC); +REGISTER_REFER_KERNEL(kSeqPool, SeqPool); + #undef REGISTER_REFER_KERNEL diff --git a/paddle/fluid/operators/jit/refer/refer.h b/paddle/fluid/operators/jit/refer/refer.h index 0fd1b89df..52fe2de02 100644 --- a/paddle/fluid/operators/jit/refer/refer.h +++ b/paddle/fluid/operators/jit/refer/refer.h @@ -332,6 +332,20 @@ void NCHW16CMulNC(const T* x, const T* y, T* z, int height, int width) { } } +template +void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) { + PADDLE_ENFORCE(attr->type == SeqPoolType::sum, "Only support sum yet"); + for (int w = 0; w < attr->w; ++w) { + const T* src = x + w; + T* dst = y + w; + *dst = static_cast(0); + for (int h = 0; h < attr->h; ++h) { + *dst = *dst + *src; + src += attr->w; + } + } +} + #define DECLARE_REFER_KERNEL(name, tuples) \ template \ class name##Kernel : public ReferKernel> { \ @@ -370,6 +384,8 @@ DECLARE_REFER_KERNEL(LayerNorm, LayerNormTuples); DECLARE_REFER_KERNEL(NCHW16CMulNC, NCHW16CMulNCTuples); +DECLARE_REFER_KERNEL(SeqPool, SeqPoolTuples); + #undef DECLARE_REFER_KERNEL } // namespace refer -- GitLab