提交 92201d39 编写于 作者: T tensor-tang

support avg and sqrt pool and add mkl impl

test=develop
上级 c50060bb
...@@ -9,3 +9,4 @@ USE_JITKERNEL_MORE(kVScal, mkl) ...@@ -9,3 +9,4 @@ USE_JITKERNEL_MORE(kVScal, mkl)
USE_JITKERNEL_MORE(kVExp, mkl) USE_JITKERNEL_MORE(kVExp, mkl)
USE_JITKERNEL_MORE(kVSigmoid, mkl) USE_JITKERNEL_MORE(kVSigmoid, mkl)
USE_JITKERNEL_MORE(kVTanh, mkl) USE_JITKERNEL_MORE(kVTanh, mkl)
USE_JITKERNEL_MORE(kSeqPool, mkl)
...@@ -72,6 +72,26 @@ void VExp<double>(const double* x, double* y, int n) { ...@@ -72,6 +72,26 @@ void VExp<double>(const double* x, double* y, int n) {
platform::dynload::vdExp(n, x, y); platform::dynload::vdExp(n, x, y);
} }
template <>
void VCopy<float>(const float* x, float* y, int n) {
platform::dynload::cblas_scopy(n, x, 1, y, 1);
}
template <>
void VCopy<double>(const double* x, double* y, int n) {
platform::dynload::cblas_dcopy(n, x, 1, y, 1);
}
template <>
void VAXPY<float>(float a, const float* x, float* y, int n) {
platform::dynload::cblas_saxpy(n, a, x, 1, y, 1);
}
template <>
void VAXPY<double>(double a, const double* x, double* y, int n) {
platform::dynload::cblas_daxpy(n, a, x, 1, y, 1);
}
// TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512 // TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512
template <> template <>
bool VMulKernel<float>::UseMe(const int& d) const { bool VMulKernel<float>::UseMe(const int& d) const {
...@@ -103,6 +123,16 @@ bool VTanhKernel<float>::UseMe(const int& d) const { ...@@ -103,6 +123,16 @@ bool VTanhKernel<float>::UseMe(const int& d) const {
return d > 7; return d > 7;
} }
template <>
bool SeqPoolKernel<float>::UseMe(const seq_pool_attr_t& attr) const {
return true;
}
template <>
bool SeqPoolKernel<double>::UseMe(const seq_pool_attr_t& attr) const {
return true;
}
#define AWALYS_USE_ME_WITH_DOUBLE(func) \ #define AWALYS_USE_ME_WITH_DOUBLE(func) \
template <> \ template <> \
bool func##Kernel<double>::UseMe(const int& d) const { \ bool func##Kernel<double>::UseMe(const int& d) const { \
...@@ -135,5 +165,6 @@ REGISTER_MKL_KERNEL(kVScal, VScal); ...@@ -135,5 +165,6 @@ REGISTER_MKL_KERNEL(kVScal, VScal);
REGISTER_MKL_KERNEL(kVExp, VExp); REGISTER_MKL_KERNEL(kVExp, VExp);
REGISTER_MKL_KERNEL(kVSigmoid, VSigmoid); REGISTER_MKL_KERNEL(kVSigmoid, VSigmoid);
REGISTER_MKL_KERNEL(kVTanh, VTanh); REGISTER_MKL_KERNEL(kVTanh, VTanh);
REGISTER_MKL_KERNEL(kSeqPool, SeqPool);
#undef REGISTER_MKL_KERNEL #undef REGISTER_MKL_KERNEL
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <cmath>
#include <type_traits> #include <type_traits>
#include "paddle/fluid/operators/jit/kernel_base.h" #include "paddle/fluid/operators/jit/kernel_base.h"
...@@ -35,6 +36,12 @@ void VScal(const T* a, const T* x, T* y, int n); ...@@ -35,6 +36,12 @@ void VScal(const T* a, const T* x, T* y, int n);
template <typename T> template <typename T>
void VExp(const T* x, T* y, int n); void VExp(const T* x, T* y, int n);
template <typename T>
void VCopy(const T* x, T* y, int n);
template <typename T>
void VAXPY(T a, const T* x, T* y, int n);
template <typename T> template <typename T>
void VSigmoid(const T* x, T* y, int n) { void VSigmoid(const T* x, T* y, int n) {
const T min = SIGMOID_THRESHOLD_MIN; const T min = SIGMOID_THRESHOLD_MIN;
...@@ -60,6 +67,23 @@ void VTanh(const T* x, T* y, int n) { ...@@ -60,6 +67,23 @@ void VTanh(const T* x, T* y, int n) {
} }
} }
template <typename T>
void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) {
VCopy<T>(x, y, attr->w);
for (int h = 1; h != attr->h; ++h) {
VAXPY<T>(static_cast<T>(1), x + h * attr->w, y, attr->w);
}
if (attr->type == SeqPoolType::kAvg || attr->type == SeqPoolType::kSqrt) {
T scalar = static_cast<T>(1);
if (attr->type == SeqPoolType::kAvg) {
scalar = scalar / static_cast<T>(attr->h);
} else {
scalar = scalar / std::sqrt(static_cast<T>(attr->h));
}
VScal<T>(&scalar, y, y, attr->w);
}
}
#define DECLARE_MKL_KERNEL(name, tuples) \ #define DECLARE_MKL_KERNEL(name, tuples) \
template <typename T> \ template <typename T> \
class name##Kernel : public KernelMore<tuples<T>> { \ class name##Kernel : public KernelMore<tuples<T>> { \
...@@ -81,6 +105,8 @@ DECLARE_MKL_KERNEL(VExp, XYNTuples); ...@@ -81,6 +105,8 @@ DECLARE_MKL_KERNEL(VExp, XYNTuples);
DECLARE_MKL_KERNEL(VSigmoid, XYNTuples); DECLARE_MKL_KERNEL(VSigmoid, XYNTuples);
DECLARE_MKL_KERNEL(VTanh, XYNTuples); DECLARE_MKL_KERNEL(VTanh, XYNTuples);
DECLARE_MKL_KERNEL(SeqPool, SeqPoolTuples);
#undef DECLARE_MKL_KERNEL #undef DECLARE_MKL_KERNEL
} // namespace mkl } // namespace mkl
......
...@@ -344,6 +344,15 @@ void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) { ...@@ -344,6 +344,15 @@ void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) {
src += attr->w; src += attr->w;
} }
} }
if (attr->type == SeqPoolType::kAvg || attr->type == SeqPoolType::kSqrt) {
T scalar = static_cast<T>(1);
if (attr->type == SeqPoolType::kAvg) {
scalar = scalar / static_cast<T>(attr->h);
} else {
scalar = scalar / std::sqrt(static_cast<T>(attr->h));
}
VScal<T>(&scalar, y, y, attr->w);
}
} }
#define DECLARE_REFER_KERNEL(name, tuples) \ #define DECLARE_REFER_KERNEL(name, tuples) \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册