提交 3d950a81 编写于 作者: T tensor-tang

combine jitcode of vscal

上级 03e11f3f
...@@ -24,21 +24,30 @@ namespace gen { ...@@ -24,21 +24,30 @@ namespace gen {
using namespace platform::jit; // NOLINT using namespace platform::jit; // NOLINT
bool VVVJitCode::init(int d) { bool VXXJitCode::init(int d, int scalar_index) {
// It's not necessary to use avx512 since it would slow down the frequency // It's not necessary to use avx512 since it would slow down the frequency
// and this kernel is not compute bound. // and this kernel is not compute bound.
return MayIUse(avx); return MayIUse(avx) && scalar_index >= 0 && scalar_index <= 2;
} }
void VVVJitCode::generate() { void VXXJitCode::generate() {
// do not need push stack, and do not need save avx512reg if do not use avx512 // do not need push stack, and do not need save avx512reg if do not use avx512
int offset = 0; int offset = 0;
if (with_relu_) { if (with_relu_) {
vxorps(ymm_zero, ymm_zero, ymm_zero); vxorps(ymm_zero, ymm_zero, ymm_zero);
} }
if (scalar_index_ == 1) {
vbroadcastss(ymm_src1, ptr[param1]);
} else if (scalar_index_ == 2) {
vbroadcastss(ymm_src2, ptr[param2]);
}
for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) { for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) {
vmovups(ymm_src1, ptr[param1 + offset]); if (scalar_index_ != 1) {
vmovups(ymm_src2, ptr[param2 + offset]); vmovups(ymm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovups(ymm_src2, ptr[param2 + offset]);
}
if (type_ == operand_type::mul) { if (type_ == operand_type::mul) {
vmulps(ymm_dst, ymm_src1, ymm_src2); vmulps(ymm_dst, ymm_src1, ymm_src2);
} else if (type_ == operand_type::add) { } else if (type_ == operand_type::add) {
...@@ -52,8 +61,12 @@ void VVVJitCode::generate() { ...@@ -52,8 +61,12 @@ void VVVJitCode::generate() {
} }
int rest = num_ % AVX_FLOAT_BLOCK; int rest = num_ % AVX_FLOAT_BLOCK;
if (rest >= 4) { if (rest >= 4) {
vmovups(xmm_src1, ptr[param1 + offset]); if (scalar_index_ != 1) {
vmovups(xmm_src2, ptr[param2 + offset]); vmovups(xmm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovups(xmm_src2, ptr[param2 + offset]);
}
if (type_ == operand_type::mul) { if (type_ == operand_type::mul) {
vmulps(xmm_dst, xmm_src1, xmm_src2); vmulps(xmm_dst, xmm_src1, xmm_src2);
} else if (type_ == operand_type::add) { } else if (type_ == operand_type::add) {
...@@ -67,8 +80,12 @@ void VVVJitCode::generate() { ...@@ -67,8 +80,12 @@ void VVVJitCode::generate() {
rest -= 4; rest -= 4;
} }
if (rest >= 2) { if (rest >= 2) {
vmovq(xmm_src1, ptr[param1 + offset]); if (scalar_index_ != 1) {
vmovq(xmm_src2, ptr[param2 + offset]); vmovups(xmm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovups(xmm_src2, ptr[param2 + offset]);
}
if (type_ == operand_type::mul) { if (type_ == operand_type::mul) {
vmulps(xmm_dst, xmm_src1, xmm_src2); vmulps(xmm_dst, xmm_src1, xmm_src2);
} else if (type_ == operand_type::add) { } else if (type_ == operand_type::add) {
...@@ -82,8 +99,12 @@ void VVVJitCode::generate() { ...@@ -82,8 +99,12 @@ void VVVJitCode::generate() {
rest -= 2; rest -= 2;
} }
if (rest > 0) { if (rest > 0) {
vmovss(xmm_src1, ptr[param1 + offset]); if (scalar_index_ != 1) {
vmovss(xmm_src2, ptr[param2 + offset]); vmovups(xmm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovups(xmm_src2, ptr[param2 + offset]);
}
if (type_ == operand_type::mul) { if (type_ == operand_type::mul) {
vmulss(xmm_dst, xmm_src1, xmm_src2); vmulss(xmm_dst, xmm_src1, xmm_src2);
} else if (type_ == operand_type::add) { } else if (type_ == operand_type::add) {
...@@ -97,40 +118,6 @@ void VVVJitCode::generate() { ...@@ -97,40 +118,6 @@ void VVVJitCode::generate() {
ret(); ret();
} }
bool VScalJitCode::init(int d) { return MayIUse(avx); }
void VScalJitCode::generate() {
int offset = 0;
vbroadcastss(ymm_src1, ptr[param1]);
for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) {
vmovups(ymm_src2, ptr[param2 + offset]);
vmulps(ymm_dst, ymm_src1, ymm_src2);
vmovups(ptr[param3 + offset], ymm_dst);
offset += sizeof(float) * AVX_FLOAT_BLOCK;
}
int rest = num_ % AVX_FLOAT_BLOCK;
if (rest >= 4) {
vmovups(xmm_src2, ptr[param2 + offset]);
vmulps(xmm_dst, xmm_src1, xmm_src2);
vmovups(ptr[param3 + offset], xmm_dst);
offset += sizeof(float) * 4;
rest -= 4;
}
if (rest >= 2) {
vmovq(xmm_src2, ptr[param2 + offset]);
vmulps(xmm_dst, xmm_src1, xmm_src2);
vmovq(ptr[param3 + offset], xmm_dst);
offset += sizeof(float) * 2;
rest -= 2;
}
if (rest > 0) {
vmovss(xmm_src2, ptr[param2 + offset]);
vmulss(xmm_dst, xmm_src1, xmm_src2);
vmovss(ptr[param3 + offset], xmm_dst);
}
ret();
}
} // namespace gen } // namespace gen
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
......
...@@ -31,11 +31,11 @@ using Label = Xbyak::Label; ...@@ -31,11 +31,11 @@ using Label = Xbyak::Label;
typedef enum { mul = 0, add } operand_type; typedef enum { mul = 0, add } operand_type;
// function: vec = Operand(vec, vec) (maybe with relu) // function: vec = Operand(vec(scalar), vec(scalar)) (maybe with relu)
class VVVJitCode : public JitCode { class VXXJitCode : public JitCode {
public: public:
const char* name() const override { const char* name() const override {
std::string base = "VVVJitCode"; std::string base = "VXXJitCode";
if (type_ == operand_type::mul) { if (type_ == operand_type::mul) {
base += "_Mul"; base += "_Mul";
} else if (type_ == operand_type::add) { } else if (type_ == operand_type::add) {
...@@ -44,18 +44,21 @@ class VVVJitCode : public JitCode { ...@@ -44,18 +44,21 @@ class VVVJitCode : public JitCode {
base += (with_relu_ ? "_Relu" : ""); base += (with_relu_ ? "_Relu" : "");
return base.c_str(); return base.c_str();
} }
explicit VVVJitCode(int d, operand_type type, bool with_relu, explicit VXXJitCode(int d, operand_type type, int scalar_index,
size_t code_size = 256 * 1024, void* code_ptr = nullptr) bool with_relu, size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr), : JitCode(code_size, code_ptr),
num_(d), num_(d),
type_(type), type_(type),
scalar_index_(scalar_index),
with_relu_(with_relu) {} with_relu_(with_relu) {}
static bool init(int d); static bool init(int d, int scalar_index = 0);
void generate() override; void generate() override;
private: private:
int num_; int num_;
operand_type type_; operand_type type_;
int scalar_index_;
bool with_relu_; bool with_relu_;
reg64_t param1{abi_param1}; reg64_t param1{abi_param1};
reg64_t param2{abi_param2}; reg64_t param2{abi_param2};
...@@ -63,39 +66,13 @@ class VVVJitCode : public JitCode { ...@@ -63,39 +66,13 @@ class VVVJitCode : public JitCode {
xmm_t xmm_src1 = xmm_t(0); xmm_t xmm_src1 = xmm_t(0);
xmm_t xmm_src2 = xmm_t(1); xmm_t xmm_src2 = xmm_t(1);
xmm_t xmm_dst = xmm_t(1); xmm_t xmm_dst = xmm_t(2);
xmm_t xmm_zero = xmm_t(2); xmm_t xmm_zero = xmm_t(3);
ymm_t ymm_src1 = ymm_t(0); ymm_t ymm_src1 = ymm_t(0);
ymm_t ymm_src2 = ymm_t(1); ymm_t ymm_src2 = ymm_t(1);
ymm_t ymm_dst = ymm_t(1); ymm_t ymm_dst = ymm_t(2);
ymm_t ymm_zero = ymm_t(2); ymm_t ymm_zero = ymm_t(3);
};
class VScalJitCode : public JitCode {
public:
DECLARE_JIT_CODE(VScalJitCode);
explicit VScalJitCode(int d, size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr), num_(d) {}
static bool init(int d);
void generate() override;
private:
int num_;
reg64_t param1{abi_param1};
reg64_t param2{abi_param2};
reg64_t param3{abi_param3};
xmm_t xmm_src1 = xmm_t(0);
xmm_t xmm_src2 = xmm_t(1);
xmm_t xmm_dst = xmm_t(1);
xmm_t xmm_zero = xmm_t(2);
ymm_t ymm_src1 = ymm_t(0);
ymm_t ymm_src2 = ymm_t(1);
ymm_t ymm_dst = ymm_t(1);
ymm_t ymm_zero = ymm_t(2);
}; };
} // namespace gen } // namespace gen
......
...@@ -131,7 +131,7 @@ class VMulKernelImpl : public VMulKernel<T> { ...@@ -131,7 +131,7 @@ class VMulKernelImpl : public VMulKernel<T> {
if (useJIT(d)) { if (useJIT(d)) {
// roughly estimate the size of code // roughly estimate the size of code
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VVVJitCode(d, gen::operand_type::mul, false, jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::mul, 0, false,
sz > 4096 ? sz : 4096)); sz > 4096 ? sz : 4096));
this->Compute = this->Compute =
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>(); jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
...@@ -150,14 +150,14 @@ class VMulKernelImpl : public VMulKernel<T> { ...@@ -150,14 +150,14 @@ class VMulKernelImpl : public VMulKernel<T> {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
private: private:
std::unique_ptr<gen::VVVJitCode> jitcode_{nullptr}; std::unique_ptr<gen::VXXJitCode> jitcode_{nullptr};
#endif #endif
}; };
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
template <> template <>
bool VMulKernelImpl<float>::useJIT(int d) { bool VMulKernelImpl<float>::useJIT(int d) {
return gen::VVVJitCode::init(d); return gen::VXXJitCode::init(d);
} }
#endif #endif
...@@ -182,7 +182,7 @@ class VAddKernelImpl : public VAddKernel<T> { ...@@ -182,7 +182,7 @@ class VAddKernelImpl : public VAddKernel<T> {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) { if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VVVJitCode(d, gen::operand_type::add, false, jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 0, false,
sz > 4096 ? sz : 4096)); sz > 4096 ? sz : 4096));
this->Compute = this->Compute =
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>(); jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
...@@ -200,14 +200,14 @@ class VAddKernelImpl : public VAddKernel<T> { ...@@ -200,14 +200,14 @@ class VAddKernelImpl : public VAddKernel<T> {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
private: private:
std::unique_ptr<gen::VVVJitCode> jitcode_{nullptr}; std::unique_ptr<gen::VXXJitCode> jitcode_{nullptr};
#endif #endif
}; };
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
template <> template <>
bool VAddKernelImpl<float>::useJIT(int d) { bool VAddKernelImpl<float>::useJIT(int d) {
return gen::VVVJitCode::init(d); return gen::VXXJitCode::init(d);
} }
#endif #endif
...@@ -232,7 +232,7 @@ class VAddReluKernelImpl : public VAddReluKernel<T> { ...@@ -232,7 +232,7 @@ class VAddReluKernelImpl : public VAddReluKernel<T> {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) { if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VVVJitCode(d, gen::operand_type::add, true, jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 0, true,
sz > 4096 ? sz : 4096)); sz > 4096 ? sz : 4096));
this->Compute = this->Compute =
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>(); jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
...@@ -244,14 +244,14 @@ class VAddReluKernelImpl : public VAddReluKernel<T> { ...@@ -244,14 +244,14 @@ class VAddReluKernelImpl : public VAddReluKernel<T> {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
private: private:
std::unique_ptr<gen::VVVJitCode> jitcode_{nullptr}; std::unique_ptr<gen::VXXJitCode> jitcode_{nullptr};
#endif #endif
}; };
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
template <> template <>
bool VAddReluKernelImpl<float>::useJIT(int d) { bool VAddReluKernelImpl<float>::useJIT(int d) {
return gen::VVVJitCode::init(d); return gen::VXXJitCode::init(d);
} }
#endif #endif
...@@ -264,7 +264,8 @@ class VScalKernelImpl : public VScalKernel<T> { ...@@ -264,7 +264,8 @@ class VScalKernelImpl : public VScalKernel<T> {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) { if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VScalJitCode(d, sz > 4096 ? sz : 4096)); jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::mul, 1, false,
sz > 4096 ? sz : 4096));
this->Compute = this->Compute =
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>(); jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
return; return;
...@@ -281,14 +282,14 @@ class VScalKernelImpl : public VScalKernel<T> { ...@@ -281,14 +282,14 @@ class VScalKernelImpl : public VScalKernel<T> {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
private: private:
std::unique_ptr<gen::VScalJitCode> jitcode_{nullptr}; std::unique_ptr<gen::VXXJitCode> jitcode_{nullptr};
#endif #endif
}; };
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
template <> template <>
bool VScalKernelImpl<float>::useJIT(int d) { bool VScalKernelImpl<float>::useJIT(int d) {
return gen::VScalJitCode::init(d); return gen::VXXJitCode::init(d, 1);
} }
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册