提交 bb09e310 编写于 作者: T tensor-tang

add vadd jitcode

test=develop
上级 d55481cf
......@@ -66,6 +66,42 @@ void VMulJitCode::generate() {
ret();
}
bool VAddJitCode::init(int d) { return MayIUse(avx); }
void VAddJitCode::generate() {
int offset = 0;
for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) {
vmovups(ymm_src1, ptr[param1 + offset]);
vmovups(ymm_src2, ptr[param2 + offset]);
vaddps(ymm_dst, ymm_src1, ymm_src2);
vmovups(ptr[param3 + offset], ymm_dst);
offset += sizeof(float) * AVX_FLOAT_BLOCK;
}
int rest = num_ % AVX_FLOAT_BLOCK;
if (rest >= 4) {
vmovups(xmm_src1, ptr[param1 + offset]);
vmovups(xmm_src2, ptr[param2 + offset]);
vaddps(xmm_dst, xmm_src1, xmm_src2);
vmovups(ptr[param3 + offset], xmm_dst);
offset += sizeof(float) * 4;
rest -= 4;
}
if (rest >= 2) {
vmovq(xmm_src1, ptr[param1 + offset]);
vmovq(xmm_src2, ptr[param2 + offset]);
vaddps(xmm_dst, xmm_src1, xmm_src2);
vmovq(ptr[param3 + offset], xmm_dst);
offset += sizeof(float) * 2;
rest -= 2;
}
if (rest > 0) {
vmovss(xmm_src1, ptr[param1 + offset]);
vmovss(xmm_src2, ptr[param2 + offset]);
vaddss(xmm_dst, xmm_src1, xmm_src2);
vmovss(ptr[param3 + offset], xmm_dst);
}
ret();
}
} // namespace gen
} // namespace jitkernel
} // namespace math
......
......@@ -53,6 +53,30 @@ class VMulJitCode : public JitCode {
ymm_t ymm_dst = ymm_t(2);
};
class VAddJitCode : public JitCode {
public:
DECLARE_JIT_CODE(VAddJitCode);
explicit VAddJitCode(int d, size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr), num_(d) {}
static bool init(int d);
void generate() override;
private:
int num_;
reg64_t param1{abi_param1};
reg64_t param2{abi_param2};
reg64_t param3{abi_param3};
xmm_t xmm_src1 = xmm_t(0);
xmm_t xmm_src2 = xmm_t(1);
xmm_t xmm_dst = xmm_t(2);
ymm_t ymm_src1 = ymm_t(0);
ymm_t ymm_src2 = ymm_t(1);
ymm_t ymm_dst = ymm_t(2);
};
} // namespace gen
} // namespace jitkernel
} // namespace math
......
......@@ -71,7 +71,7 @@ class VMulKernel : public Kernel {
template <typename T>
class VAddKernel : public Kernel {
public:
virtual void Compute(const T *x, const T *y, T *z) const = 0;
void (*Compute)(const T *, const T *, T *, int);
};
template <typename T>
......
......@@ -39,6 +39,13 @@ void VMulRefer(const T* x, const T* y, T* z, int n) {
}
}
template <typename T>
void VAddRefer(const T* x, const T* y, T* z, int n) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] + y[i];
}
}
#ifdef PADDLE_WITH_MKLML
template <typename T>
void VMulMKL(const T* x, const T* y, T* z, int n);
......@@ -47,22 +54,38 @@ template <>
void VMulMKL<float>(const float* x, const float* y, float* z, int n) {
platform::dynload::vsMul(n, x, y, z);
}
template <>
void VMulMKL<double>(const double* x, const double* y, double* z, int n) {
platform::dynload::vdMul(n, x, y, z);
}
template <typename T>
void VAddMKL(const T* x, const T* y, T* z, int n);
template <>
void VAddMKL<float>(const float* x, const float* y, float* z, int n) {
platform::dynload::vsAdd(n, x, y, z);
}
template <>
void VAddMKL<double>(const double* x, const double* y, double* z, int n) {
platform::dynload::vdAdd(n, x, y, z);
}
#endif
#define DECLARE_STATIC_FUNC \
static inline std::string name(int d) { \
PADDLE_THROW("DType should be either float or double"); \
} \
static inline bool useJIT(int d) { return false; } \
static inline bool useMKL(int d) { return false; }
/* VMUL JitKernel */
template <typename T>
class VMulKernelImpl : public VMulKernel<T> {
public:
static inline std::string name(int d) {
PADDLE_THROW("DType should be either float or double");
}
static inline bool useJIT(int d) { return false; }
static inline bool useMKL(int d) { return false; }
DECLARE_STATIC_FUNC;
explicit VMulKernelImpl(int d) : VMulKernel<T>() {
if (useJIT(d)) {
// roughly estimate the size of code
......@@ -100,63 +123,51 @@ bool VMulKernelImpl<double>::useMKL(int d) {
return true;
}
REGISTER_JITKERNEL(vmul, VMulKernel);
/* VADD JitKernel */
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
/* VAdd JitKernel */
template <typename T>
class VAddKernelImpl : public VAddKernel<T> {
public:
explicit VAddKernelImpl(int d) : VAddKernel<T>() { this->num_ = d; }
void Compute(const T* x, const T* y, T* z) const override {
for (int i = 0; i < this->num_; ++i) {
z[i] = x[i] + y[i];
DECLARE_STATIC_FUNC;
explicit VAddKernelImpl(int d) : VAddKernel<T>() {
if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VAddJitCode(d, sz > 4096 ? sz : 4096));
this->Compute =
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
return;
}
#ifdef PADDLE_WITH_MKLML
if (useMKL(d)) {
this->Compute = VAddMKL<T>;
return;
}
#endif
this->Compute = VAddRefer<T>;
}
private:
std::unique_ptr<gen::VAddJitCode> jitcode_{nullptr};
};
#ifdef PADDLE_WITH_MKLML
#define MKL_FLOAT(isa, block) \
template <> \
void VAddKernelImpl<float, isa, block>::Compute( \
const float* x, const float* y, float* z) const { \
platform::dynload::vsAdd(this->num_, x, y, z); \
}
template <>
bool VAddKernelImpl<float>::useJIT(int d) {
return gen::VAddJitCode::init(d);
}
#define MKL_DOUBLE(isa, block) \
template <> \
void VAddKernelImpl<double, isa, block>::Compute( \
const double* x, const double* y, double* z) const { \
platform::dynload::vdAdd(this->num_, x, y, z); \
}
template <>
bool VAddKernelImpl<float>::useMKL(int d) {
return d > 512;
}
FOR_EACH_ISA(MKL_FLOAT, kGT16);
FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
#endif
template <>
bool VAddKernelImpl<double>::useMKL(int d) {
return true;
}
#define INTRI8_FLOAT(isa) \
template <> \
void VAddKernelImpl<float, isa, kEQ8>::Compute( \
const float* x, const float* y, float* z) const { \
__m256 tmpx, tmpy; \
tmpx = _mm256_loadu_ps(x); \
tmpy = _mm256_loadu_ps(y); \
tmpx = _mm256_add_ps(tmpx, tmpy); \
_mm256_storeu_ps(z, tmpx); \
}
#ifdef __AVX__
INTRI8_FLOAT(jit::avx);
#endif
#ifdef __AVX2__
INTRI8_FLOAT(jit::avx2);
#endif
#ifdef __AVX512F__
INTRI8_FLOAT(jit::avx512f);
#endif
// TODO(TJ): eq16 test and complete avx512
#undef DECLARE_STATIC_FUNC
#undef INTRI8_FLOAT
#undef MKL_FLOAT
#undef MKL_DOUBLE
REGISTER_JITKERNEL(vmul, VMulKernel);
REGISTER_JITKERNEL(vadd, VAddKernel);
/* VSCAL JitKernel */
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
......@@ -480,7 +491,6 @@ INTRI_COMMON_FLOAT(jit::avx512f, kGT16);
#undef INTRI16_FLOAT
#undef INTRI_COMMON_FLOAT
REGISTER_JITKERNEL_DEPRECATED(vadd, VAddKernel);
REGISTER_JITKERNEL_DEPRECATED(vscal, VScalKernel);
REGISTER_JITKERNEL_DEPRECATED(vaddb, VAddBiasKernel);
REGISTER_JITKERNEL_DEPRECATED(vrelu, VReluKernel);
......
......@@ -181,7 +181,7 @@ class LSTMKernelImpl : public LSTMKernel<T> {
act_cand_d_->Compute(gates, gates);
vmul_d_->Compute(gates, gates + d_, gates + d_, d_);
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_);
vadd_d_->Compute(gates + d_, gates + d2_, ct);
vadd_d_->Compute(gates + d_, gates + d2_, ct, d_);
/* H_t = act_cell(C_t) * ogated */
act_cell_d_->Compute(ct, gates + d2_);
......@@ -291,16 +291,16 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
/* get fgated and igated*/
vmul_d_->Compute(wp_data, ct_1, checked, d_);
vmul_d_->Compute(wp_data + d_, ct_1, checked + d_, d_);
vadd_d2_->Compute(checked, gates + d_, gates + d_);
vadd_d2_->Compute(checked, gates + d_, gates + d_, d2_);
act_gate_d2_->Compute(gates + d_, gates + d_);
/* C_t = C_t-1 * fgated + cand_gated * igated*/
act_cand_d_->Compute(gates, gates);
vmul_d_->Compute(gates, gates + d_, gates + d_, d_);
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_);
vadd_d_->Compute(gates + d_, gates + d2_, ct);
vadd_d_->Compute(gates + d_, gates + d2_, ct, d_);
/* get ogated*/
vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_);
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_);
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_);
act_gate_d_->Compute(gates + d3_, gates + d3_);
/* H_t = act_cell(C_t) * ogated */
act_cell_d_->Compute(ct, gates + d2_);
......@@ -314,7 +314,7 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
vmul_d_->Compute(gates, gates + d_, ct, d_);
/* get outgated, put W_oc * C_t on igated */
vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_);
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_);
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_);
/* H_t = act_cell(C_t) * ogated */
act_gate_d_->Compute(gates + d3_, gates + d3_);
act_cell_d_->Compute(ct, gates + d2_);
......
......@@ -371,7 +371,7 @@ void lstm_ctht_better(
vtanh_d->Compute(gates, gates);
vmul_d->Compute(gates, gates + d, gates + d, d);
vmul_d->Compute(ct_1, gates + d2, gates + d2, d);
vadd_d->Compute(gates + d, gates + d2, ct);
vadd_d->Compute(gates + d, gates + d2, ct, d);
/* H_t = act_cell(C_t) * ogated */
vtanh_d->Compute(ct, gates + d2);
vmul_d->Compute(gates + d2, gates + d * 3, ht, d);
......@@ -695,7 +695,7 @@ TEST(JitKernel, vadd) {
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(x_data, y_data, ztgt_data);
ker->Compute(x_data, y_data, ztgt_data, d);
}
auto ttgte = GetCurrentUS();
......@@ -723,8 +723,8 @@ void vaddrelu_better(
const paddle::operators::math::jitkernel::VAddKernel<float>>& vadd,
const std::shared_ptr<
const paddle::operators::math::jitkernel::VReluKernel<float>>& vrelu,
const float* x, const float* y, float* z) {
vadd->Compute(x, y, z);
const float* x, const float* y, float* z, int d) {
vadd->Compute(x, y, z, d);
vrelu->Compute(z, z);
}
......@@ -752,7 +752,7 @@ TEST(JitKernel, vaddrelu) {
auto trefe = GetCurrentUS();
auto tmkls = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
vaddrelu_better(vadd, vrelu, x_data, y_data, zref_data);
vaddrelu_better(vadd, vrelu, x_data, y_data, zref_data, d);
}
auto tmkle = GetCurrentUS();
auto ttgts = GetCurrentUS();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册