未验证 提交 22125eba 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #14321 from tensor-tang/fea/jit/vscal

Fea jitcode vscal vaddbias
1 合并请求!14726Merge develop to Ce_debug
...@@ -24,21 +24,30 @@ namespace gen { ...@@ -24,21 +24,30 @@ namespace gen {
using namespace platform::jit; // NOLINT using namespace platform::jit; // NOLINT
bool VVVJitCode::init(int d) { bool VXXJitCode::init(int d, int scalar_index) {
// It's not necessary to use avx512 since it would slow down the frequency // It's not necessary to use avx512 since it would slow down the frequency
// and this kernel is not compute bound. // and this kernel is not compute bound.
return MayIUse(avx); return MayIUse(avx) && scalar_index >= 0 && scalar_index <= 2;
} }
void VVVJitCode::generate() { void VXXJitCode::generate() {
// do not need push stack, and do not need save avx512reg if do not use avx512 // do not need push stack, and do not need save avx512reg if do not use avx512
int offset = 0; int offset = 0;
if (with_relu_) { if (with_relu_) {
vxorps(ymm_zero, ymm_zero, ymm_zero); vxorps(ymm_zero, ymm_zero, ymm_zero);
} }
if (scalar_index_ == 1) {
vbroadcastss(ymm_src1, ptr[param1]);
} else if (scalar_index_ == 2) {
vbroadcastss(ymm_src2, ptr[param2]);
}
for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) { for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) {
vmovups(ymm_src1, ptr[param1 + offset]); if (scalar_index_ != 1) {
vmovups(ymm_src2, ptr[param2 + offset]); vmovups(ymm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovups(ymm_src2, ptr[param2 + offset]);
}
if (type_ == operand_type::mul) { if (type_ == operand_type::mul) {
vmulps(ymm_dst, ymm_src1, ymm_src2); vmulps(ymm_dst, ymm_src1, ymm_src2);
} else if (type_ == operand_type::add) { } else if (type_ == operand_type::add) {
...@@ -52,8 +61,12 @@ void VVVJitCode::generate() { ...@@ -52,8 +61,12 @@ void VVVJitCode::generate() {
} }
int rest = num_ % AVX_FLOAT_BLOCK; int rest = num_ % AVX_FLOAT_BLOCK;
if (rest >= 4) { if (rest >= 4) {
vmovups(xmm_src1, ptr[param1 + offset]); if (scalar_index_ != 1) {
vmovups(xmm_src2, ptr[param2 + offset]); vmovups(xmm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovups(xmm_src2, ptr[param2 + offset]);
}
if (type_ == operand_type::mul) { if (type_ == operand_type::mul) {
vmulps(xmm_dst, xmm_src1, xmm_src2); vmulps(xmm_dst, xmm_src1, xmm_src2);
} else if (type_ == operand_type::add) { } else if (type_ == operand_type::add) {
...@@ -67,8 +80,12 @@ void VVVJitCode::generate() { ...@@ -67,8 +80,12 @@ void VVVJitCode::generate() {
rest -= 4; rest -= 4;
} }
if (rest >= 2) { if (rest >= 2) {
vmovq(xmm_src1, ptr[param1 + offset]); if (scalar_index_ != 1) {
vmovq(xmm_src2, ptr[param2 + offset]); vmovups(xmm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovups(xmm_src2, ptr[param2 + offset]);
}
if (type_ == operand_type::mul) { if (type_ == operand_type::mul) {
vmulps(xmm_dst, xmm_src1, xmm_src2); vmulps(xmm_dst, xmm_src1, xmm_src2);
} else if (type_ == operand_type::add) { } else if (type_ == operand_type::add) {
...@@ -82,8 +99,12 @@ void VVVJitCode::generate() { ...@@ -82,8 +99,12 @@ void VVVJitCode::generate() {
rest -= 2; rest -= 2;
} }
if (rest > 0) { if (rest > 0) {
vmovss(xmm_src1, ptr[param1 + offset]); if (scalar_index_ != 1) {
vmovss(xmm_src2, ptr[param2 + offset]); vmovups(xmm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovups(xmm_src2, ptr[param2 + offset]);
}
if (type_ == operand_type::mul) { if (type_ == operand_type::mul) {
vmulss(xmm_dst, xmm_src1, xmm_src2); vmulss(xmm_dst, xmm_src1, xmm_src2);
} else if (type_ == operand_type::add) { } else if (type_ == operand_type::add) {
...@@ -96,6 +117,7 @@ void VVVJitCode::generate() { ...@@ -96,6 +117,7 @@ void VVVJitCode::generate() {
} }
ret(); ret();
} }
} // namespace gen } // namespace gen
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
......
...@@ -29,33 +29,46 @@ using ymm_t = const Xbyak::Ymm; ...@@ -29,33 +29,46 @@ using ymm_t = const Xbyak::Ymm;
using zmm_t = const Xbyak::Zmm; using zmm_t = const Xbyak::Zmm;
using Label = Xbyak::Label; using Label = Xbyak::Label;
// function: vec = Operand(vec, vec) (maybe with relu)
typedef enum { mul = 0, add } operand_type; typedef enum { mul = 0, add } operand_type;
class VVVJitCode : public JitCode { // function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu)
class VXXJitCode : public JitCode {
public: public:
const char* name() const override { const char* name() const override {
std::string base = "VVVJitCode"; std::string base = "VXXJitCode";
if (scalar_index_ == 1) {
base += "_Scalar";
} else {
base += "_Vec";
}
if (type_ == operand_type::mul) { if (type_ == operand_type::mul) {
base += "_Mul"; base += "_Mul";
} else if (type_ == operand_type::add) { } else if (type_ == operand_type::add) {
base += "_Add"; base += "_Add";
} }
base += (with_relu_ ? "_relu" : ""); if (scalar_index_ == 2) {
base += "_Scalar";
} else {
base += "_Vec";
}
base += (with_relu_ ? "_Relu" : "");
return base.c_str(); return base.c_str();
} }
explicit VVVJitCode(int d, operand_type type, bool with_relu, explicit VXXJitCode(int d, operand_type type, int scalar_index,
size_t code_size = 256 * 1024, void* code_ptr = nullptr) bool with_relu, size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr), : JitCode(code_size, code_ptr),
num_(d), num_(d),
type_(type), type_(type),
scalar_index_(scalar_index),
with_relu_(with_relu) {} with_relu_(with_relu) {}
static bool init(int d); static bool init(int d, int scalar_index = 0);
void generate() override; void generate() override;
private: private:
int num_; int num_;
operand_type type_; operand_type type_;
int scalar_index_;
bool with_relu_; bool with_relu_;
reg64_t param1{abi_param1}; reg64_t param1{abi_param1};
reg64_t param2{abi_param2}; reg64_t param2{abi_param2};
...@@ -63,13 +76,13 @@ class VVVJitCode : public JitCode { ...@@ -63,13 +76,13 @@ class VVVJitCode : public JitCode {
xmm_t xmm_src1 = xmm_t(0); xmm_t xmm_src1 = xmm_t(0);
xmm_t xmm_src2 = xmm_t(1); xmm_t xmm_src2 = xmm_t(1);
xmm_t xmm_dst = xmm_t(1); xmm_t xmm_dst = xmm_t(2);
xmm_t xmm_zero = xmm_t(2); xmm_t xmm_zero = xmm_t(3);
ymm_t ymm_src1 = ymm_t(0); ymm_t ymm_src1 = ymm_t(0);
ymm_t ymm_src2 = ymm_t(1); ymm_t ymm_src2 = ymm_t(1);
ymm_t ymm_dst = ymm_t(1); ymm_t ymm_dst = ymm_t(2);
ymm_t ymm_zero = ymm_t(2); ymm_t ymm_zero = ymm_t(3);
}; };
} // namespace gen } // namespace gen
......
...@@ -83,14 +83,15 @@ class VAddReluKernel : public Kernel { ...@@ -83,14 +83,15 @@ class VAddReluKernel : public Kernel {
template <typename T> template <typename T>
class VScalKernel : public Kernel { class VScalKernel : public Kernel {
public: public:
virtual void Compute(const T a, const T *x, T *y) const = 0; // y = a.*x
virtual void Compute(const T a, T *x) const = 0; void (*Compute)(const T *, const T *, T *, int);
}; };
template <typename T> template <typename T>
class VAddBiasKernel : public Kernel { class VAddBiasKernel : public Kernel {
public: public:
virtual void Compute(const T a, const T *x, T *y) const = 0; // y = a.+x
void (*Compute)(const T *, const T *, T *, int);
}; };
template <typename T> template <typename T>
......
...@@ -57,6 +57,20 @@ void VAddReluRefer(const T* x, const T* y, T* z, int n) { ...@@ -57,6 +57,20 @@ void VAddReluRefer(const T* x, const T* y, T* z, int n) {
} }
} }
template <typename T>
void VScalRefer(const T* a, const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = a[0] * x[i];
}
}
template <typename T>
void VAddBiasRefer(const T* a, const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = a[0] + x[i];
}
}
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
template <typename T> template <typename T>
void VMulMKL(const T* x, const T* y, T* z, int n); void VMulMKL(const T* x, const T* y, T* z, int n);
...@@ -83,6 +97,28 @@ template <> ...@@ -83,6 +97,28 @@ template <>
void VAddMKL<double>(const double* x, const double* y, double* z, int n) { void VAddMKL<double>(const double* x, const double* y, double* z, int n) {
platform::dynload::vdAdd(n, x, y, z); platform::dynload::vdAdd(n, x, y, z);
} }
template <typename T>
void VScalMKL(const T* a, const T* x, T* y, int n);
template <>
void VScalMKL<float>(const float* a, const float* x, float* y, int n) {
if (x == y) {
platform::dynload::cblas_sscal(n, *a, y, 1);
} else {
VScalRefer<float>(a, x, y, n);
}
}
template <>
void VScalMKL<double>(const double* a, const double* x, double* y, int n) {
if (x == y) {
platform::dynload::cblas_dscal(n, *a, y, 1);
} else {
VScalRefer<double>(a, x, y, n);
}
}
#endif #endif
#define DECLARE_STATIC_FUNC \ #define DECLARE_STATIC_FUNC \
...@@ -102,7 +138,7 @@ class VMulKernelImpl : public VMulKernel<T> { ...@@ -102,7 +138,7 @@ class VMulKernelImpl : public VMulKernel<T> {
if (useJIT(d)) { if (useJIT(d)) {
// roughly estimate the size of code // roughly estimate the size of code
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VVVJitCode(d, gen::operand_type::mul, false, jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::mul, 0, false,
sz > 4096 ? sz : 4096)); sz > 4096 ? sz : 4096));
this->Compute = this->Compute =
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>(); jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
...@@ -121,14 +157,14 @@ class VMulKernelImpl : public VMulKernel<T> { ...@@ -121,14 +157,14 @@ class VMulKernelImpl : public VMulKernel<T> {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
private: private:
std::unique_ptr<gen::VVVJitCode> jitcode_{nullptr}; std::unique_ptr<gen::VXXJitCode> jitcode_{nullptr};
#endif #endif
}; };
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
template <> template <>
bool VMulKernelImpl<float>::useJIT(int d) { bool VMulKernelImpl<float>::useJIT(int d) {
return gen::VVVJitCode::init(d); return gen::VXXJitCode::init(d);
} }
#endif #endif
...@@ -153,7 +189,7 @@ class VAddKernelImpl : public VAddKernel<T> { ...@@ -153,7 +189,7 @@ class VAddKernelImpl : public VAddKernel<T> {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) { if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VVVJitCode(d, gen::operand_type::add, false, jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 0, false,
sz > 4096 ? sz : 4096)); sz > 4096 ? sz : 4096));
this->Compute = this->Compute =
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>(); jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
...@@ -171,14 +207,14 @@ class VAddKernelImpl : public VAddKernel<T> { ...@@ -171,14 +207,14 @@ class VAddKernelImpl : public VAddKernel<T> {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
private: private:
std::unique_ptr<gen::VVVJitCode> jitcode_{nullptr}; std::unique_ptr<gen::VXXJitCode> jitcode_{nullptr};
#endif #endif
}; };
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
template <> template <>
bool VAddKernelImpl<float>::useJIT(int d) { bool VAddKernelImpl<float>::useJIT(int d) {
return gen::VVVJitCode::init(d); return gen::VXXJitCode::init(d);
} }
#endif #endif
...@@ -203,7 +239,7 @@ class VAddReluKernelImpl : public VAddReluKernel<T> { ...@@ -203,7 +239,7 @@ class VAddReluKernelImpl : public VAddReluKernel<T> {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) { if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VVVJitCode(d, gen::operand_type::add, true, jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 0, true,
sz > 4096 ? sz : 4096)); sz > 4096 ? sz : 4096));
this->Compute = this->Compute =
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>(); jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
...@@ -215,148 +251,106 @@ class VAddReluKernelImpl : public VAddReluKernel<T> { ...@@ -215,148 +251,106 @@ class VAddReluKernelImpl : public VAddReluKernel<T> {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
private: private:
std::unique_ptr<gen::VVVJitCode> jitcode_{nullptr}; std::unique_ptr<gen::VXXJitCode> jitcode_{nullptr};
#endif #endif
}; };
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
template <> template <>
bool VAddReluKernelImpl<float>::useJIT(int d) { bool VAddReluKernelImpl<float>::useJIT(int d) {
return gen::VVVJitCode::init(d); return gen::VXXJitCode::init(d);
} }
#endif #endif
#undef DECLARE_STATIC_FUNC /* VScal JitKernel */
template <typename T>
REGISTER_JITKERNEL(vmul, VMulKernel);
REGISTER_JITKERNEL(vadd, VAddKernel);
REGISTER_JITKERNEL(vaddrelu, VAddReluKernel);
/* VSCAL JitKernel */
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
class VScalKernelImpl : public VScalKernel<T> { class VScalKernelImpl : public VScalKernel<T> {
public: public:
explicit VScalKernelImpl(int d) : VScalKernel<T>() { this->num_ = d; } DECLARE_STATIC_FUNC;
void Compute(const T a, const T* x, T* y) const override { explicit VScalKernelImpl(int d) : VScalKernel<T>() {
for (int i = 0; i < this->num_; ++i) { #ifdef PADDLE_WITH_XBYAK
y[i] = a * x[i]; if (useJIT(d)) {
} size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
} jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::mul, 1, false,
void Compute(const T a, T* x) const override { sz > 4096 ? sz : 4096));
for (int i = 0; i < this->num_; ++i) { this->Compute =
x[i] = a * x[i]; jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
return;
} }
} #endif
};
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#define MKL_FLOAT(isa, block) \ if (useMKL(d)) {
template <> \ this->Compute = VScalMKL<T>;
void VScalKernelImpl<float, isa, block>::Compute(const float a, float* x) \ return;
const { \ }
platform::dynload::cblas_sscal(this->num_, a, x, 1); \
}
#define MKL_DOUBLE(isa, block) \
template <> \
void VScalKernelImpl<double, isa, block>::Compute(const double a, double* x) \
const { \
platform::dynload::cblas_dscal(this->num_, a, x, 1); \
}
FOR_EACH_ISA(MKL_FLOAT, kGT16);
FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
#endif #endif
this->Compute = VScalRefer<T>;
#define INTRI8_FLOAT(isa) \
template <> \
void VScalKernelImpl<float, isa, kEQ8>::Compute( \
const float a, const float* x, float* y) const { \
__m256 tmp; \
__m256 scalar = _mm256_set1_ps(a); \
tmp = _mm256_loadu_ps(x); \
tmp = _mm256_mul_ps(tmp, scalar); \
_mm256_storeu_ps(y, tmp); \
}
#define INTRI8_INPLACE_FLOAT(isa) \
template <> \
void VScalKernelImpl<float, isa, kEQ8>::Compute(const float a, float* x) \
const { \
__m256 tmp; \
__m256 scalar = _mm256_set1_ps(a); \
tmp = _mm256_loadu_ps(x); \
tmp = _mm256_mul_ps(tmp, scalar); \
_mm256_storeu_ps(x, tmp); \
} }
#ifdef PADDLE_WITH_XBYAK
#ifdef __AVX__ private:
INTRI8_FLOAT(jit::avx); std::unique_ptr<gen::VXXJitCode> jitcode_{nullptr};
INTRI8_INPLACE_FLOAT(jit::avx);
#endif
#ifdef __AVX2__
INTRI8_FLOAT(jit::avx2);
INTRI8_INPLACE_FLOAT(jit::avx2);
#endif #endif
#ifdef __AVX512F__ };
INTRI8_FLOAT(jit::avx512f);
INTRI8_INPLACE_FLOAT(jit::avx512f); #ifdef PADDLE_WITH_XBYAK
template <>
bool VScalKernelImpl<float>::useJIT(int d) {
return gen::VXXJitCode::init(d, 1);
}
#endif #endif
// TODO(TJ): eq16 test and complete avx512
#undef INTRI8_FLOAT #ifdef PADDLE_WITH_MKLML
#undef INTRI8_INPLACE_FLOAT template <>
#undef MKL_FLOAT bool VScalKernelImpl<float>::useMKL(int d) {
#undef MKL_DOUBLE return d > 512;
}
template <>
bool VScalKernelImpl<double>::useMKL(int d) {
return true;
}
#endif
/* VAddBias JitKernel */ /* VAddBias JitKernel */
template <typename T, platform::jit::cpu_isa_t isa, jit_block> template <typename T>
class VAddBiasKernelImpl : public VAddBiasKernel<T> { class VAddBiasKernelImpl : public VAddBiasKernel<T> {
public: public:
explicit VAddBiasKernelImpl(int d) : VAddBiasKernel<T>() { this->num_ = d; } DECLARE_STATIC_FUNC;
void Compute(const T a, const T* x, T* y) const override { explicit VAddBiasKernelImpl(int d) : VAddBiasKernel<T>() {
for (int i = 0; i < this->num_; ++i) { #ifdef PADDLE_WITH_XBYAK
y[i] = x[i] + a; if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 1, false,
sz > 4096 ? sz : 4096));
this->Compute =
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
return;
} }
} #endif
};
#define INTRI8_FLOAT(isa) \
template <> \
void VAddBiasKernelImpl<float, isa, kEQ8>::Compute( \
const float a, const float* x, float* y) const { \
__m256 tmp = _mm256_loadu_ps(x); \
tmp = _mm256_add_ps(tmp, _mm256_set1_ps(a)); \
_mm256_storeu_ps(y, tmp); \
}
#define INTRI16_FLOAT(isa) \ this->Compute = VAddBiasRefer<T>;
template <> \
void VAddBiasKernelImpl<float, isa, kEQ16>::Compute( \
const float a, const float* x, float* y) const { \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
tmp0 = _mm256_add_ps(tmp0, _mm256_set1_ps(a)); \
tmp1 = _mm256_add_ps(tmp1, _mm256_set1_ps(a)); \
_mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y + 8, tmp1); \
} }
#ifdef PADDLE_WITH_XBYAK
#ifdef __AVX__ private:
INTRI8_FLOAT(jit::avx); std::unique_ptr<gen::VXXJitCode> jitcode_{nullptr};
INTRI16_FLOAT(jit::avx);
#endif
#ifdef __AVX2__
INTRI8_FLOAT(jit::avx2);
INTRI16_FLOAT(jit::avx2);
#endif #endif
#ifdef __AVX512F__ };
INTRI8_FLOAT(jit::avx512f);
INTRI16_FLOAT(jit::avx512f); #ifdef PADDLE_WITH_XBYAK
template <>
bool VAddBiasKernelImpl<float>::useJIT(int d) {
return gen::VXXJitCode::init(d, 1);
}
#endif #endif
// TODO(TJ): eq16 test and complete avx512
#undef INTRI8_FLOAT #undef DECLARE_STATIC_FUNC
#undef INTRI16_FLOAT
REGISTER_JITKERNEL(vmul, VMulKernel);
REGISTER_JITKERNEL(vadd, VAddKernel);
REGISTER_JITKERNEL(vaddrelu, VAddReluKernel);
REGISTER_JITKERNEL(vscal, VScalKernel);
REGISTER_JITKERNEL(vaddbias, VAddBiasKernel);
/* VRelu JitKernel */ /* VRelu JitKernel */
template <typename T, platform::jit::cpu_isa_t isa, jit_block> template <typename T, platform::jit::cpu_isa_t isa, jit_block>
...@@ -467,8 +461,6 @@ class VIdentityKernelImpl : public VIdentityKernel<T> { ...@@ -467,8 +461,6 @@ class VIdentityKernelImpl : public VIdentityKernel<T> {
void Compute(const T* x, T* y) const override {} void Compute(const T* x, T* y) const override {}
}; };
REGISTER_JITKERNEL_DEPRECATED(vscal, VScalKernel);
REGISTER_JITKERNEL_DEPRECATED(vaddb, VAddBiasKernel);
REGISTER_JITKERNEL_DEPRECATED(vrelu, VReluKernel); REGISTER_JITKERNEL_DEPRECATED(vrelu, VReluKernel);
REGISTER_JITKERNEL_DEPRECATED(videntity, VIdentityKernel); REGISTER_JITKERNEL_DEPRECATED(videntity, VIdentityKernel);
......
...@@ -409,10 +409,11 @@ class VTanhKernelImpl : public VTanhKernel<T> { ...@@ -409,10 +409,11 @@ class VTanhKernelImpl : public VTanhKernel<T> {
vaddbias_ = KernelPool::Instance().template Get<VAddBiasKernel<T>>(d); vaddbias_ = KernelPool::Instance().template Get<VAddBiasKernel<T>>(d);
} }
void Compute(const T* x, T* y) const override { void Compute(const T* x, T* y) const override {
vscal_->Compute(static_cast<T>(2), x, y); const T a = static_cast<T>(2), b = static_cast<T>(-1);
vscal_->Compute(&a, x, y, this->num_);
vsigmoid_->Compute(y, y); vsigmoid_->Compute(y, y);
vscal_->Compute(static_cast<T>(2), y); vscal_->Compute(&a, y, y, this->num_);
vaddbias_->Compute(static_cast<T>(-1), y, y); vaddbias_->Compute(&b, y, y, this->num_);
} }
private: private:
...@@ -472,10 +473,11 @@ class VTanhKernelImpl : public VTanhKernel<T> { ...@@ -472,10 +473,11 @@ class VTanhKernelImpl : public VTanhKernel<T> {
_mm256_storeu_ps(y, tmp); \ _mm256_storeu_ps(y, tmp); \
x += AVX_FLOAT_BLOCK; \ x += AVX_FLOAT_BLOCK; \
y += AVX_FLOAT_BLOCK; \ y += AVX_FLOAT_BLOCK; \
vscal_->Compute(2.f, x, y); \ const float a = 2.f, b = -1.f; \
vscal_->Compute(&a, x, y, this->num_); \
vsigmoid_->Compute(y, y); \ vsigmoid_->Compute(y, y); \
vscal_->Compute(2.f, y); \ vscal_->Compute(&a, y, y, this->num_); \
vaddbias_->Compute(-1.f, y, y); \ vaddbias_->Compute(&b, y, y, this->num_); \
} }
#define INTRI_GT16_FLOAT(isa, expisa) \ #define INTRI_GT16_FLOAT(isa, expisa) \
...@@ -502,10 +504,11 @@ class VTanhKernelImpl : public VTanhKernel<T> { ...@@ -502,10 +504,11 @@ class VTanhKernelImpl : public VTanhKernel<T> {
} \ } \
x += this->end_; \ x += this->end_; \
y += this->end_; \ y += this->end_; \
vscal_->Compute(2.f, x, y); \ const float a = 2.f, b = -1.f; \
vscal_->Compute(&a, x, y, this->num_); \
vsigmoid_->Compute(y, y); \ vsigmoid_->Compute(y, y); \
vscal_->Compute(2.f, y); \ vscal_->Compute(&a, y, y, this->num_); \
vaddbias_->Compute(-1.f, y, y); \ vaddbias_->Compute(&b, y, y, this->num_); \
} }
#ifdef __AVX__ #ifdef __AVX__
......
...@@ -128,7 +128,7 @@ TEST(JitKernel, vaddbias) { ...@@ -128,7 +128,7 @@ TEST(JitKernel, vaddbias) {
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
auto ttgts = GetCurrentUS(); auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
ker->Compute(a, x_data, ztgt_data); ker->Compute(&a, x_data, ztgt_data, d);
} }
auto ttgte = GetCurrentUS(); auto ttgte = GetCurrentUS();
...@@ -281,10 +281,11 @@ void vtanh_better( ...@@ -281,10 +281,11 @@ void vtanh_better(
const paddle::operators::math::jitkernel::VAddBiasKernel<float>>& const paddle::operators::math::jitkernel::VAddBiasKernel<float>>&
vaddbias, vaddbias,
const int n, const float* x, float* y) { const int n, const float* x, float* y) {
vscal->Compute(2.f, x, y); const float a = 2.f, b = -1.f;
vscal->Compute(&a, x, y, n);
vsigmoid->Compute(y, y); vsigmoid->Compute(y, y);
vscal->Compute(2.f, y); vscal->Compute(&a, y, y, n);
vaddbias->Compute(-1.f, y, y); vaddbias->Compute(&b, y, y, n);
} }
TEST(JitKernel, vtanh) { TEST(JitKernel, vtanh) {
...@@ -531,12 +532,12 @@ TEST(JitKernel, vscal) { ...@@ -531,12 +532,12 @@ TEST(JitKernel, vscal) {
auto ttgts = GetCurrentUS(); auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
ker->Compute(a, x_data, ztgt_data); ker->Compute(&a, x_data, ztgt_data, d);
} }
auto ttgte = GetCurrentUS(); auto ttgte = GetCurrentUS();
auto ttgts1 = GetCurrentUS(); auto ttgts1 = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
ker->Compute(a, y_data); ker->Compute(&a, y_data, y_data, d);
} }
auto ttgte1 = GetCurrentUS(); auto ttgte1 = GetCurrentUS();
VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
反馈
建议
客服 返回
顶部