提交 2513b2cc 编写于 作者: T tensor-tang

fix bug vtanh

上级 cf8c8e72
...@@ -29,7 +29,6 @@ namespace jitkernel { ...@@ -29,7 +29,6 @@ namespace jitkernel {
#define SIGMOID_THRESHOLD_MIN -40.0 #define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0 #define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0 #define EXP_MAX_INPUT 40.0
#define AVX_FLOAT_BLOCK 8 #define AVX_FLOAT_BLOCK 8
#define AVX2_FLOAT_BLOCK 8 #define AVX2_FLOAT_BLOCK 8
#define AVX512_FLOAT_BLOCK 16 #define AVX512_FLOAT_BLOCK 16
...@@ -40,8 +39,9 @@ class Kernel { ...@@ -40,8 +39,9 @@ class Kernel {
public: public:
Kernel() = default; Kernel() = default;
virtual ~Kernel() = default; virtual ~Kernel() = default;
int num_{0};
private: int end_{0};
int rest_{0};
DISABLE_COPY_AND_ASSIGN(Kernel); DISABLE_COPY_AND_ASSIGN(Kernel);
}; };
...@@ -95,13 +95,13 @@ class VExpKernel : public Kernel { ...@@ -95,13 +95,13 @@ class VExpKernel : public Kernel {
template <typename T> template <typename T>
class VSigmoidKernel : public Kernel { class VSigmoidKernel : public Kernel {
public: public:
virtual void Compute(const int n, const T *x, T *y) const = 0; virtual void Compute(const T *x, T *y) const = 0;
}; };
template <typename T> template <typename T>
class VTanhKernel : public Kernel { class VTanhKernel : public Kernel {
public: public:
virtual void Compute(const int n, const T *x, T *y) const = 0; virtual void Compute(const T *x, T *y) const = 0;
}; };
template <typename T> template <typename T>
......
...@@ -113,17 +113,18 @@ template <typename T, jit::cpu_isa_t isa, jit_block> ...@@ -113,17 +113,18 @@ template <typename T, jit::cpu_isa_t isa, jit_block>
class VSigmoidKernelImpl : public VSigmoidKernel<T> { class VSigmoidKernelImpl : public VSigmoidKernel<T> {
public: public:
explicit VSigmoidKernelImpl(int d) : VSigmoidKernel<T>() { explicit VSigmoidKernelImpl(int d) : VSigmoidKernel<T>() {
this->num_ = d;
vexp_ = KernelPool::Instance().template Get<VExpKernel<T>>(d); vexp_ = KernelPool::Instance().template Get<VExpKernel<T>>(d);
} }
void Compute(const int n, const T* x, T* y) const override { void Compute(const T* x, T* y) const override {
const T min = SIGMOID_THRESHOLD_MIN; const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX; const T max = SIGMOID_THRESHOLD_MAX;
for (int i = 0; i < n; ++i) { for (int i = 0; i < this->num_; ++i) {
y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = static_cast<T>(0) - y[i]; y[i] = static_cast<T>(0) - y[i];
} }
vexp_->Compute(n, y, y); vexp_->Compute(this->num_, y, y);
for (int i = 0; i < n; ++i) { for (int i = 0; i < this->num_; ++i) {
y[i] = static_cast<T>(1) / (static_cast<T>(1) + y[i]); y[i] = static_cast<T>(1) / (static_cast<T>(1) + y[i]);
} }
} }
...@@ -140,76 +141,89 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> { ...@@ -140,76 +141,89 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \ tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \
tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp) tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp)
#define INTRI8_FLOAT(isa) \ #define INTRI8_FLOAT(isa) \
template <> \ template <> \
void VSigmoidKernelImpl<float, isa, kEQ8>::Compute( \ void VSigmoidKernelImpl<float, isa, kEQ8>::Compute(const float* x, float* y) \
const int n, const float* x, float* y) const { \ const { \
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
__m256 tmp = _mm256_loadu_ps(x); \ __m256 tmp = _mm256_loadu_ps(x); \
INTRI_SIGMOID(tmp, min, max); \ INTRI_SIGMOID(tmp, min, max); \
_mm256_storeu_ps(y, tmp); \ _mm256_storeu_ps(y, tmp); \
} }
#define INTRI16_FLOAT(isa) \ #define INTRI16_FLOAT(isa) \
template <> \ template <> \
void VSigmoidKernelImpl<float, isa, kEQ16>::Compute( \ void VSigmoidKernelImpl<float, isa, kEQ16>::Compute(const float* x, \
const int n, const float* x, float* y) const { \ float* y) const { \
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
__m256 tmp0 = _mm256_loadu_ps(x); \ __m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \ __m256 tmp1 = _mm256_loadu_ps(x + 8); \
INTRI_SIGMOID(tmp0, min, max); \ INTRI_SIGMOID(tmp0, min, max); \
INTRI_SIGMOID(tmp1, min, max); \ INTRI_SIGMOID(tmp1, min, max); \
_mm256_storeu_ps(y, tmp0); \ _mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y + 8, tmp1); \ _mm256_storeu_ps(y + 8, tmp1); \
} }
#define INTRI_GT8LT16_FLOAT(isa) \ #define INTRI_GT8LT16_FLOAT(isa) \
template <> \ template <> \
void VSigmoidKernelImpl<float, isa, kGT8LT16>::Compute( \ VSigmoidKernelImpl<float, isa, kGT8LT16>::VSigmoidKernelImpl(int d) \
const int n, const float* x, float* y) const { \ : VSigmoidKernel<float>() { \
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ this->num_ = d; \
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ this->end_ = AVX_FLOAT_BLOCK; \
__m256 tmp = _mm256_loadu_ps(x); \ this->rest_ = d - this->end_; \
INTRI_SIGMOID(tmp, min, max); \ vexp_ = KernelPool::Instance().template Get<VExpKernel<float>>(d); \
_mm256_storeu_ps(y, tmp); \ } \
const float min_ = SIGMOID_THRESHOLD_MIN; \ template <> \
const float max_ = SIGMOID_THRESHOLD_MAX; \ void VSigmoidKernelImpl<float, isa, kGT8LT16>::Compute(const float* x, \
for (int i = AVX_FLOAT_BLOCK; i < n; ++i) { \ float* y) const { \
y[i] = (x[i] < min_) ? min_ : ((x[i] > max_) ? max_ : x[i]); \ __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \
y[i] = 0.f - y[i]; \ __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
} \ __m256 tmp = _mm256_loadu_ps(x); \
vexp_->Compute(n - AVX_FLOAT_BLOCK, y + AVX_FLOAT_BLOCK, \ INTRI_SIGMOID(tmp, min, max); \
y + AVX_FLOAT_BLOCK); \ _mm256_storeu_ps(y, tmp); \
for (int i = AVX_FLOAT_BLOCK; i < n; ++i) { \ const float min_ = SIGMOID_THRESHOLD_MIN; \
y[i] = 1.f / (1.f + y[i]); \ const float max_ = SIGMOID_THRESHOLD_MAX; \
} \ for (int i = this->end_; i < this->num_; ++i) { \
y[i] = (x[i] < min_) ? min_ : ((x[i] > max_) ? max_ : x[i]); \
y[i] = 0.f - y[i]; \
} \
vexp_->Compute(this->rest_, y + this->end_, y + this->end_); \
for (int i = this->end_; i < this->num_; ++i) { \
y[i] = 1.f / (1.f + y[i]); \
} \
} }
#define INTRI_GT16_FLOAT(isa) \ #define INTRI_GT16_FLOAT(isa) \
template <> \ template <> \
void VSigmoidKernelImpl<float, isa, kGT16>::Compute( \ VSigmoidKernelImpl<float, isa, kGT16>::VSigmoidKernelImpl(int d) \
const int n, const float* x, float* y) const { \ : VSigmoidKernel<float>() { \
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ this->num_ = d; \
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ this->rest_ = d % AVX_FLOAT_BLOCK; \
const int rest = n % AVX_FLOAT_BLOCK; \ this->end_ = d - this->rest_; \
const int end = n - rest; \ vexp_ = KernelPool::Instance().template Get<VExpKernel<float>>(d); \
for (int i = 0; i < end; i += AVX_FLOAT_BLOCK) { \ } \
__m256 tmp = _mm256_loadu_ps(x + i); \ template <> \
INTRI_SIGMOID(tmp, min, max); \ void VSigmoidKernelImpl<float, isa, kGT16>::Compute(const float* x, \
_mm256_storeu_ps(y + i, tmp); \ float* y) const { \
} \ __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \
const float min_ = SIGMOID_THRESHOLD_MIN; \ __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
const float max_ = SIGMOID_THRESHOLD_MAX; \ for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \
for (int i = end; i < n; ++i) { \ __m256 tmp = _mm256_loadu_ps(x + i); \
y[i] = (x[i] < min_) ? min_ : ((x[i] > max_) ? max_ : x[i]); \ INTRI_SIGMOID(tmp, min, max); \
y[i] = 0.f - y[i]; \ _mm256_storeu_ps(y + i, tmp); \
} \ } \
vexp_->Compute(rest, y + end, y + end); \ const float min_ = SIGMOID_THRESHOLD_MIN; \
for (int i = end; i < n; ++i) { \ const float max_ = SIGMOID_THRESHOLD_MAX; \
y[i] = 1.f / (1.f + y[i]); \ for (int i = this->end_; i < this->num_; ++i) { \
} \ y[i] = (x[i] < min_) ? min_ : ((x[i] > max_) ? max_ : x[i]); \
y[i] = 0.f - y[i]; \
} \
vexp_->Compute(this->rest_, y + this->end_, y + this->end_); \
for (int i = this->end_; i < this->num_; ++i) { \
y[i] = 1.f / (1.f + y[i]); \
} \
} }
#ifdef __AVX__ #ifdef __AVX__
...@@ -249,15 +263,16 @@ template <typename T, jit::cpu_isa_t isa, jit_block> ...@@ -249,15 +263,16 @@ template <typename T, jit::cpu_isa_t isa, jit_block>
class VTanhKernelImpl : public VTanhKernel<T> { class VTanhKernelImpl : public VTanhKernel<T> {
public: public:
explicit VTanhKernelImpl(int d) : VTanhKernel<T>() { explicit VTanhKernelImpl(int d) : VTanhKernel<T>() {
this->num_ = d;
vscal_ = KernelPool::Instance().template Get<VScalKernel<T>>(d); vscal_ = KernelPool::Instance().template Get<VScalKernel<T>>(d);
vsigmoid_ = KernelPool::Instance().template Get<VSigmoidKernel<T>>(d); vsigmoid_ = KernelPool::Instance().template Get<VSigmoidKernel<T>>(d);
vaddbias_ = KernelPool::Instance().template Get<VAddBiasKernel<T>>(d); vaddbias_ = KernelPool::Instance().template Get<VAddBiasKernel<T>>(d);
} }
void Compute(const int n, const T* x, T* y) const override { void Compute(const T* x, T* y) const override {
vscal_->Compute(n, static_cast<T>(2), x, y); vscal_->Compute(this->num_, static_cast<T>(2), x, y);
vsigmoid_->Compute(n, y, y); vsigmoid_->Compute(y, y);
vscal_->Compute(n, static_cast<T>(2), y); vscal_->Compute(this->num_, static_cast<T>(2), y);
vaddbias_->Compute(n, static_cast<T>(-1), y, y); vaddbias_->Compute(this->num_, static_cast<T>(-1), y, y);
} }
private: private:
...@@ -274,60 +289,83 @@ class VTanhKernelImpl : public VTanhKernel<T> { ...@@ -274,60 +289,83 @@ class VTanhKernelImpl : public VTanhKernel<T> {
tmp = _mm256_div_ps(_mm256_set1_ps(2.0f), tmp); \ tmp = _mm256_div_ps(_mm256_set1_ps(2.0f), tmp); \
tmp = _mm256_sub_ps(tmp, _mm256_set1_ps(1.0f)) tmp = _mm256_sub_ps(tmp, _mm256_set1_ps(1.0f))
#define INTRI8_FLOAT(isa) \ #define INTRI8_FLOAT(isa) \
template <> \ template <> \
void VTanhKernelImpl<float, isa, kEQ8>::Compute(const int n, const float* x, \ void VTanhKernelImpl<float, isa, kEQ8>::Compute(const float* x, float* y) \
float* y) const { \ const { \
__m256 tmp = _mm256_loadu_ps(x); \ __m256 tmp = _mm256_loadu_ps(x); \
INTRI_VTANH(tmp); \ INTRI_VTANH(tmp); \
_mm256_storeu_ps(y, tmp); \ _mm256_storeu_ps(y, tmp); \
} }
#define INTRI16_FLOAT(isa) \ #define INTRI16_FLOAT(isa) \
template <> \ template <> \
void VTanhKernelImpl<float, isa, kEQ16>::Compute( \ void VTanhKernelImpl<float, isa, kEQ16>::Compute(const float* x, float* y) \
const int n, const float* x, float* y) const { \ const { \
__m256 tmp0 = _mm256_loadu_ps(x); \ __m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \ __m256 tmp1 = _mm256_loadu_ps(x + 8); \
INTRI_VTANH(tmp0); \ INTRI_VTANH(tmp0); \
INTRI_VTANH(tmp1); \ INTRI_VTANH(tmp1); \
_mm256_storeu_ps(y, tmp0); \ _mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y + 8, tmp1); \ _mm256_storeu_ps(y + 8, tmp1); \
} }
#define INTRI_GT8LT16_FLOAT(isa) \ #define INTRI_GT8LT16_FLOAT(isa) \
template <> \ template <> \
void VTanhKernelImpl<float, isa, kGT8LT16>::Compute( \ VTanhKernelImpl<float, isa, kGT8LT16>::VTanhKernelImpl(int d) \
const int n, const float* x, float* y) const { \ : VTanhKernel<float>() { \
__m256 tmp = _mm256_loadu_ps(x); \ this->num_ = d; \
INTRI_VTANH(tmp); \ this->end_ = AVX_FLOAT_BLOCK; \
_mm256_storeu_ps(y, tmp); \ this->rest_ = d - this->end_; \
x += AVX_FLOAT_BLOCK; \ vscal_ = \
y += AVX_FLOAT_BLOCK; \ KernelPool::Instance().template Get<VScalKernel<float>>(this->rest_); \
const int rest = n - AVX_FLOAT_BLOCK; \ vsigmoid_ = KernelPool::Instance().template Get<VSigmoidKernel<float>>( \
vscal_->Compute(rest, 2.f, x, y); \ this->rest_); \
vsigmoid_->Compute(rest, y, y); \ vaddbias_ = KernelPool::Instance().template Get<VAddBiasKernel<float>>( \
vscal_->Compute(rest, 2.f, y); \ this->rest_); \
vaddbias_->Compute(rest, -1.f, y, y); \ } \
template <> \
void VTanhKernelImpl<float, isa, kGT8LT16>::Compute(const float* x, \
float* y) const { \
__m256 tmp = _mm256_loadu_ps(x); \
INTRI_VTANH(tmp); \
_mm256_storeu_ps(y, tmp); \
x += AVX_FLOAT_BLOCK; \
y += AVX_FLOAT_BLOCK; \
vscal_->Compute(this->rest_, 2.f, x, y); \
vsigmoid_->Compute(y, y); \
vscal_->Compute(this->rest_, 2.f, y); \
vaddbias_->Compute(this->rest_, -1.f, y, y); \
} }
#define INTRI_GT16_FLOAT(isa) \ #define INTRI_GT16_FLOAT(isa) \
template <> \ template <> \
void VTanhKernelImpl<float, isa, kGT16>::Compute( \ VTanhKernelImpl<float, isa, kGT16>::VTanhKernelImpl(int d) \
const int n, const float* x, float* y) const { \ : VTanhKernel<float>() { \
const int rest = n % AVX_FLOAT_BLOCK; \ this->num_ = d; \
const int end = n - rest; \ this->rest_ = d % AVX_FLOAT_BLOCK; \
for (int i = 0; i < end; i += AVX_FLOAT_BLOCK) { \ this->end_ = d - this->rest_; \
__m256 tmp = _mm256_loadu_ps(x + i); \ vscal_ = \
INTRI_VTANH(tmp); \ KernelPool::Instance().template Get<VScalKernel<float>>(this->rest_); \
_mm256_storeu_ps(y + i, tmp); \ vsigmoid_ = KernelPool::Instance().template Get<VSigmoidKernel<float>>( \
} \ this->rest_); \
x += end; \ vaddbias_ = KernelPool::Instance().template Get<VAddBiasKernel<float>>( \
y += end; \ this->rest_); \
vscal_->Compute(rest, 2.f, x, y); \ } \
vsigmoid_->Compute(rest, y, y); \ template <> \
vscal_->Compute(rest, 2.f, y); \ void VTanhKernelImpl<float, isa, kGT16>::Compute(const float* x, float* y) \
vaddbias_->Compute(rest, -1.f, y, y); \ const { \
for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \
__m256 tmp = _mm256_loadu_ps(x + i); \
INTRI_VTANH(tmp); \
_mm256_storeu_ps(y + i, tmp); \
} \
x += this->end_; \
y += this->end_; \
vscal_->Compute(this->rest_, 2.f, x, y); \
vsigmoid_->Compute(y, y); \
vscal_->Compute(this->rest_, 2.f, y); \
vaddbias_->Compute(this->rest_, -1.f, y, y); \
} }
#ifdef __AVX__ #ifdef __AVX__
......
...@@ -195,7 +195,7 @@ TEST(JitKernel, vsigmoid) { ...@@ -195,7 +195,7 @@ TEST(JitKernel, vsigmoid) {
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
auto ttgts = GetCurrentUS(); auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
ker->Compute(d, x_data, ztgt_data); ker->Compute(x_data, ztgt_data);
} }
auto ttgte = GetCurrentUS(); auto ttgte = GetCurrentUS();
...@@ -227,7 +227,7 @@ void vtanh_better( ...@@ -227,7 +227,7 @@ void vtanh_better(
vaddbias, vaddbias,
const int n, const float* x, float* y) { const int n, const float* x, float* y) {
vscal->Compute(n, 2.f, x, y); vscal->Compute(n, 2.f, x, y);
vsigmoid->Compute(n, y, y); vsigmoid->Compute(y, y);
vscal->Compute(n, 2.f, y); vscal->Compute(n, 2.f, y);
vaddbias->Compute(n, -1.f, y, y); vaddbias->Compute(n, -1.f, y, y);
} }
...@@ -261,7 +261,7 @@ TEST(JitKernel, vtanh) { ...@@ -261,7 +261,7 @@ TEST(JitKernel, vtanh) {
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
auto ttgts = GetCurrentUS(); auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
ker->Compute(d, x_data, ztgt_data); ker->Compute(x_data, ztgt_data);
} }
auto ttgte = GetCurrentUS(); auto ttgte = GetCurrentUS();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册