提交 382307b9 编写于 作者: T tensor-tang

refine code

test=develop
上级 25e070ec
......@@ -24,51 +24,14 @@ namespace gen {
using namespace platform::jit; // NOLINT
bool VMulJitCode::init(int d) {
bool VVVJitCode::init(int d) {
// It's not necessary to use avx512 since it would slow down the frequency
// and this kernel is not compute bound.
return MayIUse(avx);
}
void VMulJitCode::generate() {
void VVVJitCode::generate() {
// do not need push stack, and do not need save avx512reg if do not use avx512
int offset = 0;
for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) {
vmovups(ymm_src1, ptr[param1 + offset]);
vmovups(ymm_src2, ptr[param2 + offset]);
vmulps(ymm_dst, ymm_src1, ymm_src2);
vmovups(ptr[param3 + offset], ymm_dst);
offset += sizeof(float) * AVX_FLOAT_BLOCK;
}
int rest = num_ % AVX_FLOAT_BLOCK;
if (rest >= 4) {
vmovups(xmm_src1, ptr[param1 + offset]);
vmovups(xmm_src2, ptr[param2 + offset]);
vmulps(xmm_dst, xmm_src1, xmm_src2);
vmovups(ptr[param3 + offset], xmm_dst);
offset += sizeof(float) * 4;
rest -= 4;
}
if (rest >= 2) {
vmovq(xmm_src1, ptr[param1 + offset]);
vmovq(xmm_src2, ptr[param2 + offset]);
vmulps(xmm_dst, xmm_src1, xmm_src2);
vmovq(ptr[param3 + offset], xmm_dst);
offset += sizeof(float) * 2;
rest -= 2;
}
if (rest > 0) {
vmovss(xmm_src1, ptr[param1 + offset]);
vmovss(xmm_src2, ptr[param2 + offset]);
vmulss(xmm_dst, xmm_src1, xmm_src2);
vmovss(ptr[param3 + offset], xmm_dst);
}
ret();
}
bool VAddJitCode::init(int d) { return MayIUse(avx); }
void VAddJitCode::generate() {
int offset = 0;
if (with_relu_) {
vxorps(ymm_zero, ymm_zero, ymm_zero);
......@@ -76,7 +39,11 @@ void VAddJitCode::generate() {
for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) {
vmovups(ymm_src1, ptr[param1 + offset]);
vmovups(ymm_src2, ptr[param2 + offset]);
vaddps(ymm_dst, ymm_src1, ymm_src2);
if (type_ == operand_type::mul) {
vmulps(ymm_dst, ymm_src1, ymm_src2);
} else if (type_ == operand_type::add) {
vaddps(ymm_dst, ymm_src1, ymm_src2);
}
if (with_relu_) {
vmaxps(ymm_dst, ymm_zero, ymm_dst);
}
......@@ -87,7 +54,11 @@ void VAddJitCode::generate() {
if (rest >= 4) {
vmovups(xmm_src1, ptr[param1 + offset]);
vmovups(xmm_src2, ptr[param2 + offset]);
vaddps(xmm_dst, xmm_src1, xmm_src2);
if (type_ == operand_type::mul) {
vmulps(xmm_dst, xmm_src1, xmm_src2);
} else if (type_ == operand_type::add) {
vaddps(xmm_dst, xmm_src1, xmm_src2);
}
if (with_relu_) {
vmaxps(xmm_dst, xmm_zero, xmm_dst);
}
......@@ -98,7 +69,11 @@ void VAddJitCode::generate() {
if (rest >= 2) {
vmovq(xmm_src1, ptr[param1 + offset]);
vmovq(xmm_src2, ptr[param2 + offset]);
vaddps(xmm_dst, xmm_src1, xmm_src2);
if (type_ == operand_type::mul) {
vmulps(xmm_dst, xmm_src1, xmm_src2);
} else if (type_ == operand_type::add) {
vaddps(xmm_dst, xmm_src1, xmm_src2);
}
if (with_relu_) {
vmaxps(xmm_dst, xmm_zero, xmm_dst);
}
......@@ -109,7 +84,11 @@ void VAddJitCode::generate() {
if (rest > 0) {
vmovss(xmm_src1, ptr[param1 + offset]);
vmovss(xmm_src2, ptr[param2 + offset]);
vaddss(xmm_dst, xmm_src1, xmm_src2);
if (type_ == operand_type::mul) {
vmulss(xmm_dst, xmm_src1, xmm_src2);
} else if (type_ == operand_type::add) {
vaddss(xmm_dst, xmm_src1, xmm_src2);
}
if (with_relu_) {
vmaxps(xmm_dst, xmm_zero, xmm_dst);
}
......
......@@ -14,8 +14,8 @@ limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/operators/math/jit_gen.h"
namespace paddle {
namespace operators {
namespace math {
......@@ -29,41 +29,33 @@ using ymm_t = const Xbyak::Ymm;
using zmm_t = const Xbyak::Zmm;
using Label = Xbyak::Label;
class VMulJitCode : public JitCode {
public:
DECLARE_JIT_CODE(VMulJitCode);
explicit VMulJitCode(int d, size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr), num_(d) {}
static bool init(int d);
void generate() override;
private:
int num_;
reg64_t param1{abi_param1};
reg64_t param2{abi_param2};
reg64_t param3{abi_param3};
xmm_t xmm_src1 = xmm_t(0);
xmm_t xmm_src2 = xmm_t(1);
xmm_t xmm_dst = xmm_t(1);
ymm_t ymm_src1 = ymm_t(0);
ymm_t ymm_src2 = ymm_t(1);
ymm_t ymm_dst = ymm_t(1);
};
// function: vec = Operand(vec, vec) (maybe with relu)
typedef enum { mul = 0, add } operand_type;
class VAddJitCode : public JitCode {
class VVVJitCode : public JitCode {
public:
DECLARE_JIT_CODE(VAddJitCode);
explicit VAddJitCode(int d, bool with_relu, size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr), num_(d), with_relu_(with_relu) {}
const char* name() const override {
std::string base = "VVVJitCode";
if (type_ == operand_type::mul) {
base += "_Mul";
} else if (type_ == operand_type::add) {
base += "_Add";
}
base += (with_relu_ ? "_relu" : "");
return base.c_str();
}
explicit VVVJitCode(int d, operand_type type, bool with_relu,
size_t code_size = 256 * 1024, void* code_ptr = nullptr)
: JitCode(code_size, code_ptr),
num_(d),
type_(type),
with_relu_(with_relu) {}
static bool init(int d);
void generate() override;
private:
int num_;
operand_type type_;
bool with_relu_;
reg64_t param1{abi_param1};
reg64_t param2{abi_param2};
......
......@@ -102,7 +102,8 @@ class VMulKernelImpl : public VMulKernel<T> {
if (useJIT(d)) {
// roughly estimate the size of code
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VMulJitCode(d, sz > 4096 ? sz : 4096));
jitcode_.reset(new gen::VVVJitCode(d, gen::operand_type::mul, false,
sz > 4096 ? sz : 4096));
this->Compute =
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
return;
......@@ -120,14 +121,14 @@ class VMulKernelImpl : public VMulKernel<T> {
#ifdef PADDLE_WITH_XBYAK
private:
std::unique_ptr<gen::VMulJitCode> jitcode_{nullptr};
std::unique_ptr<gen::VVVJitCode> jitcode_{nullptr};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template <>
bool VMulKernelImpl<float>::useJIT(int d) {
return gen::VMulJitCode::init(d);
return gen::VVVJitCode::init(d);
}
#endif
......@@ -149,13 +150,16 @@ class VAddKernelImpl : public VAddKernel<T> {
public:
DECLARE_STATIC_FUNC;
explicit VAddKernelImpl(int d) : VAddKernel<T>() {
#ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VAddJitCode(d, false, sz > 4096 ? sz : 4096));
jitcode_.reset(new gen::VVVJitCode(d, gen::operand_type::add, false,
sz > 4096 ? sz : 4096));
this->Compute =
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
return;
}
#endif
#ifdef PADDLE_WITH_MKLML
if (useMKL(d)) {
this->Compute = VAddMKL<T>;
......@@ -166,14 +170,17 @@ class VAddKernelImpl : public VAddKernel<T> {
}
private:
std::unique_ptr<gen::VAddJitCode> jitcode_{nullptr};
std::unique_ptr<gen::VVVJitCode> jitcode_{nullptr};
};
#ifdef PADDLE_WITH_XBYAK
template <>
bool VAddKernelImpl<float>::useJIT(int d) {
return gen::VAddJitCode::init(d);
return gen::VVVJitCode::init(d);
}
#endif
#ifdef PADDLE_WITH_MKLML
template <>
bool VAddKernelImpl<float>::useMKL(int d) {
return d > 512;
......@@ -183,6 +190,7 @@ template <>
bool VAddKernelImpl<double>::useMKL(int d) {
return true;
}
#endif
/* VAddRelu JitKernel */
template <typename T>
......@@ -190,24 +198,29 @@ class VAddReluKernelImpl : public VAddReluKernel<T> {
public:
DECLARE_STATIC_FUNC;
explicit VAddReluKernelImpl(int d) : VAddReluKernel<T>() {
#ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VAddJitCode(d, true, sz > 4096 ? sz : 4096));
jitcode_.reset(new gen::VVVJitCode(d, gen::operand_type::add, true,
sz > 4096 ? sz : 4096));
this->Compute =
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
return;
}
#endif
this->Compute = VAddReluRefer<T>;
}
private:
std::unique_ptr<gen::VAddJitCode> jitcode_{nullptr};
std::unique_ptr<gen::VVVJitCode> jitcode_{nullptr};
};
#ifdef PADDLE_WITH_XBYAK
template <>
bool VAddReluKernelImpl<float>::useJIT(int d) {
return gen::VAddJitCode::init(d);
return gen::VVVJitCode::init(d);
}
#endif
#undef DECLARE_STATIC_FUNC
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册