提交 03e11f3f 编写于 作者: T tensor-tang

add vscal jitcode

上级 5b7a9dd7
......@@ -96,6 +96,41 @@ void VVVJitCode::generate() {
}
ret();
}
bool VScalJitCode::init(int d) { return MayIUse(avx); }
void VScalJitCode::generate() {
int offset = 0;
vbroadcastss(ymm_src1, ptr[param1]);
for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) {
vmovups(ymm_src2, ptr[param2 + offset]);
vmulps(ymm_dst, ymm_src1, ymm_src2);
vmovups(ptr[param3 + offset], ymm_dst);
offset += sizeof(float) * AVX_FLOAT_BLOCK;
}
int rest = num_ % AVX_FLOAT_BLOCK;
if (rest >= 4) {
vmovups(xmm_src2, ptr[param2 + offset]);
vmulps(xmm_dst, xmm_src1, xmm_src2);
vmovups(ptr[param3 + offset], xmm_dst);
offset += sizeof(float) * 4;
rest -= 4;
}
if (rest >= 2) {
vmovq(xmm_src2, ptr[param2 + offset]);
vmulps(xmm_dst, xmm_src1, xmm_src2);
vmovq(ptr[param3 + offset], xmm_dst);
offset += sizeof(float) * 2;
rest -= 2;
}
if (rest > 0) {
vmovss(xmm_src2, ptr[param2 + offset]);
vmulss(xmm_dst, xmm_src1, xmm_src2);
vmovss(ptr[param3 + offset], xmm_dst);
}
ret();
}
} // namespace gen
} // namespace jitkernel
} // namespace math
......
......@@ -29,9 +29,9 @@ using ymm_t = const Xbyak::Ymm;
using zmm_t = const Xbyak::Zmm;
using Label = Xbyak::Label;
// function: vec = Operand(vec, vec) (maybe with relu)
typedef enum { mul = 0, add } operand_type;
// function: vec = Operand(vec, vec) (maybe with relu)
class VVVJitCode : public JitCode {
public:
const char* name() const override {
......@@ -41,7 +41,7 @@ class VVVJitCode : public JitCode {
} else if (type_ == operand_type::add) {
base += "_Add";
}
base += (with_relu_ ? "_relu" : "");
base += (with_relu_ ? "_Relu" : "");
return base.c_str();
}
explicit VVVJitCode(int d, operand_type type, bool with_relu,
......@@ -72,6 +72,32 @@ class VVVJitCode : public JitCode {
ymm_t ymm_zero = ymm_t(2);
};
class VScalJitCode : public JitCode {
public:
DECLARE_JIT_CODE(VScalJitCode);
explicit VScalJitCode(int d, size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr), num_(d) {}
static bool init(int d);
void generate() override;
private:
int num_;
reg64_t param1{abi_param1};
reg64_t param2{abi_param2};
reg64_t param3{abi_param3};
xmm_t xmm_src1 = xmm_t(0);
xmm_t xmm_src2 = xmm_t(1);
xmm_t xmm_dst = xmm_t(1);
xmm_t xmm_zero = xmm_t(2);
ymm_t ymm_src1 = ymm_t(0);
ymm_t ymm_src2 = ymm_t(1);
ymm_t ymm_dst = ymm_t(1);
ymm_t ymm_zero = ymm_t(2);
};
} // namespace gen
} // namespace jitkernel
} // namespace math
......
......@@ -83,8 +83,7 @@ class VAddReluKernel : public Kernel {
template <typename T>
class VScalKernel : public Kernel {
public:
virtual void Compute(const T a, const T *x, T *y) const = 0;
virtual void Compute(const T a, T *x) const = 0;
void (*Compute)(const T *, const T *, T *, int);
};
template <typename T>
......
......@@ -57,6 +57,13 @@ void VAddReluRefer(const T* x, const T* y, T* z, int n) {
}
}
template <typename T>
void VScalRefer(const T* a, const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = a[0] * x[i];
}
}
#ifdef PADDLE_WITH_MKLML
template <typename T>
void VMulMKL(const T* x, const T* y, T* z, int n);
......@@ -83,6 +90,28 @@ template <>
void VAddMKL<double>(const double* x, const double* y, double* z, int n) {
platform::dynload::vdAdd(n, x, y, z);
}
template <typename T>
void VScalMKL(const T* a, const T* x, T* y, int n);
template <>
void VScalMKL<float>(const float* a, const float* x, float* y, int n) {
if (x == y) {
platform::dynload::cblas_sscal(n, *a, y, 1);
} else {
VScalRefer<float>(a, x, y, n);
}
}
template <>
void VScalMKL<double>(const double* a, const double* x, double* y, int n) {
if (x == y) {
platform::dynload::cblas_dscal(n, *a, y, 1);
} else {
VScalRefer<double>(a, x, y, n);
}
}
#endif
#define DECLARE_STATIC_FUNC \
......@@ -226,87 +255,60 @@ bool VAddReluKernelImpl<float>::useJIT(int d) {
}
#endif
#undef DECLARE_STATIC_FUNC
REGISTER_JITKERNEL(vmul, VMulKernel);
REGISTER_JITKERNEL(vadd, VAddKernel);
REGISTER_JITKERNEL(vaddrelu, VAddReluKernel);
/* VSCAL JitKernel */
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
/* VScal JitKernel */
template <typename T>
class VScalKernelImpl : public VScalKernel<T> {
public:
explicit VScalKernelImpl(int d) : VScalKernel<T>() { this->num_ = d; }
void Compute(const T a, const T* x, T* y) const override {
for (int i = 0; i < this->num_; ++i) {
y[i] = a * x[i];
}
}
void Compute(const T a, T* x) const override {
for (int i = 0; i < this->num_; ++i) {
x[i] = a * x[i];
DECLARE_STATIC_FUNC;
explicit VScalKernelImpl(int d) : VScalKernel<T>() {
#ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VScalJitCode(d, sz > 4096 ? sz : 4096));
this->Compute =
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
return;
}
}
};
#endif
#ifdef PADDLE_WITH_MKLML
#define MKL_FLOAT(isa, block) \
template <> \
void VScalKernelImpl<float, isa, block>::Compute(const float a, float* x) \
const { \
platform::dynload::cblas_sscal(this->num_, a, x, 1); \
}
#define MKL_DOUBLE(isa, block) \
template <> \
void VScalKernelImpl<double, isa, block>::Compute(const double a, double* x) \
const { \
platform::dynload::cblas_dscal(this->num_, a, x, 1); \
}
FOR_EACH_ISA(MKL_FLOAT, kGT16);
FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
if (useMKL(d)) {
this->Compute = VScalMKL<T>;
return;
}
#endif
#define INTRI8_FLOAT(isa) \
template <> \
void VScalKernelImpl<float, isa, kEQ8>::Compute( \
const float a, const float* x, float* y) const { \
__m256 tmp; \
__m256 scalar = _mm256_set1_ps(a); \
tmp = _mm256_loadu_ps(x); \
tmp = _mm256_mul_ps(tmp, scalar); \
_mm256_storeu_ps(y, tmp); \
}
#define INTRI8_INPLACE_FLOAT(isa) \
template <> \
void VScalKernelImpl<float, isa, kEQ8>::Compute(const float a, float* x) \
const { \
__m256 tmp; \
__m256 scalar = _mm256_set1_ps(a); \
tmp = _mm256_loadu_ps(x); \
tmp = _mm256_mul_ps(tmp, scalar); \
_mm256_storeu_ps(x, tmp); \
this->Compute = VScalRefer<T>;
}
#ifdef PADDLE_WITH_XBYAK
#ifdef __AVX__
INTRI8_FLOAT(jit::avx);
INTRI8_INPLACE_FLOAT(jit::avx);
private:
std::unique_ptr<gen::VScalJitCode> jitcode_{nullptr};
#endif
#ifdef __AVX2__
INTRI8_FLOAT(jit::avx2);
INTRI8_INPLACE_FLOAT(jit::avx2);
};
#ifdef PADDLE_WITH_XBYAK
template <>
bool VScalKernelImpl<float>::useJIT(int d) {
return gen::VScalJitCode::init(d);
}
#endif
#ifdef __AVX512F__
INTRI8_FLOAT(jit::avx512f);
INTRI8_INPLACE_FLOAT(jit::avx512f);
#ifdef PADDLE_WITH_MKLML
template <>
bool VScalKernelImpl<float>::useMKL(int d) {
return d > 512;
}
template <>
bool VScalKernelImpl<double>::useMKL(int d) {
return true;
}
#endif
// TODO(TJ): eq16 test and complete avx512
#undef INTRI8_FLOAT
#undef INTRI8_INPLACE_FLOAT
#undef MKL_FLOAT
#undef MKL_DOUBLE
#undef DECLARE_STATIC_FUNC
REGISTER_JITKERNEL(vmul, VMulKernel);
REGISTER_JITKERNEL(vadd, VAddKernel);
REGISTER_JITKERNEL(vscal, VScalKernel);
REGISTER_JITKERNEL(vaddrelu, VAddReluKernel);
/* VAddBias JitKernel */
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
......@@ -467,7 +469,6 @@ class VIdentityKernelImpl : public VIdentityKernel<T> {
void Compute(const T* x, T* y) const override {}
};
REGISTER_JITKERNEL_DEPRECATED(vscal, VScalKernel);
REGISTER_JITKERNEL_DEPRECATED(vaddb, VAddBiasKernel);
REGISTER_JITKERNEL_DEPRECATED(vrelu, VReluKernel);
REGISTER_JITKERNEL_DEPRECATED(videntity, VIdentityKernel);
......
......@@ -409,9 +409,10 @@ class VTanhKernelImpl : public VTanhKernel<T> {
vaddbias_ = KernelPool::Instance().template Get<VAddBiasKernel<T>>(d);
}
void Compute(const T* x, T* y) const override {
vscal_->Compute(static_cast<T>(2), x, y);
const T a = static_cast<T>(2);
vscal_->Compute(&a, x, y, this->num_);
vsigmoid_->Compute(y, y);
vscal_->Compute(static_cast<T>(2), y);
vscal_->Compute(&a, y, y, this->num_);
vaddbias_->Compute(static_cast<T>(-1), y, y);
}
......@@ -472,9 +473,10 @@ class VTanhKernelImpl : public VTanhKernel<T> {
_mm256_storeu_ps(y, tmp); \
x += AVX_FLOAT_BLOCK; \
y += AVX_FLOAT_BLOCK; \
vscal_->Compute(2.f, x, y); \
const float a = 2.f; \
vscal_->Compute(&a, x, y, this->num_); \
vsigmoid_->Compute(y, y); \
vscal_->Compute(2.f, y); \
vscal_->Compute(&a, y, y, this->num_); \
vaddbias_->Compute(-1.f, y, y); \
}
......@@ -502,9 +504,10 @@ class VTanhKernelImpl : public VTanhKernel<T> {
} \
x += this->end_; \
y += this->end_; \
vscal_->Compute(2.f, x, y); \
const float a = 2.f; \
vscal_->Compute(&a, x, y, this->num_); \
vsigmoid_->Compute(y, y); \
vscal_->Compute(2.f, y); \
vscal_->Compute(&a, y, y, this->num_); \
vaddbias_->Compute(-1.f, y, y); \
}
......
......@@ -281,9 +281,10 @@ void vtanh_better(
const paddle::operators::math::jitkernel::VAddBiasKernel<float>>&
vaddbias,
const int n, const float* x, float* y) {
vscal->Compute(2.f, x, y);
const float tmp1 = 2.f;
vscal->Compute(&tmp1, x, y, n);
vsigmoid->Compute(y, y);
vscal->Compute(2.f, y);
vscal->Compute(&tmp1, y, y, n);
vaddbias->Compute(-1.f, y, y);
}
......@@ -531,12 +532,12 @@ TEST(JitKernel, vscal) {
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(a, x_data, ztgt_data);
ker->Compute(&a, x_data, ztgt_data, d);
}
auto ttgte = GetCurrentUS();
auto ttgts1 = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(a, y_data);
ker->Compute(&a, y_data, y_data, d);
}
auto ttgte1 = GetCurrentUS();
VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册