提交 e6d8aca3 编写于 作者: T tensor-tang

refine code and fix

上级 ea7dc9cb
......@@ -64,32 +64,32 @@ class KernelPool {
template <typename T>
class VMulKernel : public Kernel {
public:
virtual void Compute(const int n, const T *x, const T *y, T *z) const = 0;
virtual void Compute(const T *x, const T *y, T *z) const = 0;
};
template <typename T>
class VAddKernel : public Kernel {
public:
virtual void Compute(const int n, const T *x, const T *y, T *z) const = 0;
virtual void Compute(const T *x, const T *y, T *z) const = 0;
};
template <typename T>
class VScalKernel : public Kernel {
public:
virtual void Compute(const int n, const T a, const T *x, T *y) const = 0;
virtual void Compute(const int n, const T a, T *x) const = 0;
virtual void Compute(const T a, const T *x, T *y) const = 0;
virtual void Compute(const T a, T *x) const = 0;
};
template <typename T>
class VAddBiasKernel : public Kernel {
public:
virtual void Compute(const int n, const T a, const T *x, T *y) const = 0;
virtual void Compute(const T a, const T *x, T *y) const = 0;
};
template <typename T>
class VExpKernel : public Kernel {
public:
virtual void Compute(const int n, const T *x, T *y) const = 0;
virtual void Compute(const T *x, T *y) const = 0;
};
template <typename T>
......
......@@ -34,8 +34,9 @@ namespace jit = platform::jit;
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
class VMulKernelImpl : public VMulKernel<T> {
public:
void Compute(const int n, const T* x, const T* y, T* z) const override {
for (int i = 0; i < n; ++i) {
explicit VMulKernelImpl(int d) : VMulKernel<T>() { this->num_ = d; }
void Compute(const T* x, const T* y, T* z) const override {
for (int i = 0; i < this->num_; ++i) {
z[i] = x[i] * y[i];
}
}
......@@ -45,15 +46,15 @@ class VMulKernelImpl : public VMulKernel<T> {
#define MKL_FLOAT(isa, block) \
template <> \
void VMulKernelImpl<float, isa, block>::Compute( \
const int n, const float* x, const float* y, float* z) const { \
platform::dynload::vsMul(n, x, y, z); \
const float* x, const float* y, float* z) const { \
platform::dynload::vsMul(this->num_, x, y, z); \
}
#define MKL_DOUBLE(isa, block) \
template <> \
void VMulKernelImpl<double, isa, block>::Compute( \
const int n, const double* x, const double* y, double* z) const { \
platform::dynload::vdMul(n, x, y, z); \
const double* x, const double* y, double* z) const { \
platform::dynload::vdMul(this->num_, x, y, z); \
}
FOR_EACH_ISA(MKL_FLOAT, kGT16);
......@@ -63,7 +64,7 @@ FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
#define INTRI8_FLOAT(isa) \
template <> \
void VMulKernelImpl<float, isa, kEQ8>::Compute( \
const int n, const float* x, const float* y, float* z) const { \
const float* x, const float* y, float* z) const { \
__m256 tmpx, tmpy; \
tmpx = _mm256_loadu_ps(x); \
tmpy = _mm256_loadu_ps(y); \
......@@ -90,8 +91,9 @@ INTRI8_FLOAT(jit::avx512f);
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
class VAddKernelImpl : public VAddKernel<T> {
public:
void Compute(const int n, const T* x, const T* y, T* z) const override {
for (int i = 0; i < n; ++i) {
explicit VAddKernelImpl(int d) : VAddKernel<T>() { this->num_ = d; }
void Compute(const T* x, const T* y, T* z) const override {
for (int i = 0; i < this->num_; ++i) {
z[i] = x[i] + y[i];
}
}
......@@ -101,15 +103,15 @@ class VAddKernelImpl : public VAddKernel<T> {
#define MKL_FLOAT(isa, block) \
template <> \
void VAddKernelImpl<float, isa, block>::Compute( \
const int n, const float* x, const float* y, float* z) const { \
platform::dynload::vsAdd(n, x, y, z); \
const float* x, const float* y, float* z) const { \
platform::dynload::vsAdd(this->num_, x, y, z); \
}
#define MKL_DOUBLE(isa, block) \
template <> \
void VAddKernelImpl<double, isa, block>::Compute( \
const int n, const double* x, const double* y, double* z) const { \
platform::dynload::vdAdd(n, x, y, z); \
const double* x, const double* y, double* z) const { \
platform::dynload::vdAdd(this->num_, x, y, z); \
}
FOR_EACH_ISA(MKL_FLOAT, kGT16);
......@@ -119,7 +121,7 @@ FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
#define INTRI8_FLOAT(isa) \
template <> \
void VAddKernelImpl<float, isa, kEQ8>::Compute( \
const int n, const float* x, const float* y, float* z) const { \
const float* x, const float* y, float* z) const { \
__m256 tmpx, tmpy; \
tmpx = _mm256_loadu_ps(x); \
tmpy = _mm256_loadu_ps(y); \
......@@ -145,13 +147,14 @@ INTRI8_FLOAT(jit::avx512f);
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
class VScalKernelImpl : public VScalKernel<T> {
public:
void Compute(const int n, const T a, const T* x, T* y) const override {
for (int i = 0; i < n; ++i) {
explicit VScalKernelImpl(int d) : VScalKernel<T>() { this->num_ = d; }
void Compute(const T a, const T* x, T* y) const override {
for (int i = 0; i < this->num_; ++i) {
y[i] = a * x[i];
}
}
void Compute(const int n, const T a, T* x) const override {
for (int i = 0; i < n; ++i) {
void Compute(const T a, T* x) const override {
for (int i = 0; i < this->num_; ++i) {
x[i] = a * x[i];
}
}
......@@ -160,16 +163,16 @@ class VScalKernelImpl : public VScalKernel<T> {
#ifdef PADDLE_WITH_MKLML
#define MKL_FLOAT(isa, block) \
template <> \
void VScalKernelImpl<float, isa, block>::Compute(const int n, const float a, \
float* x) const { \
platform::dynload::cblas_sscal(n, a, x, 1); \
void VScalKernelImpl<float, isa, block>::Compute(const float a, float* x) \
const { \
platform::dynload::cblas_sscal(this->num_, a, x, 1); \
}
#define MKL_DOUBLE(isa, block) \
template <> \
void VScalKernelImpl<double, isa, block>::Compute( \
const int n, const double a, double* x) const { \
platform::dynload::cblas_dscal(n, a, x, 1); \
void VScalKernelImpl<double, isa, block>::Compute(const double a, double* x) \
const { \
platform::dynload::cblas_dscal(this->num_, a, x, 1); \
}
FOR_EACH_ISA(MKL_FLOAT, kGT16);
......@@ -179,7 +182,7 @@ FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
#define INTRI8_FLOAT(isa) \
template <> \
void VScalKernelImpl<float, isa, kEQ8>::Compute( \
const int n, const float a, const float* x, float* y) const { \
const float a, const float* x, float* y) const { \
__m256 tmp; \
__m256 scalar = _mm256_set1_ps(a); \
tmp = _mm256_loadu_ps(x); \
......@@ -188,8 +191,8 @@ FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
}
#define INTRI8_INPLACE_FLOAT(isa) \
template <> \
void VScalKernelImpl<float, isa, kEQ8>::Compute(const int n, const float a, \
float* x) const { \
void VScalKernelImpl<float, isa, kEQ8>::Compute(const float a, float* x) \
const { \
__m256 tmp; \
__m256 scalar = _mm256_set1_ps(a); \
tmp = _mm256_loadu_ps(x); \
......@@ -220,8 +223,9 @@ INTRI8_INPLACE_FLOAT(jit::avx512f);
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
class VAddBiasKernelImpl : public VAddBiasKernel<T> {
public:
void Compute(const int n, const T a, const T* x, T* y) const override {
for (int i = 0; i < n; ++i) {
explicit VAddBiasKernelImpl(int d) : VAddBiasKernel<T>() { this->num_ = d; }
void Compute(const T a, const T* x, T* y) const override {
for (int i = 0; i < this->num_; ++i) {
y[i] = x[i] + a;
}
}
......@@ -230,7 +234,7 @@ class VAddBiasKernelImpl : public VAddBiasKernel<T> {
#define INTRI8_FLOAT(isa) \
template <> \
void VAddBiasKernelImpl<float, isa, kEQ8>::Compute( \
const int n, const float a, const float* x, float* y) const { \
const float a, const float* x, float* y) const { \
__m256 tmp = _mm256_loadu_ps(x); \
tmp = _mm256_add_ps(tmp, _mm256_set1_ps(a)); \
_mm256_storeu_ps(y, tmp); \
......@@ -239,7 +243,7 @@ class VAddBiasKernelImpl : public VAddBiasKernel<T> {
#define INTRI16_FLOAT(isa) \
template <> \
void VAddBiasKernelImpl<float, isa, kEQ16>::Compute( \
const int n, const float a, const float* x, float* y) const { \
const float a, const float* x, float* y) const { \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
tmp0 = _mm256_add_ps(tmp0, _mm256_set1_ps(a)); \
......
......@@ -40,8 +40,9 @@ namespace jit = platform::jit;
template <typename T, jit::cpu_isa_t isa, jit_block>
class VExpKernelImpl : public VExpKernel<T> {
public:
void Compute(const int n, const T* x, T* y) const override {
for (int i = 0; i < n; ++i) {
explicit VExpKernelImpl(int d) : VExpKernel<T>() { this->num_ = d; }
void Compute(const T* x, T* y) const override {
for (int i = 0; i < this->num_; ++i) {
y[i] = std::exp(x[i]);
}
}
......@@ -50,16 +51,16 @@ class VExpKernelImpl : public VExpKernel<T> {
#ifdef PADDLE_WITH_MKLML
#define MKL_FLOAT(isa, block) \
template <> \
void VExpKernelImpl<float, isa, block>::Compute(const int n, const float* x, \
float* y) const { \
platform::dynload::vsExp(n, x, y); \
void VExpKernelImpl<float, isa, block>::Compute(const float* x, float* y) \
const { \
platform::dynload::vsExp(this->num_, x, y); \
}
#define MKL_DOUBLE(isa, block) \
template <> \
void VExpKernelImpl<double, isa, block>::Compute( \
const int n, const double* x, double* y) const { \
platform::dynload::vdExp(n, x, y); \
void VExpKernelImpl<double, isa, block>::Compute(const double* x, double* y) \
const { \
platform::dynload::vdExp(this->num_, x, y); \
}
FOR_EACH_ISA(MKL_FLOAT, kLT8);
FOR_EACH_ISA(MKL_FLOAT, kGT8LT16);
......@@ -69,16 +70,16 @@ FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
#define INTRI8_FLOAT(isa) \
template <> \
void VExpKernelImpl<float, isa, kEQ8>::Compute(const int n, const float* x, \
float* y) const { \
void VExpKernelImpl<float, isa, kEQ8>::Compute(const float* x, float* y) \
const { \
__m256 tmp = _mm256_loadu_ps(x); \
_mm256_storeu_ps(y, detail::Exp(tmp)); \
}
#define INTRI16_FLOAT(isa) \
template <> \
void VExpKernelImpl<float, isa, kEQ16>::Compute(const int n, const float* x, \
float* y) const { \
void VExpKernelImpl<float, isa, kEQ16>::Compute(const float* x, float* y) \
const { \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
tmp0 = detail::Exp(tmp0); \
......@@ -123,7 +124,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = static_cast<T>(0) - y[i];
}
vexp_->Compute(this->num_, y, y);
vexp_->Compute(y, y);
for (int i = 0; i < this->num_; ++i) {
y[i] = static_cast<T>(1) / (static_cast<T>(1) + y[i]);
}
......@@ -173,7 +174,8 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
this->num_ = d; \
this->end_ = AVX_FLOAT_BLOCK; \
this->rest_ = d - this->end_; \
vexp_ = KernelPool::Instance().template Get<VExpKernel<float>>(d); \
vexp_ = \
KernelPool::Instance().template Get<VExpKernel<float>>(this->rest_); \
} \
template <> \
void VSigmoidKernelImpl<float, isa, kGT8LT16>::Compute(const float* x, \
......@@ -189,7 +191,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
y[i] = (x[i] < min_) ? min_ : ((x[i] > max_) ? max_ : x[i]); \
y[i] = 0.f - y[i]; \
} \
vexp_->Compute(this->rest_, y + this->end_, y + this->end_); \
vexp_->Compute(y + this->end_, y + this->end_); \
for (int i = this->end_; i < this->num_; ++i) { \
y[i] = 1.f / (1.f + y[i]); \
} \
......@@ -202,7 +204,8 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
this->num_ = d; \
this->rest_ = d % AVX_FLOAT_BLOCK; \
this->end_ = d - this->rest_; \
vexp_ = KernelPool::Instance().template Get<VExpKernel<float>>(d); \
vexp_ = \
KernelPool::Instance().template Get<VExpKernel<float>>(this->rest_); \
} \
template <> \
void VSigmoidKernelImpl<float, isa, kGT16>::Compute(const float* x, \
......@@ -220,7 +223,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
y[i] = (x[i] < min_) ? min_ : ((x[i] > max_) ? max_ : x[i]); \
y[i] = 0.f - y[i]; \
} \
vexp_->Compute(this->rest_, y + this->end_, y + this->end_); \
vexp_->Compute(y + this->end_, y + this->end_); \
for (int i = this->end_; i < this->num_; ++i) { \
y[i] = 1.f / (1.f + y[i]); \
} \
......@@ -251,12 +254,7 @@ INTRI16_FLOAT(jit::avx512f);
#undef INTRI_GT16_FLOAT
#undef INTRI_VSIGMOID
#define JITKERNEL_NEW_ACT_IMPL(ker, dtype, isa, k) \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype, isa, k>>(d))
REGISTER_JITKERNEL_ARGS(vsigmoid, VSigmoidKernel, JITKERNEL_DECLARE,
JITKERNEL_KEY, JITKERNEL_NEW_ACT_IMPL);
REGISTER_JITKERNEL(vsigmoid, VSigmoidKernel);
/* VTanh JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block>
......@@ -269,10 +267,10 @@ class VTanhKernelImpl : public VTanhKernel<T> {
vaddbias_ = KernelPool::Instance().template Get<VAddBiasKernel<T>>(d);
}
void Compute(const T* x, T* y) const override {
vscal_->Compute(this->num_, static_cast<T>(2), x, y);
vscal_->Compute(static_cast<T>(2), x, y);
vsigmoid_->Compute(y, y);
vscal_->Compute(this->num_, static_cast<T>(2), y);
vaddbias_->Compute(this->num_, static_cast<T>(-1), y, y);
vscal_->Compute(static_cast<T>(2), y);
vaddbias_->Compute(static_cast<T>(-1), y, y);
}
private:
......@@ -332,10 +330,10 @@ class VTanhKernelImpl : public VTanhKernel<T> {
_mm256_storeu_ps(y, tmp); \
x += AVX_FLOAT_BLOCK; \
y += AVX_FLOAT_BLOCK; \
vscal_->Compute(this->rest_, 2.f, x, y); \
vscal_->Compute(2.f, x, y); \
vsigmoid_->Compute(y, y); \
vscal_->Compute(this->rest_, 2.f, y); \
vaddbias_->Compute(this->rest_, -1.f, y, y); \
vscal_->Compute(2.f, y); \
vaddbias_->Compute(-1.f, y, y); \
}
#define INTRI_GT16_FLOAT(isa) \
......@@ -362,10 +360,10 @@ class VTanhKernelImpl : public VTanhKernel<T> {
} \
x += this->end_; \
y += this->end_; \
vscal_->Compute(this->rest_, 2.f, x, y); \
vscal_->Compute(2.f, x, y); \
vsigmoid_->Compute(y, y); \
vscal_->Compute(this->rest_, 2.f, y); \
vaddbias_->Compute(this->rest_, -1.f, y, y); \
vscal_->Compute(2.f, y); \
vaddbias_->Compute(-1.f, y, y); \
}
#ifdef __AVX__
......@@ -391,8 +389,7 @@ INTRI16_FLOAT(jit::avx512f);
#undef INTRI_GT16_FLOAT
#undef INTRI_VTANH
REGISTER_JITKERNEL_ARGS(vtanh, VTanhKernel, JITKERNEL_DECLARE, JITKERNEL_KEY,
JITKERNEL_NEW_ACT_IMPL);
REGISTER_JITKERNEL(vtanh, VTanhKernel);
#undef JITKERNEL_NEW_ACT_IMPL
......
......@@ -57,7 +57,7 @@ namespace jit = platform::jit;
#define JITKERNEL_NEW_IMPL(ker, dtype, isa, k) \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype, isa, k>>())
std::make_shared<ker##Impl<dtype, isa, k>>(d))
#define JITKERNEL_WITH_DTYPE(ker_key, ker_class, ker_dtype, dtype_key, \
marco_declare, macro_key, macro_impl) \
......
......@@ -73,7 +73,7 @@ TEST(JitKernel, vaddbias) {
auto trefe = GetCurrentUS();
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(d, a, x_data, ztgt_data);
ker->Compute(a, x_data, ztgt_data);
}
auto ttgte = GetCurrentUS();
......@@ -99,7 +99,7 @@ void vexp_mkl(const int n, const float* x, float* y) {
TEST(JitKernel, vexp) {
namespace jit = paddle::operators::math::jitkernel;
for (int d : {7, 8, 15, 16, 30, 128}) {
for (int d : {7, 8, 15, 16, 30, 128, 256}) {
std::vector<float> x(d);
std::vector<float> zref(d), ztgt(d);
RandomVec<float>(d, x.data(), -2.f, 2.f);
......@@ -124,7 +124,7 @@ TEST(JitKernel, vexp) {
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(d, x_data, ztgt_data);
ker->Compute(x_data, ztgt_data);
}
auto ttgte = GetCurrentUS();
......@@ -164,7 +164,7 @@ void vsigmoid_better(
y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = 0.f - y[i];
}
vexp->Compute(n, y, y);
vexp->Compute(y, y);
for (int i = 0; i < n; ++i) {
y[i] = 1.f / (1.f + y[i]);
}
......@@ -226,10 +226,10 @@ void vtanh_better(
const paddle::operators::math::jitkernel::VAddBiasKernel<float>>&
vaddbias,
const int n, const float* x, float* y) {
vscal->Compute(n, 2.f, x, y);
vscal->Compute(2.f, x, y);
vsigmoid->Compute(y, y);
vscal->Compute(n, 2.f, y);
vaddbias->Compute(n, -1.f, y, y);
vscal->Compute(2.f, y);
vaddbias->Compute(-1.f, y, y);
}
TEST(JitKernel, vtanh) {
......@@ -359,12 +359,12 @@ TEST(JitKernel, vscal) {
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(d, a, x_data, ztgt_data);
ker->Compute(a, x_data, ztgt_data);
}
auto ttgte = GetCurrentUS();
auto ttgts1 = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(d, a, y_data);
ker->Compute(a, y_data);
}
auto ttgte1 = GetCurrentUS();
VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat
......@@ -444,7 +444,7 @@ TEST(JitKernel, vmul) {
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(d, x_data, y_data, ztgt_data);
ker->Compute(x_data, y_data, ztgt_data);
}
auto ttgte = GetCurrentUS();
......@@ -523,7 +523,7 @@ TEST(JitKernel, vadd) {
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(d, x_data, y_data, ztgt_data);
ker->Compute(x_data, y_data, ztgt_data);
}
auto ttgte = GetCurrentUS();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册