提交 e6d8aca3 编写于 作者: T tensor-tang

refine code and fix

上级 ea7dc9cb
...@@ -64,32 +64,32 @@ class KernelPool { ...@@ -64,32 +64,32 @@ class KernelPool {
template <typename T> template <typename T>
class VMulKernel : public Kernel { class VMulKernel : public Kernel {
public: public:
virtual void Compute(const int n, const T *x, const T *y, T *z) const = 0; virtual void Compute(const T *x, const T *y, T *z) const = 0;
}; };
template <typename T> template <typename T>
class VAddKernel : public Kernel { class VAddKernel : public Kernel {
public: public:
virtual void Compute(const int n, const T *x, const T *y, T *z) const = 0; virtual void Compute(const T *x, const T *y, T *z) const = 0;
}; };
template <typename T> template <typename T>
class VScalKernel : public Kernel { class VScalKernel : public Kernel {
public: public:
virtual void Compute(const int n, const T a, const T *x, T *y) const = 0; virtual void Compute(const T a, const T *x, T *y) const = 0;
virtual void Compute(const int n, const T a, T *x) const = 0; virtual void Compute(const T a, T *x) const = 0;
}; };
template <typename T> template <typename T>
class VAddBiasKernel : public Kernel { class VAddBiasKernel : public Kernel {
public: public:
virtual void Compute(const int n, const T a, const T *x, T *y) const = 0; virtual void Compute(const T a, const T *x, T *y) const = 0;
}; };
template <typename T> template <typename T>
class VExpKernel : public Kernel { class VExpKernel : public Kernel {
public: public:
virtual void Compute(const int n, const T *x, T *y) const = 0; virtual void Compute(const T *x, T *y) const = 0;
}; };
template <typename T> template <typename T>
......
...@@ -34,41 +34,42 @@ namespace jit = platform::jit; ...@@ -34,41 +34,42 @@ namespace jit = platform::jit;
template <typename T, platform::jit::cpu_isa_t isa, jit_block> template <typename T, platform::jit::cpu_isa_t isa, jit_block>
class VMulKernelImpl : public VMulKernel<T> { class VMulKernelImpl : public VMulKernel<T> {
public: public:
void Compute(const int n, const T* x, const T* y, T* z) const override { explicit VMulKernelImpl(int d) : VMulKernel<T>() { this->num_ = d; }
for (int i = 0; i < n; ++i) { void Compute(const T* x, const T* y, T* z) const override {
for (int i = 0; i < this->num_; ++i) {
z[i] = x[i] * y[i]; z[i] = x[i] * y[i];
} }
} }
}; };
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#define MKL_FLOAT(isa, block) \ #define MKL_FLOAT(isa, block) \
template <> \ template <> \
void VMulKernelImpl<float, isa, block>::Compute( \ void VMulKernelImpl<float, isa, block>::Compute( \
const int n, const float* x, const float* y, float* z) const { \ const float* x, const float* y, float* z) const { \
platform::dynload::vsMul(n, x, y, z); \ platform::dynload::vsMul(this->num_, x, y, z); \
} }
#define MKL_DOUBLE(isa, block) \ #define MKL_DOUBLE(isa, block) \
template <> \ template <> \
void VMulKernelImpl<double, isa, block>::Compute( \ void VMulKernelImpl<double, isa, block>::Compute( \
const int n, const double* x, const double* y, double* z) const { \ const double* x, const double* y, double* z) const { \
platform::dynload::vdMul(n, x, y, z); \ platform::dynload::vdMul(this->num_, x, y, z); \
} }
FOR_EACH_ISA(MKL_FLOAT, kGT16); FOR_EACH_ISA(MKL_FLOAT, kGT16);
FOR_EACH_ISA_BLOCK(MKL_DOUBLE); FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
#endif #endif
#define INTRI8_FLOAT(isa) \ #define INTRI8_FLOAT(isa) \
template <> \ template <> \
void VMulKernelImpl<float, isa, kEQ8>::Compute( \ void VMulKernelImpl<float, isa, kEQ8>::Compute( \
const int n, const float* x, const float* y, float* z) const { \ const float* x, const float* y, float* z) const { \
__m256 tmpx, tmpy; \ __m256 tmpx, tmpy; \
tmpx = _mm256_loadu_ps(x); \ tmpx = _mm256_loadu_ps(x); \
tmpy = _mm256_loadu_ps(y); \ tmpy = _mm256_loadu_ps(y); \
tmpx = _mm256_mul_ps(tmpx, tmpy); \ tmpx = _mm256_mul_ps(tmpx, tmpy); \
_mm256_storeu_ps(z, tmpx); \ _mm256_storeu_ps(z, tmpx); \
} }
// avx > for > mkl // avx > for > mkl
...@@ -90,41 +91,42 @@ INTRI8_FLOAT(jit::avx512f); ...@@ -90,41 +91,42 @@ INTRI8_FLOAT(jit::avx512f);
template <typename T, platform::jit::cpu_isa_t isa, jit_block> template <typename T, platform::jit::cpu_isa_t isa, jit_block>
class VAddKernelImpl : public VAddKernel<T> { class VAddKernelImpl : public VAddKernel<T> {
public: public:
void Compute(const int n, const T* x, const T* y, T* z) const override { explicit VAddKernelImpl(int d) : VAddKernel<T>() { this->num_ = d; }
for (int i = 0; i < n; ++i) { void Compute(const T* x, const T* y, T* z) const override {
for (int i = 0; i < this->num_; ++i) {
z[i] = x[i] + y[i]; z[i] = x[i] + y[i];
} }
} }
}; };
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#define MKL_FLOAT(isa, block) \ #define MKL_FLOAT(isa, block) \
template <> \ template <> \
void VAddKernelImpl<float, isa, block>::Compute( \ void VAddKernelImpl<float, isa, block>::Compute( \
const int n, const float* x, const float* y, float* z) const { \ const float* x, const float* y, float* z) const { \
platform::dynload::vsAdd(n, x, y, z); \ platform::dynload::vsAdd(this->num_, x, y, z); \
} }
#define MKL_DOUBLE(isa, block) \ #define MKL_DOUBLE(isa, block) \
template <> \ template <> \
void VAddKernelImpl<double, isa, block>::Compute( \ void VAddKernelImpl<double, isa, block>::Compute( \
const int n, const double* x, const double* y, double* z) const { \ const double* x, const double* y, double* z) const { \
platform::dynload::vdAdd(n, x, y, z); \ platform::dynload::vdAdd(this->num_, x, y, z); \
} }
FOR_EACH_ISA(MKL_FLOAT, kGT16); FOR_EACH_ISA(MKL_FLOAT, kGT16);
FOR_EACH_ISA_BLOCK(MKL_DOUBLE); FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
#endif #endif
#define INTRI8_FLOAT(isa) \ #define INTRI8_FLOAT(isa) \
template <> \ template <> \
void VAddKernelImpl<float, isa, kEQ8>::Compute( \ void VAddKernelImpl<float, isa, kEQ8>::Compute( \
const int n, const float* x, const float* y, float* z) const { \ const float* x, const float* y, float* z) const { \
__m256 tmpx, tmpy; \ __m256 tmpx, tmpy; \
tmpx = _mm256_loadu_ps(x); \ tmpx = _mm256_loadu_ps(x); \
tmpy = _mm256_loadu_ps(y); \ tmpy = _mm256_loadu_ps(y); \
tmpx = _mm256_add_ps(tmpx, tmpy); \ tmpx = _mm256_add_ps(tmpx, tmpy); \
_mm256_storeu_ps(z, tmpx); \ _mm256_storeu_ps(z, tmpx); \
} }
#ifdef __AVX__ #ifdef __AVX__
INTRI8_FLOAT(jit::avx); INTRI8_FLOAT(jit::avx);
...@@ -145,56 +147,57 @@ INTRI8_FLOAT(jit::avx512f); ...@@ -145,56 +147,57 @@ INTRI8_FLOAT(jit::avx512f);
template <typename T, platform::jit::cpu_isa_t isa, jit_block> template <typename T, platform::jit::cpu_isa_t isa, jit_block>
class VScalKernelImpl : public VScalKernel<T> { class VScalKernelImpl : public VScalKernel<T> {
public: public:
void Compute(const int n, const T a, const T* x, T* y) const override { explicit VScalKernelImpl(int d) : VScalKernel<T>() { this->num_ = d; }
for (int i = 0; i < n; ++i) { void Compute(const T a, const T* x, T* y) const override {
for (int i = 0; i < this->num_; ++i) {
y[i] = a * x[i]; y[i] = a * x[i];
} }
} }
void Compute(const int n, const T a, T* x) const override { void Compute(const T a, T* x) const override {
for (int i = 0; i < n; ++i) { for (int i = 0; i < this->num_; ++i) {
x[i] = a * x[i]; x[i] = a * x[i];
} }
} }
}; };
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#define MKL_FLOAT(isa, block) \ #define MKL_FLOAT(isa, block) \
template <> \ template <> \
void VScalKernelImpl<float, isa, block>::Compute(const int n, const float a, \ void VScalKernelImpl<float, isa, block>::Compute(const float a, float* x) \
float* x) const { \ const { \
platform::dynload::cblas_sscal(n, a, x, 1); \ platform::dynload::cblas_sscal(this->num_, a, x, 1); \
} }
#define MKL_DOUBLE(isa, block) \ #define MKL_DOUBLE(isa, block) \
template <> \ template <> \
void VScalKernelImpl<double, isa, block>::Compute( \ void VScalKernelImpl<double, isa, block>::Compute(const double a, double* x) \
const int n, const double a, double* x) const { \ const { \
platform::dynload::cblas_dscal(n, a, x, 1); \ platform::dynload::cblas_dscal(this->num_, a, x, 1); \
} }
FOR_EACH_ISA(MKL_FLOAT, kGT16); FOR_EACH_ISA(MKL_FLOAT, kGT16);
FOR_EACH_ISA_BLOCK(MKL_DOUBLE); FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
#endif #endif
#define INTRI8_FLOAT(isa) \ #define INTRI8_FLOAT(isa) \
template <> \ template <> \
void VScalKernelImpl<float, isa, kEQ8>::Compute( \ void VScalKernelImpl<float, isa, kEQ8>::Compute( \
const int n, const float a, const float* x, float* y) const { \ const float a, const float* x, float* y) const { \
__m256 tmp; \ __m256 tmp; \
__m256 scalar = _mm256_set1_ps(a); \ __m256 scalar = _mm256_set1_ps(a); \
tmp = _mm256_loadu_ps(x); \ tmp = _mm256_loadu_ps(x); \
tmp = _mm256_mul_ps(tmp, scalar); \ tmp = _mm256_mul_ps(tmp, scalar); \
_mm256_storeu_ps(y, tmp); \ _mm256_storeu_ps(y, tmp); \
} }
#define INTRI8_INPLACE_FLOAT(isa) \ #define INTRI8_INPLACE_FLOAT(isa) \
template <> \ template <> \
void VScalKernelImpl<float, isa, kEQ8>::Compute(const int n, const float a, \ void VScalKernelImpl<float, isa, kEQ8>::Compute(const float a, float* x) \
float* x) const { \ const { \
__m256 tmp; \ __m256 tmp; \
__m256 scalar = _mm256_set1_ps(a); \ __m256 scalar = _mm256_set1_ps(a); \
tmp = _mm256_loadu_ps(x); \ tmp = _mm256_loadu_ps(x); \
tmp = _mm256_mul_ps(tmp, scalar); \ tmp = _mm256_mul_ps(tmp, scalar); \
_mm256_storeu_ps(x, tmp); \ _mm256_storeu_ps(x, tmp); \
} }
#ifdef __AVX__ #ifdef __AVX__
...@@ -220,32 +223,33 @@ INTRI8_INPLACE_FLOAT(jit::avx512f); ...@@ -220,32 +223,33 @@ INTRI8_INPLACE_FLOAT(jit::avx512f);
template <typename T, platform::jit::cpu_isa_t isa, jit_block> template <typename T, platform::jit::cpu_isa_t isa, jit_block>
class VAddBiasKernelImpl : public VAddBiasKernel<T> { class VAddBiasKernelImpl : public VAddBiasKernel<T> {
public: public:
void Compute(const int n, const T a, const T* x, T* y) const override { explicit VAddBiasKernelImpl(int d) : VAddBiasKernel<T>() { this->num_ = d; }
for (int i = 0; i < n; ++i) { void Compute(const T a, const T* x, T* y) const override {
for (int i = 0; i < this->num_; ++i) {
y[i] = x[i] + a; y[i] = x[i] + a;
} }
} }
}; };
#define INTRI8_FLOAT(isa) \ #define INTRI8_FLOAT(isa) \
template <> \ template <> \
void VAddBiasKernelImpl<float, isa, kEQ8>::Compute( \ void VAddBiasKernelImpl<float, isa, kEQ8>::Compute( \
const int n, const float a, const float* x, float* y) const { \ const float a, const float* x, float* y) const { \
__m256 tmp = _mm256_loadu_ps(x); \ __m256 tmp = _mm256_loadu_ps(x); \
tmp = _mm256_add_ps(tmp, _mm256_set1_ps(a)); \ tmp = _mm256_add_ps(tmp, _mm256_set1_ps(a)); \
_mm256_storeu_ps(y, tmp); \ _mm256_storeu_ps(y, tmp); \
} }
#define INTRI16_FLOAT(isa) \ #define INTRI16_FLOAT(isa) \
template <> \ template <> \
void VAddBiasKernelImpl<float, isa, kEQ16>::Compute( \ void VAddBiasKernelImpl<float, isa, kEQ16>::Compute( \
const int n, const float a, const float* x, float* y) const { \ const float a, const float* x, float* y) const { \
__m256 tmp0 = _mm256_loadu_ps(x); \ __m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \ __m256 tmp1 = _mm256_loadu_ps(x + 8); \
tmp0 = _mm256_add_ps(tmp0, _mm256_set1_ps(a)); \ tmp0 = _mm256_add_ps(tmp0, _mm256_set1_ps(a)); \
tmp1 = _mm256_add_ps(tmp1, _mm256_set1_ps(a)); \ tmp1 = _mm256_add_ps(tmp1, _mm256_set1_ps(a)); \
_mm256_storeu_ps(y, tmp0); \ _mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y + 8, tmp1); \ _mm256_storeu_ps(y + 8, tmp1); \
} }
#ifdef __AVX__ #ifdef __AVX__
......
...@@ -40,26 +40,27 @@ namespace jit = platform::jit; ...@@ -40,26 +40,27 @@ namespace jit = platform::jit;
template <typename T, jit::cpu_isa_t isa, jit_block> template <typename T, jit::cpu_isa_t isa, jit_block>
class VExpKernelImpl : public VExpKernel<T> { class VExpKernelImpl : public VExpKernel<T> {
public: public:
void Compute(const int n, const T* x, T* y) const override { explicit VExpKernelImpl(int d) : VExpKernel<T>() { this->num_ = d; }
for (int i = 0; i < n; ++i) { void Compute(const T* x, T* y) const override {
for (int i = 0; i < this->num_; ++i) {
y[i] = std::exp(x[i]); y[i] = std::exp(x[i]);
} }
} }
}; };
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#define MKL_FLOAT(isa, block) \ #define MKL_FLOAT(isa, block) \
template <> \ template <> \
void VExpKernelImpl<float, isa, block>::Compute(const int n, const float* x, \ void VExpKernelImpl<float, isa, block>::Compute(const float* x, float* y) \
float* y) const { \ const { \
platform::dynload::vsExp(n, x, y); \ platform::dynload::vsExp(this->num_, x, y); \
} }
#define MKL_DOUBLE(isa, block) \ #define MKL_DOUBLE(isa, block) \
template <> \ template <> \
void VExpKernelImpl<double, isa, block>::Compute( \ void VExpKernelImpl<double, isa, block>::Compute(const double* x, double* y) \
const int n, const double* x, double* y) const { \ const { \
platform::dynload::vdExp(n, x, y); \ platform::dynload::vdExp(this->num_, x, y); \
} }
FOR_EACH_ISA(MKL_FLOAT, kLT8); FOR_EACH_ISA(MKL_FLOAT, kLT8);
FOR_EACH_ISA(MKL_FLOAT, kGT8LT16); FOR_EACH_ISA(MKL_FLOAT, kGT8LT16);
...@@ -67,24 +68,24 @@ FOR_EACH_ISA(MKL_FLOAT, kGT16); ...@@ -67,24 +68,24 @@ FOR_EACH_ISA(MKL_FLOAT, kGT16);
FOR_EACH_ISA_BLOCK(MKL_DOUBLE); FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
#endif #endif
#define INTRI8_FLOAT(isa) \ #define INTRI8_FLOAT(isa) \
template <> \ template <> \
void VExpKernelImpl<float, isa, kEQ8>::Compute(const int n, const float* x, \ void VExpKernelImpl<float, isa, kEQ8>::Compute(const float* x, float* y) \
float* y) const { \ const { \
__m256 tmp = _mm256_loadu_ps(x); \ __m256 tmp = _mm256_loadu_ps(x); \
_mm256_storeu_ps(y, detail::Exp(tmp)); \ _mm256_storeu_ps(y, detail::Exp(tmp)); \
} }
#define INTRI16_FLOAT(isa) \ #define INTRI16_FLOAT(isa) \
template <> \ template <> \
void VExpKernelImpl<float, isa, kEQ16>::Compute(const int n, const float* x, \ void VExpKernelImpl<float, isa, kEQ16>::Compute(const float* x, float* y) \
float* y) const { \ const { \
__m256 tmp0 = _mm256_loadu_ps(x); \ __m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \ __m256 tmp1 = _mm256_loadu_ps(x + 8); \
tmp0 = detail::Exp(tmp0); \ tmp0 = detail::Exp(tmp0); \
tmp1 = detail::Exp(tmp1); \ tmp1 = detail::Exp(tmp1); \
_mm256_storeu_ps(y, tmp0); \ _mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y + 8, tmp1); \ _mm256_storeu_ps(y + 8, tmp1); \
} }
#ifdef __AVX__ #ifdef __AVX__
...@@ -123,7 +124,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> { ...@@ -123,7 +124,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = static_cast<T>(0) - y[i]; y[i] = static_cast<T>(0) - y[i];
} }
vexp_->Compute(this->num_, y, y); vexp_->Compute(y, y);
for (int i = 0; i < this->num_; ++i) { for (int i = 0; i < this->num_; ++i) {
y[i] = static_cast<T>(1) / (static_cast<T>(1) + y[i]); y[i] = static_cast<T>(1) / (static_cast<T>(1) + y[i]);
} }
...@@ -166,64 +167,66 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> { ...@@ -166,64 +167,66 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
_mm256_storeu_ps(y + 8, tmp1); \ _mm256_storeu_ps(y + 8, tmp1); \
} }
#define INTRI_GT8LT16_FLOAT(isa) \ #define INTRI_GT8LT16_FLOAT(isa) \
template <> \ template <> \
VSigmoidKernelImpl<float, isa, kGT8LT16>::VSigmoidKernelImpl(int d) \ VSigmoidKernelImpl<float, isa, kGT8LT16>::VSigmoidKernelImpl(int d) \
: VSigmoidKernel<float>() { \ : VSigmoidKernel<float>() { \
this->num_ = d; \ this->num_ = d; \
this->end_ = AVX_FLOAT_BLOCK; \ this->end_ = AVX_FLOAT_BLOCK; \
this->rest_ = d - this->end_; \ this->rest_ = d - this->end_; \
vexp_ = KernelPool::Instance().template Get<VExpKernel<float>>(d); \ vexp_ = \
} \ KernelPool::Instance().template Get<VExpKernel<float>>(this->rest_); \
template <> \ } \
void VSigmoidKernelImpl<float, isa, kGT8LT16>::Compute(const float* x, \ template <> \
float* y) const { \ void VSigmoidKernelImpl<float, isa, kGT8LT16>::Compute(const float* x, \
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ float* y) const { \
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \
__m256 tmp = _mm256_loadu_ps(x); \ __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
INTRI_SIGMOID(tmp, min, max); \ __m256 tmp = _mm256_loadu_ps(x); \
_mm256_storeu_ps(y, tmp); \ INTRI_SIGMOID(tmp, min, max); \
const float min_ = SIGMOID_THRESHOLD_MIN; \ _mm256_storeu_ps(y, tmp); \
const float max_ = SIGMOID_THRESHOLD_MAX; \ const float min_ = SIGMOID_THRESHOLD_MIN; \
for (int i = this->end_; i < this->num_; ++i) { \ const float max_ = SIGMOID_THRESHOLD_MAX; \
y[i] = (x[i] < min_) ? min_ : ((x[i] > max_) ? max_ : x[i]); \ for (int i = this->end_; i < this->num_; ++i) { \
y[i] = 0.f - y[i]; \ y[i] = (x[i] < min_) ? min_ : ((x[i] > max_) ? max_ : x[i]); \
} \ y[i] = 0.f - y[i]; \
vexp_->Compute(this->rest_, y + this->end_, y + this->end_); \ } \
for (int i = this->end_; i < this->num_; ++i) { \ vexp_->Compute(y + this->end_, y + this->end_); \
y[i] = 1.f / (1.f + y[i]); \ for (int i = this->end_; i < this->num_; ++i) { \
} \ y[i] = 1.f / (1.f + y[i]); \
} \
} }
#define INTRI_GT16_FLOAT(isa) \ #define INTRI_GT16_FLOAT(isa) \
template <> \ template <> \
VSigmoidKernelImpl<float, isa, kGT16>::VSigmoidKernelImpl(int d) \ VSigmoidKernelImpl<float, isa, kGT16>::VSigmoidKernelImpl(int d) \
: VSigmoidKernel<float>() { \ : VSigmoidKernel<float>() { \
this->num_ = d; \ this->num_ = d; \
this->rest_ = d % AVX_FLOAT_BLOCK; \ this->rest_ = d % AVX_FLOAT_BLOCK; \
this->end_ = d - this->rest_; \ this->end_ = d - this->rest_; \
vexp_ = KernelPool::Instance().template Get<VExpKernel<float>>(d); \ vexp_ = \
} \ KernelPool::Instance().template Get<VExpKernel<float>>(this->rest_); \
template <> \ } \
void VSigmoidKernelImpl<float, isa, kGT16>::Compute(const float* x, \ template <> \
float* y) const { \ void VSigmoidKernelImpl<float, isa, kGT16>::Compute(const float* x, \
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ float* y) const { \
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \
for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \ __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
__m256 tmp = _mm256_loadu_ps(x + i); \ for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \
INTRI_SIGMOID(tmp, min, max); \ __m256 tmp = _mm256_loadu_ps(x + i); \
_mm256_storeu_ps(y + i, tmp); \ INTRI_SIGMOID(tmp, min, max); \
} \ _mm256_storeu_ps(y + i, tmp); \
const float min_ = SIGMOID_THRESHOLD_MIN; \ } \
const float max_ = SIGMOID_THRESHOLD_MAX; \ const float min_ = SIGMOID_THRESHOLD_MIN; \
for (int i = this->end_; i < this->num_; ++i) { \ const float max_ = SIGMOID_THRESHOLD_MAX; \
y[i] = (x[i] < min_) ? min_ : ((x[i] > max_) ? max_ : x[i]); \ for (int i = this->end_; i < this->num_; ++i) { \
y[i] = 0.f - y[i]; \ y[i] = (x[i] < min_) ? min_ : ((x[i] > max_) ? max_ : x[i]); \
} \ y[i] = 0.f - y[i]; \
vexp_->Compute(this->rest_, y + this->end_, y + this->end_); \ } \
for (int i = this->end_; i < this->num_; ++i) { \ vexp_->Compute(y + this->end_, y + this->end_); \
y[i] = 1.f / (1.f + y[i]); \ for (int i = this->end_; i < this->num_; ++i) { \
} \ y[i] = 1.f / (1.f + y[i]); \
} \
} }
#ifdef __AVX__ #ifdef __AVX__
...@@ -251,12 +254,7 @@ INTRI16_FLOAT(jit::avx512f); ...@@ -251,12 +254,7 @@ INTRI16_FLOAT(jit::avx512f);
#undef INTRI_GT16_FLOAT #undef INTRI_GT16_FLOAT
#undef INTRI_VSIGMOID #undef INTRI_VSIGMOID
#define JITKERNEL_NEW_ACT_IMPL(ker, dtype, isa, k) \ REGISTER_JITKERNEL(vsigmoid, VSigmoidKernel);
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype, isa, k>>(d))
REGISTER_JITKERNEL_ARGS(vsigmoid, VSigmoidKernel, JITKERNEL_DECLARE,
JITKERNEL_KEY, JITKERNEL_NEW_ACT_IMPL);
/* VTanh JitKernel */ /* VTanh JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block> template <typename T, jit::cpu_isa_t isa, jit_block>
...@@ -269,10 +267,10 @@ class VTanhKernelImpl : public VTanhKernel<T> { ...@@ -269,10 +267,10 @@ class VTanhKernelImpl : public VTanhKernel<T> {
vaddbias_ = KernelPool::Instance().template Get<VAddBiasKernel<T>>(d); vaddbias_ = KernelPool::Instance().template Get<VAddBiasKernel<T>>(d);
} }
void Compute(const T* x, T* y) const override { void Compute(const T* x, T* y) const override {
vscal_->Compute(this->num_, static_cast<T>(2), x, y); vscal_->Compute(static_cast<T>(2), x, y);
vsigmoid_->Compute(y, y); vsigmoid_->Compute(y, y);
vscal_->Compute(this->num_, static_cast<T>(2), y); vscal_->Compute(static_cast<T>(2), y);
vaddbias_->Compute(this->num_, static_cast<T>(-1), y, y); vaddbias_->Compute(static_cast<T>(-1), y, y);
} }
private: private:
...@@ -332,10 +330,10 @@ class VTanhKernelImpl : public VTanhKernel<T> { ...@@ -332,10 +330,10 @@ class VTanhKernelImpl : public VTanhKernel<T> {
_mm256_storeu_ps(y, tmp); \ _mm256_storeu_ps(y, tmp); \
x += AVX_FLOAT_BLOCK; \ x += AVX_FLOAT_BLOCK; \
y += AVX_FLOAT_BLOCK; \ y += AVX_FLOAT_BLOCK; \
vscal_->Compute(this->rest_, 2.f, x, y); \ vscal_->Compute(2.f, x, y); \
vsigmoid_->Compute(y, y); \ vsigmoid_->Compute(y, y); \
vscal_->Compute(this->rest_, 2.f, y); \ vscal_->Compute(2.f, y); \
vaddbias_->Compute(this->rest_, -1.f, y, y); \ vaddbias_->Compute(-1.f, y, y); \
} }
#define INTRI_GT16_FLOAT(isa) \ #define INTRI_GT16_FLOAT(isa) \
...@@ -362,10 +360,10 @@ class VTanhKernelImpl : public VTanhKernel<T> { ...@@ -362,10 +360,10 @@ class VTanhKernelImpl : public VTanhKernel<T> {
} \ } \
x += this->end_; \ x += this->end_; \
y += this->end_; \ y += this->end_; \
vscal_->Compute(this->rest_, 2.f, x, y); \ vscal_->Compute(2.f, x, y); \
vsigmoid_->Compute(y, y); \ vsigmoid_->Compute(y, y); \
vscal_->Compute(this->rest_, 2.f, y); \ vscal_->Compute(2.f, y); \
vaddbias_->Compute(this->rest_, -1.f, y, y); \ vaddbias_->Compute(-1.f, y, y); \
} }
#ifdef __AVX__ #ifdef __AVX__
...@@ -391,8 +389,7 @@ INTRI16_FLOAT(jit::avx512f); ...@@ -391,8 +389,7 @@ INTRI16_FLOAT(jit::avx512f);
#undef INTRI_GT16_FLOAT #undef INTRI_GT16_FLOAT
#undef INTRI_VTANH #undef INTRI_VTANH
REGISTER_JITKERNEL_ARGS(vtanh, VTanhKernel, JITKERNEL_DECLARE, JITKERNEL_KEY, REGISTER_JITKERNEL(vtanh, VTanhKernel);
JITKERNEL_NEW_ACT_IMPL);
#undef JITKERNEL_NEW_ACT_IMPL #undef JITKERNEL_NEW_ACT_IMPL
......
...@@ -57,7 +57,7 @@ namespace jit = platform::jit; ...@@ -57,7 +57,7 @@ namespace jit = platform::jit;
#define JITKERNEL_NEW_IMPL(ker, dtype, isa, k) \ #define JITKERNEL_NEW_IMPL(ker, dtype, isa, k) \
p = std::dynamic_pointer_cast<ker<dtype>>( \ p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype, isa, k>>()) std::make_shared<ker##Impl<dtype, isa, k>>(d))
#define JITKERNEL_WITH_DTYPE(ker_key, ker_class, ker_dtype, dtype_key, \ #define JITKERNEL_WITH_DTYPE(ker_key, ker_class, ker_dtype, dtype_key, \
marco_declare, macro_key, macro_impl) \ marco_declare, macro_key, macro_impl) \
......
...@@ -73,7 +73,7 @@ TEST(JitKernel, vaddbias) { ...@@ -73,7 +73,7 @@ TEST(JitKernel, vaddbias) {
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
auto ttgts = GetCurrentUS(); auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
ker->Compute(d, a, x_data, ztgt_data); ker->Compute(a, x_data, ztgt_data);
} }
auto ttgte = GetCurrentUS(); auto ttgte = GetCurrentUS();
...@@ -99,7 +99,7 @@ void vexp_mkl(const int n, const float* x, float* y) { ...@@ -99,7 +99,7 @@ void vexp_mkl(const int n, const float* x, float* y) {
TEST(JitKernel, vexp) { TEST(JitKernel, vexp) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
for (int d : {7, 8, 15, 16, 30, 128}) { for (int d : {7, 8, 15, 16, 30, 128, 256}) {
std::vector<float> x(d); std::vector<float> x(d);
std::vector<float> zref(d), ztgt(d); std::vector<float> zref(d), ztgt(d);
RandomVec<float>(d, x.data(), -2.f, 2.f); RandomVec<float>(d, x.data(), -2.f, 2.f);
...@@ -124,7 +124,7 @@ TEST(JitKernel, vexp) { ...@@ -124,7 +124,7 @@ TEST(JitKernel, vexp) {
auto ttgts = GetCurrentUS(); auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
ker->Compute(d, x_data, ztgt_data); ker->Compute(x_data, ztgt_data);
} }
auto ttgte = GetCurrentUS(); auto ttgte = GetCurrentUS();
...@@ -164,7 +164,7 @@ void vsigmoid_better( ...@@ -164,7 +164,7 @@ void vsigmoid_better(
y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = 0.f - y[i]; y[i] = 0.f - y[i];
} }
vexp->Compute(n, y, y); vexp->Compute(y, y);
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
y[i] = 1.f / (1.f + y[i]); y[i] = 1.f / (1.f + y[i]);
} }
...@@ -226,10 +226,10 @@ void vtanh_better( ...@@ -226,10 +226,10 @@ void vtanh_better(
const paddle::operators::math::jitkernel::VAddBiasKernel<float>>& const paddle::operators::math::jitkernel::VAddBiasKernel<float>>&
vaddbias, vaddbias,
const int n, const float* x, float* y) { const int n, const float* x, float* y) {
vscal->Compute(n, 2.f, x, y); vscal->Compute(2.f, x, y);
vsigmoid->Compute(y, y); vsigmoid->Compute(y, y);
vscal->Compute(n, 2.f, y); vscal->Compute(2.f, y);
vaddbias->Compute(n, -1.f, y, y); vaddbias->Compute(-1.f, y, y);
} }
TEST(JitKernel, vtanh) { TEST(JitKernel, vtanh) {
...@@ -359,12 +359,12 @@ TEST(JitKernel, vscal) { ...@@ -359,12 +359,12 @@ TEST(JitKernel, vscal) {
auto ttgts = GetCurrentUS(); auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
ker->Compute(d, a, x_data, ztgt_data); ker->Compute(a, x_data, ztgt_data);
} }
auto ttgte = GetCurrentUS(); auto ttgte = GetCurrentUS();
auto ttgts1 = GetCurrentUS(); auto ttgts1 = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
ker->Compute(d, a, y_data); ker->Compute(a, y_data);
} }
auto ttgte1 = GetCurrentUS(); auto ttgte1 = GetCurrentUS();
VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat
...@@ -444,7 +444,7 @@ TEST(JitKernel, vmul) { ...@@ -444,7 +444,7 @@ TEST(JitKernel, vmul) {
auto ttgts = GetCurrentUS(); auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
ker->Compute(d, x_data, y_data, ztgt_data); ker->Compute(x_data, y_data, ztgt_data);
} }
auto ttgte = GetCurrentUS(); auto ttgte = GetCurrentUS();
...@@ -523,7 +523,7 @@ TEST(JitKernel, vadd) { ...@@ -523,7 +523,7 @@ TEST(JitKernel, vadd) {
auto ttgts = GetCurrentUS(); auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
ker->Compute(d, x_data, y_data, ztgt_data); ker->Compute(x_data, y_data, ztgt_data);
} }
auto ttgte = GetCurrentUS(); auto ttgte = GetCurrentUS();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册