提交 5e64244f 编写于 作者: T tensor-tang

add vaddbias jitcode

test=develop
上级 5f7956ae
......@@ -31,16 +31,26 @@ using Label = Xbyak::Label;
typedef enum { mul = 0, add } operand_type;
// function: vec = Operand(vec(scalar), vec(scalar)) (maybe with relu)
// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu)
class VXXJitCode : public JitCode {
public:
const char* name() const override {
std::string base = "VXXJitCode";
if (scalar_index_ == 1) {
base += "_Scalar";
} else {
base += "_Vec";
}
if (type_ == operand_type::mul) {
base += "_Mul";
} else if (type_ == operand_type::add) {
base += "_Add";
}
if (scalar_index_ == 2) {
base += "_Scalar";
} else {
base += "_Vec";
}
base += (with_relu_ ? "_Relu" : "");
return base.c_str();
}
......
......@@ -83,13 +83,15 @@ class VAddReluKernel : public Kernel {
template <typename T>
class VScalKernel : public Kernel {
public:
// y = a.*x
void (*Compute)(const T *, const T *, T *, int);
};
template <typename T>
class VAddBiasKernel : public Kernel {
public:
virtual void Compute(const T a, const T *x, T *y) const = 0;
// y = a.+x
void (*Compute)(const T *, const T *, T *, int);
};
template <typename T>
......
......@@ -60,6 +60,13 @@ void VScalRefer(const T* a, const T* x, T* y, int n) {
}
}
template <typename T>
void VAddBiasRefer(const T* a, const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = a[0] + x[i];
}
}
#ifdef PADDLE_WITH_MKLML
template <typename T>
void VMulMKL(const T* x, const T* y, T* z, int n);
......@@ -300,62 +307,46 @@ bool VScalKernelImpl<double>::useMKL(int d) {
}
#endif
#undef DECLARE_STATIC_FUNC
REGISTER_JITKERNEL(vmul, VMulKernel);
REGISTER_JITKERNEL(vadd, VAddKernel);
REGISTER_JITKERNEL(vscal, VScalKernel);
REGISTER_JITKERNEL(vaddrelu, VAddReluKernel);
/* VAddBias JitKernel */
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
template <typename T>
class VAddBiasKernelImpl : public VAddBiasKernel<T> {
public:
explicit VAddBiasKernelImpl(int d) : VAddBiasKernel<T>() { this->num_ = d; }
void Compute(const T a, const T* x, T* y) const override {
for (int i = 0; i < this->num_; ++i) {
y[i] = x[i] + a;
DECLARE_STATIC_FUNC;
explicit VAddBiasKernelImpl(int d) : VAddBiasKernel<T>() {
#ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 1, false,
sz > 4096 ? sz : 4096));
this->Compute =
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
return;
}
}
};
#define INTRI8_FLOAT(isa) \
template <> \
void VAddBiasKernelImpl<float, isa, kEQ8>::Compute( \
const float a, const float* x, float* y) const { \
__m256 tmp = _mm256_loadu_ps(x); \
tmp = _mm256_add_ps(tmp, _mm256_set1_ps(a)); \
_mm256_storeu_ps(y, tmp); \
}
#endif
#define INTRI16_FLOAT(isa) \
template <> \
void VAddBiasKernelImpl<float, isa, kEQ16>::Compute( \
const float a, const float* x, float* y) const { \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
tmp0 = _mm256_add_ps(tmp0, _mm256_set1_ps(a)); \
tmp1 = _mm256_add_ps(tmp1, _mm256_set1_ps(a)); \
_mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y + 8, tmp1); \
this->Compute = VAddBiasRefer<T>;
}
#ifdef PADDLE_WITH_XBYAK
#ifdef __AVX__
INTRI8_FLOAT(jit::avx);
INTRI16_FLOAT(jit::avx);
#endif
#ifdef __AVX2__
INTRI8_FLOAT(jit::avx2);
INTRI16_FLOAT(jit::avx2);
private:
std::unique_ptr<gen::VXXJitCode> jitcode_{nullptr};
#endif
#ifdef __AVX512F__
INTRI8_FLOAT(jit::avx512f);
INTRI16_FLOAT(jit::avx512f);
};
#ifdef PADDLE_WITH_XBYAK
template <>
bool VAddBiasKernelImpl<float>::useJIT(int d) {
return gen::VXXJitCode::init(d, 1);
}
#endif
// TODO(TJ): eq16 test and complete avx512
#undef INTRI8_FLOAT
#undef INTRI16_FLOAT
#undef DECLARE_STATIC_FUNC
REGISTER_JITKERNEL(vmul, VMulKernel);
REGISTER_JITKERNEL(vadd, VAddKernel);
REGISTER_JITKERNEL(vaddrelu, VAddReluKernel);
REGISTER_JITKERNEL(vscal, VScalKernel);
REGISTER_JITKERNEL(vaddbias, VAddBiasKernel);
/* VRelu JitKernel */
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
......@@ -466,7 +457,6 @@ class VIdentityKernelImpl : public VIdentityKernel<T> {
void Compute(const T* x, T* y) const override {}
};
REGISTER_JITKERNEL_DEPRECATED(vaddb, VAddBiasKernel);
REGISTER_JITKERNEL_DEPRECATED(vrelu, VReluKernel);
REGISTER_JITKERNEL_DEPRECATED(videntity, VIdentityKernel);
......
......@@ -409,11 +409,11 @@ class VTanhKernelImpl : public VTanhKernel<T> {
vaddbias_ = KernelPool::Instance().template Get<VAddBiasKernel<T>>(d);
}
void Compute(const T* x, T* y) const override {
const T a = static_cast<T>(2);
const T a = static_cast<T>(2), b = static_cast<T>(-1);
vscal_->Compute(&a, x, y, this->num_);
vsigmoid_->Compute(y, y);
vscal_->Compute(&a, y, y, this->num_);
vaddbias_->Compute(static_cast<T>(-1), y, y);
vaddbias_->Compute(&b, y, y, this->num_);
}
private:
......@@ -473,11 +473,11 @@ class VTanhKernelImpl : public VTanhKernel<T> {
_mm256_storeu_ps(y, tmp); \
x += AVX_FLOAT_BLOCK; \
y += AVX_FLOAT_BLOCK; \
const float a = 2.f; \
const float a = 2.f, b = -1.f; \
vscal_->Compute(&a, x, y, this->num_); \
vsigmoid_->Compute(y, y); \
vscal_->Compute(&a, y, y, this->num_); \
vaddbias_->Compute(-1.f, y, y); \
vaddbias_->Compute(&b, y, y, this->num_); \
}
#define INTRI_GT16_FLOAT(isa, expisa) \
......@@ -504,11 +504,11 @@ class VTanhKernelImpl : public VTanhKernel<T> {
} \
x += this->end_; \
y += this->end_; \
const float a = 2.f; \
const float a = 2.f, b = -1.f; \
vscal_->Compute(&a, x, y, this->num_); \
vsigmoid_->Compute(y, y); \
vscal_->Compute(&a, y, y, this->num_); \
vaddbias_->Compute(-1.f, y, y); \
vaddbias_->Compute(&b, y, y, this->num_); \
}
#ifndef __WIN32
......
......@@ -128,7 +128,7 @@ TEST(JitKernel, vaddbias) {
auto trefe = GetCurrentUS();
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(a, x_data, ztgt_data);
ker->Compute(&a, x_data, ztgt_data, d);
}
auto ttgte = GetCurrentUS();
......@@ -281,11 +281,11 @@ void vtanh_better(
const paddle::operators::math::jitkernel::VAddBiasKernel<float>>&
vaddbias,
const int n, const float* x, float* y) {
const float tmp1 = 2.f;
vscal->Compute(&tmp1, x, y, n);
const float a = 2.f, b = -1.f;
vscal->Compute(&a, x, y, n);
vsigmoid->Compute(y, y);
vscal->Compute(&tmp1, y, y, n);
vaddbias->Compute(-1.f, y, y);
vscal->Compute(&a, y, y, n);
vaddbias->Compute(&b, y, y, n);
}
TEST(JitKernel, vtanh) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册