提交 b68ececb 编写于 作者: T tensor-tang

add vaddrelu jitcode

test=develop
上级 bb09e310
...@@ -70,10 +70,16 @@ bool VAddJitCode::init(int d) { return MayIUse(avx); } ...@@ -70,10 +70,16 @@ bool VAddJitCode::init(int d) { return MayIUse(avx); }
void VAddJitCode::generate() { void VAddJitCode::generate() {
int offset = 0; int offset = 0;
if (with_relu_) {
vxorps(ymm_zero, ymm_zero, ymm_zero);
}
for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) { for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) {
vmovups(ymm_src1, ptr[param1 + offset]); vmovups(ymm_src1, ptr[param1 + offset]);
vmovups(ymm_src2, ptr[param2 + offset]); vmovups(ymm_src2, ptr[param2 + offset]);
vaddps(ymm_dst, ymm_src1, ymm_src2); vaddps(ymm_dst, ymm_src1, ymm_src2);
if (with_relu_) {
vmaxps(ymm_dst, ymm_zero, ymm_dst);
}
vmovups(ptr[param3 + offset], ymm_dst); vmovups(ptr[param3 + offset], ymm_dst);
offset += sizeof(float) * AVX_FLOAT_BLOCK; offset += sizeof(float) * AVX_FLOAT_BLOCK;
} }
...@@ -82,6 +88,9 @@ void VAddJitCode::generate() { ...@@ -82,6 +88,9 @@ void VAddJitCode::generate() {
vmovups(xmm_src1, ptr[param1 + offset]); vmovups(xmm_src1, ptr[param1 + offset]);
vmovups(xmm_src2, ptr[param2 + offset]); vmovups(xmm_src2, ptr[param2 + offset]);
vaddps(xmm_dst, xmm_src1, xmm_src2); vaddps(xmm_dst, xmm_src1, xmm_src2);
if (with_relu_) {
vmaxps(xmm_dst, xmm_zero, xmm_dst);
}
vmovups(ptr[param3 + offset], xmm_dst); vmovups(ptr[param3 + offset], xmm_dst);
offset += sizeof(float) * 4; offset += sizeof(float) * 4;
rest -= 4; rest -= 4;
...@@ -90,6 +99,9 @@ void VAddJitCode::generate() { ...@@ -90,6 +99,9 @@ void VAddJitCode::generate() {
vmovq(xmm_src1, ptr[param1 + offset]); vmovq(xmm_src1, ptr[param1 + offset]);
vmovq(xmm_src2, ptr[param2 + offset]); vmovq(xmm_src2, ptr[param2 + offset]);
vaddps(xmm_dst, xmm_src1, xmm_src2); vaddps(xmm_dst, xmm_src1, xmm_src2);
if (with_relu_) {
vmaxps(xmm_dst, xmm_zero, xmm_dst);
}
vmovq(ptr[param3 + offset], xmm_dst); vmovq(ptr[param3 + offset], xmm_dst);
offset += sizeof(float) * 2; offset += sizeof(float) * 2;
rest -= 2; rest -= 2;
...@@ -98,6 +110,9 @@ void VAddJitCode::generate() { ...@@ -98,6 +110,9 @@ void VAddJitCode::generate() {
vmovss(xmm_src1, ptr[param1 + offset]); vmovss(xmm_src1, ptr[param1 + offset]);
vmovss(xmm_src2, ptr[param2 + offset]); vmovss(xmm_src2, ptr[param2 + offset]);
vaddss(xmm_dst, xmm_src1, xmm_src2); vaddss(xmm_dst, xmm_src1, xmm_src2);
if (with_relu_) {
vmaxps(xmm_dst, xmm_zero, xmm_dst);
}
vmovss(ptr[param3 + offset], xmm_dst); vmovss(ptr[param3 + offset], xmm_dst);
} }
ret(); ret();
......
...@@ -46,35 +46,38 @@ class VMulJitCode : public JitCode { ...@@ -46,35 +46,38 @@ class VMulJitCode : public JitCode {
xmm_t xmm_src1 = xmm_t(0); xmm_t xmm_src1 = xmm_t(0);
xmm_t xmm_src2 = xmm_t(1); xmm_t xmm_src2 = xmm_t(1);
xmm_t xmm_dst = xmm_t(2); xmm_t xmm_dst = xmm_t(1);
ymm_t ymm_src1 = ymm_t(0); ymm_t ymm_src1 = ymm_t(0);
ymm_t ymm_src2 = ymm_t(1); ymm_t ymm_src2 = ymm_t(1);
ymm_t ymm_dst = ymm_t(2); ymm_t ymm_dst = ymm_t(1);
}; };
class VAddJitCode : public JitCode { class VAddJitCode : public JitCode {
public: public:
DECLARE_JIT_CODE(VAddJitCode); DECLARE_JIT_CODE(VAddJitCode);
explicit VAddJitCode(int d, size_t code_size = 256 * 1024, explicit VAddJitCode(int d, bool with_relu, size_t code_size = 256 * 1024,
void* code_ptr = nullptr) void* code_ptr = nullptr)
: JitCode(code_size, code_ptr), num_(d) {} : JitCode(code_size, code_ptr), num_(d), with_relu_(with_relu) {}
static bool init(int d); static bool init(int d);
void generate() override; void generate() override;
private: private:
int num_; int num_;
bool with_relu_;
reg64_t param1{abi_param1}; reg64_t param1{abi_param1};
reg64_t param2{abi_param2}; reg64_t param2{abi_param2};
reg64_t param3{abi_param3}; reg64_t param3{abi_param3};
xmm_t xmm_src1 = xmm_t(0); xmm_t xmm_src1 = xmm_t(0);
xmm_t xmm_src2 = xmm_t(1); xmm_t xmm_src2 = xmm_t(1);
xmm_t xmm_dst = xmm_t(2); xmm_t xmm_dst = xmm_t(1);
xmm_t xmm_zero = xmm_t(2);
ymm_t ymm_src1 = ymm_t(0); ymm_t ymm_src1 = ymm_t(0);
ymm_t ymm_src2 = ymm_t(1); ymm_t ymm_src2 = ymm_t(1);
ymm_t ymm_dst = ymm_t(2); ymm_t ymm_dst = ymm_t(1);
ymm_t ymm_zero = ymm_t(2);
}; };
} // namespace gen } // namespace gen
......
...@@ -75,22 +75,22 @@ class VAddKernel : public Kernel { ...@@ -75,22 +75,22 @@ class VAddKernel : public Kernel {
}; };
template <typename T> template <typename T>
class VScalKernel : public Kernel { class VAddReluKernel : public Kernel {
public: public:
virtual void Compute(const T a, const T *x, T *y) const = 0; void (*Compute)(const T *, const T *, T *, int);
virtual void Compute(const T a, T *x) const = 0;
}; };
template <typename T> template <typename T>
class VAddBiasKernel : public Kernel { class VScalKernel : public Kernel {
public: public:
virtual void Compute(const T a, const T *x, T *y) const = 0; virtual void Compute(const T a, const T *x, T *y) const = 0;
virtual void Compute(const T a, T *x) const = 0;
}; };
template <typename T> template <typename T>
class VAddReluKernel : public Kernel { class VAddBiasKernel : public Kernel {
public: public:
virtual void Compute(const T *x, const T *y, T *z) const = 0; virtual void Compute(const T a, const T *x, T *y) const = 0;
}; };
template <typename T> template <typename T>
......
...@@ -46,6 +46,14 @@ void VAddRefer(const T* x, const T* y, T* z, int n) { ...@@ -46,6 +46,14 @@ void VAddRefer(const T* x, const T* y, T* z, int n) {
} }
} }
template <typename T>
void VAddReluRefer(const T* x, const T* y, T* z, int n) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] + y[i];
z[i] = z[i] > 0 ? z[i] : 0;
}
}
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
template <typename T> template <typename T>
void VMulMKL(const T* x, const T* y, T* z, int n); void VMulMKL(const T* x, const T* y, T* z, int n);
...@@ -131,7 +139,7 @@ class VAddKernelImpl : public VAddKernel<T> { ...@@ -131,7 +139,7 @@ class VAddKernelImpl : public VAddKernel<T> {
explicit VAddKernelImpl(int d) : VAddKernel<T>() { explicit VAddKernelImpl(int d) : VAddKernel<T>() {
if (useJIT(d)) { if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VAddJitCode(d, sz > 4096 ? sz : 4096)); jitcode_.reset(new gen::VAddJitCode(d, false, sz > 4096 ? sz : 4096));
this->Compute = this->Compute =
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>(); jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
return; return;
...@@ -164,10 +172,36 @@ bool VAddKernelImpl<double>::useMKL(int d) { ...@@ -164,10 +172,36 @@ bool VAddKernelImpl<double>::useMKL(int d) {
return true; return true;
} }
/* VAddRelu JitKernel */
template <typename T>
class VAddReluKernelImpl : public VAddReluKernel<T> {
public:
DECLARE_STATIC_FUNC;
explicit VAddReluKernelImpl(int d) : VAddReluKernel<T>() {
if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VAddJitCode(d, true, sz > 4096 ? sz : 4096));
this->Compute =
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
return;
}
this->Compute = VAddReluRefer<T>;
}
private:
std::unique_ptr<gen::VAddJitCode> jitcode_{nullptr};
};
template <>
bool VAddReluKernelImpl<float>::useJIT(int d) {
return gen::VAddJitCode::init(d);
}
#undef DECLARE_STATIC_FUNC #undef DECLARE_STATIC_FUNC
REGISTER_JITKERNEL(vmul, VMulKernel); REGISTER_JITKERNEL(vmul, VMulKernel);
REGISTER_JITKERNEL(vadd, VAddKernel); REGISTER_JITKERNEL(vadd, VAddKernel);
REGISTER_JITKERNEL(vaddrelu, VAddReluKernel);
/* VSCAL JitKernel */ /* VSCAL JitKernel */
template <typename T, platform::jit::cpu_isa_t isa, jit_block> template <typename T, platform::jit::cpu_isa_t isa, jit_block>
...@@ -404,97 +438,9 @@ class VIdentityKernelImpl : public VIdentityKernel<T> { ...@@ -404,97 +438,9 @@ class VIdentityKernelImpl : public VIdentityKernel<T> {
void Compute(const T* x, T* y) const override {} void Compute(const T* x, T* y) const override {}
}; };
/* VAddRelu JitKernel */
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
class VAddReluKernelImpl : public VAddReluKernel<T> {
public:
explicit VAddReluKernelImpl(int d) : VAddReluKernel<T>() { this->num_ = d; }
void Compute(const T* x, const T* y, T* z) const override {
for (int i = 0; i < this->num_; ++i) {
z[i] = x[i] + y[i];
z[i] = z[i] > 0 ? z[i] : 0;
}
}
};
#define INTRI8_FLOAT(isa) \
template <> \
void VAddReluKernelImpl<float, isa, kEQ8>::Compute( \
const float* x, const float* y, float* z) const { \
__m256 tmpx = _mm256_loadu_ps(x); \
__m256 tmpy = _mm256_loadu_ps(y); \
tmpy = _mm256_add_ps(tmpx, tmpy); \
tmpy = _mm256_max_ps(tmpy, _mm256_setzero_ps()); \
_mm256_storeu_ps(z, tmpy); \
}
#define INTRI16_FLOAT(isa) \
template <> \
void VAddReluKernelImpl<float, isa, kEQ16>::Compute( \
const float* x, const float* y, float* z) const { \
__m256 zeros = _mm256_setzero_ps(); \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(y); \
tmp0 = _mm256_add_ps(tmp0, tmp1); \
tmp0 = _mm256_max_ps(tmp0, zeros); \
tmp1 = _mm256_loadu_ps(x + 8); \
__m256 tmp2 = _mm256_loadu_ps(y + 8); \
tmp1 = _mm256_add_ps(tmp1, tmp2); \
tmp1 = _mm256_max_ps(tmp1, zeros); \
_mm256_storeu_ps(z, tmp0); \
_mm256_storeu_ps(z + 8, tmp1); \
}
#define INTRI_COMMON_FLOAT(isa, block) \
template <> \
VAddReluKernelImpl<float, isa, block>::VAddReluKernelImpl(int d) \
: VAddReluKernel<float>() { \
this->num_ = d; \
this->end_ = d - d % AVX_FLOAT_BLOCK; \
this->rest_ = d - this->end_; \
} \
template <> \
void VAddReluKernelImpl<float, isa, block>::Compute( \
const float* x, const float* y, float* z) const { \
__m256 zeros = _mm256_setzero_ps(); \
for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \
__m256 tmpx = _mm256_loadu_ps(x + i); \
__m256 tmpy = _mm256_loadu_ps(y + i); \
tmpy = _mm256_add_ps(tmpx, tmpy); \
tmpy = _mm256_max_ps(tmpy, zeros); \
_mm256_storeu_ps(z + i, tmpy); \
} \
for (int i = this->end_; i < this->num_; ++i) { \
z[i] = x[i] + y[i]; \
z[i] = z[i] > 0 ? z[i] : 0; \
} \
}
#ifdef __AVX__
INTRI8_FLOAT(jit::avx);
INTRI16_FLOAT(jit::avx);
INTRI_COMMON_FLOAT(jit::avx, kGT16);
#endif
#ifdef __AVX2__
INTRI8_FLOAT(jit::avx2);
INTRI16_FLOAT(jit::avx2);
INTRI_COMMON_FLOAT(jit::avx2, kGT16);
#endif
#ifdef __AVX512F__
// TODO(TJ): refine avx512
INTRI8_FLOAT(jit::avx512f);
INTRI16_FLOAT(jit::avx512f);
INTRI_COMMON_FLOAT(jit::avx512f, kGT16);
#endif
#undef INTRI8_FLOAT
#undef INTRI16_FLOAT
#undef INTRI_COMMON_FLOAT
REGISTER_JITKERNEL_DEPRECATED(vscal, VScalKernel); REGISTER_JITKERNEL_DEPRECATED(vscal, VScalKernel);
REGISTER_JITKERNEL_DEPRECATED(vaddb, VAddBiasKernel); REGISTER_JITKERNEL_DEPRECATED(vaddb, VAddBiasKernel);
REGISTER_JITKERNEL_DEPRECATED(vrelu, VReluKernel); REGISTER_JITKERNEL_DEPRECATED(vrelu, VReluKernel);
REGISTER_JITKERNEL_DEPRECATED(vaddrelu, VAddReluKernel);
REGISTER_JITKERNEL_DEPRECATED(videntity, VIdentityKernel); REGISTER_JITKERNEL_DEPRECATED(videntity, VIdentityKernel);
} // namespace jitkernel } // namespace jitkernel
......
...@@ -757,7 +757,7 @@ TEST(JitKernel, vaddrelu) { ...@@ -757,7 +757,7 @@ TEST(JitKernel, vaddrelu) {
auto tmkle = GetCurrentUS(); auto tmkle = GetCurrentUS();
auto ttgts = GetCurrentUS(); auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
ker->Compute(x_data, y_data, ztgt_data); ker->Compute(x_data, y_data, ztgt_data, d);
} }
auto ttgte = GetCurrentUS(); auto ttgte = GetCurrentUS();
VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册