未验证 提交 23544096 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #14374 from tensor-tang/fea/jit/act

add vrelu jitcode
......@@ -118,6 +118,39 @@ void VXXJitCode::generate() {
ret();
}
bool ReluJitCode::init(int d) { return MayIUse(avx); }
void ReluJitCode::generate() {
int offset = 0;
vxorps(ymm_zero, ymm_zero, ymm_zero);
for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) {
vmovups(ymm_src, ptr[param1 + offset]);
vmaxps(ymm_dst, ymm_zero, ymm_src);
vmovups(ptr[param2 + offset], ymm_dst);
offset += sizeof(float) * AVX_FLOAT_BLOCK;
}
int rest = num_ % AVX_FLOAT_BLOCK;
if (rest >= 4) {
vmovups(xmm_src, ptr[param1 + offset]);
vmaxps(xmm_dst, xmm_zero, xmm_src);
vmovups(ptr[param2 + offset], xmm_dst);
offset += sizeof(float) * 4;
rest -= 4;
}
if (rest >= 2) {
vmovups(xmm_src, ptr[param1 + offset]);
vmaxps(xmm_dst, xmm_zero, xmm_src);
vmovq(ptr[param2 + offset], xmm_dst);
offset += sizeof(float) * 2;
rest -= 2;
}
if (rest > 0) {
vmovups(xmm_src, ptr[param1 + offset]);
vmaxps(xmm_dst, xmm_zero, xmm_src);
vmovss(ptr[param2 + offset], xmm_dst);
}
ret();
}
} // namespace gen
} // namespace jitkernel
} // namespace math
......
......@@ -85,6 +85,29 @@ class VXXJitCode : public JitCode {
ymm_t ymm_zero = ymm_t(3);
};
class ReluJitCode : public JitCode {
public:
DECLARE_JIT_CODE(ReluJitCode);
explicit ReluJitCode(int d, size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr), num_(d) {}
static bool init(int d);
void generate() override;
private:
int num_;
reg64_t param1{abi_param1};
reg64_t param2{abi_param2};
xmm_t xmm_zero = xmm_t(0);
xmm_t xmm_src = xmm_t(1);
xmm_t xmm_dst = xmm_t(1);
ymm_t ymm_zero = ymm_t(0);
ymm_t ymm_src = ymm_t(1);
ymm_t ymm_dst = ymm_t(1);
};
} // namespace gen
} // namespace jitkernel
} // namespace math
......
......@@ -97,37 +97,38 @@ class VAddBiasKernel : public Kernel {
template <typename T>
class VActKernel : public Kernel {
public:
virtual void Compute(const T *x, T *y) const = 0;
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
};
template <typename T>
class VReluKernel : public VActKernel<T> {
public:
virtual void Compute(const T *x, T *y) const = 0;
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
void (*Compute)(const T *, T *, int);
};
template <typename T>
class VIdentityKernel : public VActKernel<T> {
public:
virtual void Compute(const T *x, T *y) const = 0;
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
};
template <typename T>
class VExpKernel : public VActKernel<T> {
public:
virtual void Compute(const T *x, T *y) const = 0;
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
};
template <typename T>
class VSigmoidKernel : public VActKernel<T> {
public:
virtual void Compute(const T *x, T *y) const = 0;
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
};
template <typename T>
class VTanhKernel : public VActKernel<T> {
public:
virtual void Compute(const T *x, T *y) const = 0;
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
};
template <typename T>
......
......@@ -71,6 +71,13 @@ void VAddBiasRefer(const T* a, const T* x, T* y, int n) {
}
}
template <typename T>
void VReluRefer(const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = x[i] > 0 ? x[i] : 0;
}
}
#ifdef PADDLE_WITH_MKLML
template <typename T>
void VMulMKL(const T* x, const T* y, T* z, int n);
......@@ -344,124 +351,60 @@ bool VAddBiasKernelImpl<float>::useJIT(int d) {
}
#endif
#undef DECLARE_STATIC_FUNC
REGISTER_JITKERNEL(vmul, VMulKernel);
REGISTER_JITKERNEL(vadd, VAddKernel);
REGISTER_JITKERNEL(vaddrelu, VAddReluKernel);
REGISTER_JITKERNEL(vscal, VScalKernel);
REGISTER_JITKERNEL(vaddbias, VAddBiasKernel);
/* VRelu JitKernel */
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
template <typename T>
class VReluKernelImpl : public VReluKernel<T> {
public:
explicit VReluKernelImpl(int d) : VReluKernel<T>() { this->num_ = d; }
void Compute(const T* x, T* y) const override {
for (int i = 0; i < this->num_; ++i) {
y[i] = x[i] > 0 ? x[i] : 0;
DECLARE_STATIC_FUNC;
explicit VReluKernelImpl(int d) : VReluKernel<T>() {
this->num_ = d; // TODO(TJ): remove me when ComputeDeprecated done
#ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) {
size_t sz = 96 /*init*/ +
d / AVX_FLOAT_BLOCK * 4 /* instructions*/ *
8 /*everage byte for each instruction*/;
jitcode_.reset(new gen::ReluJitCode(d, sz > 4096 ? sz : 4096));
this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>();
return;
}
}
};
#define INTRI8_FLOAT(isa) \
template <> \
void VReluKernelImpl<float, isa, kEQ8>::Compute(const float* x, float* y) \
const { \
__m256 tmp = _mm256_loadu_ps(x); \
tmp = _mm256_max_ps(tmp, _mm256_setzero_ps()); \
_mm256_storeu_ps(y, tmp); \
}
#define INTRI16_FLOAT(isa) \
template <> \
void VReluKernelImpl<float, isa, kEQ16>::Compute(const float* x, float* y) \
const { \
__m256 zeros = _mm256_setzero_ps(); \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
tmp0 = _mm256_max_ps(tmp0, zeros); \
tmp1 = _mm256_max_ps(tmp1, zeros); \
_mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y + 8, tmp1); \
}
#endif
#define INTRI_GT8LT16_FLOAT(isa) \
template <> \
VReluKernelImpl<float, isa, kGT8LT16>::VReluKernelImpl(int d) \
: VReluKernel<float>() { \
this->num_ = d; \
this->end_ = AVX_FLOAT_BLOCK; \
this->rest_ = d - AVX_FLOAT_BLOCK; \
} \
template <> \
void VReluKernelImpl<float, isa, kGT8LT16>::Compute(const float* x, \
float* y) const { \
__m256 zeros = _mm256_setzero_ps(); \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + this->rest_); \
tmp0 = _mm256_max_ps(tmp0, zeros); \
tmp1 = _mm256_max_ps(tmp1, zeros); \
_mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y + this->rest_, tmp1); \
this->Compute = VReluRefer<T>;
}
#define INTRI_GT16_FLOAT(isa) \
template <> \
VReluKernelImpl<float, isa, kGT16>::VReluKernelImpl(int d) \
: VReluKernel<float>() { \
this->num_ = d; \
this->end_ = d - d % AVX_FLOAT_BLOCK; \
this->rest_ = d - AVX_FLOAT_BLOCK; \
} \
template <> \
void VReluKernelImpl<float, isa, kGT16>::Compute(const float* x, float* y) \
const { \
__m256 zeros = _mm256_setzero_ps(); \
for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \
__m256 tmp = _mm256_loadu_ps(x + i); \
tmp = _mm256_max_ps(tmp, zeros); \
_mm256_storeu_ps(y + i, tmp); \
} \
__m256 tmp = _mm256_loadu_ps(x + this->rest_); \
tmp = _mm256_max_ps(tmp, zeros); \
_mm256_storeu_ps(y + this->rest_, tmp); \
void ComputeDeprecated(const T* x, T* y) const override {
VReluRefer(x, y, this->num_);
}
#ifdef PADDLE_WITH_XBYAK
#ifdef __AVX__
INTRI8_FLOAT(jit::avx);
INTRI16_FLOAT(jit::avx);
INTRI_GT8LT16_FLOAT(jit::avx);
INTRI_GT16_FLOAT(jit::avx);
#endif
#ifdef __AVX2__
INTRI8_FLOAT(jit::avx2);
INTRI16_FLOAT(jit::avx2);
INTRI_GT8LT16_FLOAT(jit::avx2);
INTRI_GT16_FLOAT(jit::avx2);
private:
std::unique_ptr<gen::ReluJitCode> jitcode_{nullptr};
#endif
#ifdef __AVX512F__
// TODO(TJ): refine avx512
INTRI8_FLOAT(jit::avx512f);
INTRI16_FLOAT(jit::avx512f);
INTRI_GT8LT16_FLOAT(jit::avx512f);
INTRI_GT16_FLOAT(jit::avx512f);
};
#ifdef PADDLE_WITH_XBYAK
template <>
bool VReluKernelImpl<float>::useJIT(int d) {
return gen::ReluJitCode::init(d);
}
#endif
#undef INTRI8_FLOAT
#undef INTRI16_FLOAT
#undef INTRI_GT8LT16_FLOAT
#undef INTRI_GT16_FLOAT
#undef DECLARE_STATIC_FUNC
REGISTER_JITKERNEL(vmul, VMulKernel);
REGISTER_JITKERNEL(vadd, VAddKernel);
REGISTER_JITKERNEL(vaddrelu, VAddReluKernel);
REGISTER_JITKERNEL(vscal, VScalKernel);
REGISTER_JITKERNEL(vaddbias, VAddBiasKernel);
REGISTER_JITKERNEL(vrelu, VReluKernel);
/* An empty JitKernel */
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
class VIdentityKernelImpl : public VIdentityKernel<T> {
public:
explicit VIdentityKernelImpl(int d) : VIdentityKernel<T>() { this->num_ = d; }
void Compute(const T* x, T* y) const override {}
void ComputeDeprecated(const T* x, T* y) const override {}
};
REGISTER_JITKERNEL_DEPRECATED(vrelu, VReluKernel);
REGISTER_JITKERNEL_DEPRECATED(videntity, VIdentityKernel);
} // namespace jitkernel
......
......@@ -175,26 +175,26 @@ class LSTMKernelImpl : public LSTMKernel<T> {
void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data,
T* checked) const override {
// gates: W_ch, W_ih, W_fh, W_oh
act_gate_d3_->Compute(gates + d_, gates + d_);
act_gate_d3_->ComputeDeprecated(gates + d_, gates + d_);
/* C_t = C_t-1 * fgated + cand_gated * igated */
act_cand_d_->Compute(gates, gates);
act_cand_d_->ComputeDeprecated(gates, gates);
vmul_d_->Compute(gates, gates + d_, gates + d_, d_);
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_);
vadd_d_->Compute(gates + d_, gates + d2_, ct, d_);
/* H_t = act_cell(C_t) * ogated */
act_cell_d_->Compute(ct, gates + d2_);
act_cell_d_->ComputeDeprecated(ct, gates + d2_);
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
}
void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override {
/* C_t = igated * cgated*/
act_gate_d_->Compute(gates + d_, gates + d_);
act_cand_d_->Compute(gates, gates);
act_gate_d_->ComputeDeprecated(gates + d_, gates + d_);
act_cand_d_->ComputeDeprecated(gates, gates);
vmul_d_->Compute(gates, gates + d_, ct, d_);
/* H_t = act_cell(C_t) * ogated */
act_gate_d_->Compute(gates + d3_, gates + d3_);
act_cell_d_->Compute(ct, gates + d2_);
act_gate_d_->ComputeDeprecated(gates + d3_, gates + d3_);
act_cell_d_->ComputeDeprecated(ct, gates + d2_);
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
}
......@@ -292,32 +292,32 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
vmul_d_->Compute(wp_data, ct_1, checked, d_);
vmul_d_->Compute(wp_data + d_, ct_1, checked + d_, d_);
vadd_d2_->Compute(checked, gates + d_, gates + d_, d2_);
act_gate_d2_->Compute(gates + d_, gates + d_);
act_gate_d2_->ComputeDeprecated(gates + d_, gates + d_);
/* C_t = C_t-1 * fgated + cand_gated * igated*/
act_cand_d_->Compute(gates, gates);
act_cand_d_->ComputeDeprecated(gates, gates);
vmul_d_->Compute(gates, gates + d_, gates + d_, d_);
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_);
vadd_d_->Compute(gates + d_, gates + d2_, ct, d_);
/* get ogated*/
vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_);
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_);
act_gate_d_->Compute(gates + d3_, gates + d3_);
act_gate_d_->ComputeDeprecated(gates + d3_, gates + d3_);
/* H_t = act_cell(C_t) * ogated */
act_cell_d_->Compute(ct, gates + d2_);
act_cell_d_->ComputeDeprecated(ct, gates + d2_);
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
}
void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override {
/* C_t = igated * cgated*/
act_gate_d_->Compute(gates + d_, gates + d_);
act_cand_d_->Compute(gates, gates);
act_gate_d_->ComputeDeprecated(gates + d_, gates + d_);
act_cand_d_->ComputeDeprecated(gates, gates);
vmul_d_->Compute(gates, gates + d_, ct, d_);
/* get outgated, put W_oc * C_t on igated */
vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_);
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_);
/* H_t = act_cell(C_t) * ogated */
act_gate_d_->Compute(gates + d3_, gates + d3_);
act_cell_d_->Compute(ct, gates + d2_);
act_gate_d_->ComputeDeprecated(gates + d3_, gates + d3_);
act_cell_d_->ComputeDeprecated(ct, gates + d2_);
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
}
......@@ -376,20 +376,20 @@ class GRUKernelImpl : public GRUKernel<T> {
}
void ComputeH1(T* gates, T* ht) const override {
act_gate_d_->Compute(gates, gates);
act_state_d_->Compute(gates + d2_, gates + d2_);
act_gate_d_->ComputeDeprecated(gates, gates);
act_state_d_->ComputeDeprecated(gates + d2_, gates + d2_);
vmul_d_->Compute(gates, gates + d2_, ht, d_);
}
void ComputeHtPart1(T* gates, const T* ht_1, T* ht) const override {
// W: {W_update, W_reset; W_state}
act_gate_d2_->Compute(gates, gates);
act_gate_d2_->ComputeDeprecated(gates, gates);
vmul_d_->Compute(ht_1, gates + d_, ht, d_);
}
void ComputeHtPart2(T* gates, const T* ht_1, T* ht) const override {
T* y = gates + d2_;
act_state_d_->Compute(y, y);
act_state_d_->ComputeDeprecated(y, y);
// out = zt*ht~ + (1-zt)*ht_1
for (int i = 0; i < d_; ++i) {
ht[i] = gates[i] * y[i] + (static_cast<T>(1) - gates[i]) * ht_1[i];
......
......@@ -92,7 +92,7 @@ TEST(JitKernel, vrelu) {
#endif
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(x_data, ztgt_data);
ker->Compute(x_data, ztgt_data, d);
}
auto ttgte = GetCurrentUS();
VLOG(30) << "Vec size " << d
......@@ -181,7 +181,7 @@ TEST(JitKernel, vexp) {
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(x_data, ztgt_data);
ker->ComputeDeprecated(x_data, ztgt_data);
}
auto ttgte = GetCurrentUS();
......@@ -222,7 +222,7 @@ void vsigmoid_better(
y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = 0.f - y[i];
}
vexp->Compute(y, y);
vexp->ComputeDeprecated(y, y);
for (int i = 0; i < n; ++i) {
y[i] = 1.f / (1.f + y[i]);
}
......@@ -253,7 +253,7 @@ TEST(JitKernel, vsigmoid) {
auto trefe = GetCurrentUS();
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(x_data, ztgt_data);
ker->ComputeDeprecated(x_data, ztgt_data);
}
auto ttgte = GetCurrentUS();
......@@ -287,7 +287,7 @@ void vtanh_better(
const int n, const float* x, float* y) {
const float a = 2.f, b = -1.f;
vscal->Compute(&a, x, y, n);
vsigmoid->Compute(y, y);
vsigmoid->ComputeDeprecated(y, y);
vscal->Compute(&a, y, y, n);
vaddbias->Compute(&b, y, y, n);
}
......@@ -321,7 +321,7 @@ TEST(JitKernel, vtanh) {
auto trefe = GetCurrentUS();
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(x_data, ztgt_data);
ker->ComputeDeprecated(x_data, ztgt_data);
}
auto ttgte = GetCurrentUS();
......@@ -344,8 +344,8 @@ void lstm_ctht_ref(
const std::shared_ptr<
const paddle::operators::math::jitkernel::VExpKernel<float>>& vexp_1,
const int d, float* gates, const float* ct_1, float* ct, float* ht) {
vsigmoid_3d->Compute(gates + d, gates + d);
vtanh_d->Compute(gates, gates);
vsigmoid_3d->ComputeDeprecated(gates + d, gates + d);
vtanh_d->ComputeDeprecated(gates, gates);
const float *i = gates + d, *f = gates + d * 2, *o = gates + d * 3;
const float min = SIGMOID_THRESHOLD_MIN;
const float max = SIGMOID_THRESHOLD_MAX;
......@@ -355,7 +355,7 @@ void lstm_ctht_ref(
// H_t = act_cell(C_t) * ogated
float tmp = ct[k] * 2;
tmp = 0.f - ((tmp < min) ? min : ((tmp > max) ? max : tmp));
vexp_1->Compute(&tmp, &tmp);
vexp_1->ComputeDeprecated(&tmp, &tmp);
tmp = 2.f / (1.f + tmp) - 1.f;
ht[k] = tmp * o[k];
}
......@@ -373,13 +373,13 @@ void lstm_ctht_better(
const paddle::operators::math::jitkernel::VAddKernel<float>>& vadd_d,
const int d, float* gates, const float* ct_1, float* ct, float* ht) {
int d2 = d * 2;
vsigmoid_3d->Compute(gates + d, gates + d);
vtanh_d->Compute(gates, gates);
vsigmoid_3d->ComputeDeprecated(gates + d, gates + d);
vtanh_d->ComputeDeprecated(gates, gates);
vmul_d->Compute(gates, gates + d, gates + d, d);
vmul_d->Compute(ct_1, gates + d2, gates + d2, d);
vadd_d->Compute(gates + d, gates + d2, ct, d);
/* H_t = act_cell(C_t) * ogated */
vtanh_d->Compute(ct, gates + d2);
vtanh_d->ComputeDeprecated(ct, gates + d2);
vmul_d->Compute(gates + d2, gates + d * 3, ht, d);
}
......@@ -736,7 +736,7 @@ void vaddrelu_better(
const paddle::operators::math::jitkernel::VReluKernel<float>>& vrelu,
const float* x, const float* y, float* z, int d) {
vadd->Compute(x, y, z, d);
vrelu->Compute(z, z);
vrelu->ComputeDeprecated(z, z);
}
TEST(JitKernel, vaddrelu) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册