提交 e0591dee 编写于 作者: T tensor-tang

enhance seqpool jitcode

上级 92201d39
...@@ -194,8 +194,8 @@ template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> ...@@ -194,8 +194,8 @@ template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void BenchSeqPoolKernel() { void BenchSeqPoolKernel() {
std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum}; std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum};
for (auto type : pool_types) { for (auto type : pool_types) {
for (int h : TestSizes()) { for (int w : TestSizes()) {
for (int w : TestSizes()) { for (int h : TestSizes()) {
const jit::seq_pool_attr_t attr(h, w, type); const jit::seq_pool_attr_t attr(h, w, type);
std::vector<T> x(h * w), y(w); std::vector<T> x(h * w), y(w);
RandomVec<T>(h * w, x.data(), -2.f, 2.f); RandomVec<T>(h * w, x.data(), -2.f, 2.f);
......
...@@ -35,7 +35,6 @@ void SeqPoolJitCode::genCode() { ...@@ -35,7 +35,6 @@ void SeqPoolJitCode::genCode() {
mov(reg32_scalar, scalar); mov(reg32_scalar, scalar);
} }
// TODO(TJ): make height load from params
const int group_len = max_num_regs * block * sizeof(float); const int group_len = max_num_regs * block * sizeof(float);
for (int g = 0; g < num_groups; ++g) { for (int g = 0; g < num_groups; ++g) {
pool_height<ymm_t>(g * group_len, block, max_num_regs); pool_height<ymm_t>(g * group_len, block, max_num_regs);
...@@ -44,59 +43,9 @@ void SeqPoolJitCode::genCode() { ...@@ -44,59 +43,9 @@ void SeqPoolJitCode::genCode() {
pool_height<ymm_t>(num_groups * group_len, block, rest_num_regs); pool_height<ymm_t>(num_groups * group_len, block, rest_num_regs);
} }
// rest part // part of rest_w * height
const int rest = w_ % block; const int rest = w_ % block;
const bool has_block4 = rest / 4 > 0; pool_height_of_rest_width(rest, (w_ - rest) * sizeof(float), max_num_regs);
const bool has_block2 = (rest % 4) / 2 > 0;
const bool has_block1 = (rest % 2) == 1;
const int w_offset = num_block * YMM_FLOAT_BLOCK * sizeof(float);
for (int h = 0; h < h_; ++h) {
int offset = h * w_ * sizeof(float) + w_offset;
const int shift_regs = (h == 0) ? 0 : max_num_regs;
int reg_idx = 0;
if (has_block4) {
vmovups(xmm_t(reg_idx + shift_regs), ptr[param1 + offset]);
offset += sizeof(float) * 4;
reg_idx++;
}
if (has_block2) {
vmovq(xmm_t(reg_idx + shift_regs), ptr[param1 + offset]);
offset += sizeof(float) * 2;
reg_idx++;
}
if (has_block1) {
vmovss(xmm_t(reg_idx + shift_regs), ptr[param1 + offset]);
reg_idx++;
}
rest_num_regs = reg_idx;
if (h > 0) {
for (int i = 0; i < reg_idx; ++i) {
vaddps(xmm_t(i), xmm_t(i), xmm_t(i + max_num_regs));
}
}
}
// save right now
int offset = w_offset;
if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
vbroadcastss(xmm_t(max_num_regs - 1), reg32_scalar);
for (int i = 0; i < rest_num_regs; ++i) {
vmulps(xmm_t(i), xmm_t(i), xmm_t(max_num_regs - 1));
}
}
int reg_idx = 0;
if (has_block4) {
vmovups(ptr[param2 + offset], xmm_t(reg_idx));
offset += sizeof(float) * 4;
reg_idx++;
}
if (has_block2) {
vmovq(ptr[param2 + offset], xmm_t(reg_idx));
offset += sizeof(float) * 2;
reg_idx++;
}
if (has_block1) {
vmovss(ptr[param2 + offset], xmm_t(reg_idx));
}
ret(); ret();
} }
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <string> #include <string>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h" #include "paddle/fluid/operators/jit/gen/jitcode.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -45,8 +46,6 @@ class SeqPoolJitCode : public JitCode { ...@@ -45,8 +46,6 @@ class SeqPoolJitCode : public JitCode {
base += "_Sqrt"; base += "_Sqrt";
} }
base += ("_W" + std::to_string(w_)); base += ("_W" + std::to_string(w_));
// TODO(TJ): make h load from params
base += ("_H" + std::to_string(h_));
return base.c_str(); return base.c_str();
} }
void genCode() override; void genCode() override;
...@@ -54,25 +53,36 @@ class SeqPoolJitCode : public JitCode { ...@@ -54,25 +53,36 @@ class SeqPoolJitCode : public JitCode {
protected: protected:
template <typename JMM> template <typename JMM>
void pool_height(int w_offset, int block, int max_num_regs) { void pool_height(int w_offset, int block, int max_num_regs) {
for (int h = 0; h < h_; ++h) { int offset = w_offset;
int offset = h * w_ * sizeof(float) + w_offset; for (int i = 0; i < max_num_regs; ++i) {
const int shift_regs = (h == 0) ? 0 : max_num_regs; vmovups(JMM(i), ptr[param1 + offset]);
for (int i = 0; i < max_num_regs; ++i) { offset += sizeof(float) * block;
vmovups(JMM(i + shift_regs), ptr[param1 + offset]); }
offset += sizeof(float) * block; if (h_ > 1) {
} Label l_next_h;
if (h > 0) { mov(reg_h, 1);
// sum anyway mov(reg_tmp, param1);
add(reg_tmp, w_ * sizeof(float) + w_offset);
L(l_next_h);
{
mov(reg_ptr_src_i, reg_tmp);
for (int i = 0; i < max_num_regs; ++i) { for (int i = 0; i < max_num_regs; ++i) {
vmovups(JMM(i + max_num_regs), ptr[reg_ptr_src_i]);
// sum anyway
vaddps(JMM(i), JMM(i), JMM(i + max_num_regs)); vaddps(JMM(i), JMM(i), JMM(i + max_num_regs));
add(reg_ptr_src_i, sizeof(float) * block);
} }
inc(reg_h);
add(reg_tmp, w_ * sizeof(float));
cmp(reg_h, h_);
jl(l_next_h, T_NEAR);
} }
} }
// save right now // save right now
if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) { if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
vbroadcastss(JMM(max_num_regs), reg32_scalar); vbroadcastss(JMM(max_num_regs), reg32_scalar);
} }
int offset = w_offset; offset = w_offset;
for (int i = 0; i < max_num_regs; ++i) { for (int i = 0; i < max_num_regs; ++i) {
if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) { if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
vmulps(JMM(i), JMM(i), JMM(max_num_regs)); vmulps(JMM(i), JMM(i), JMM(max_num_regs));
...@@ -82,6 +92,102 @@ class SeqPoolJitCode : public JitCode { ...@@ -82,6 +92,102 @@ class SeqPoolJitCode : public JitCode {
} }
} }
void pool_height_of_rest_width(int rest, int w_offset, int max_num_regs) {
const int rest_used_num_regs = load_rest(rest, w_offset, 0);
const bool has_block4 = rest / 4 > 0;
const bool has_block2 = (rest % 4) / 2 > 0;
const bool has_block1 = (rest % 2) == 1;
if (h_ > 1) {
Label l_next_h;
mov(reg_h, 1);
mov(reg_tmp, param1);
add(reg_tmp, w_ * sizeof(float) + w_offset);
L(l_next_h);
{
// int used_regs =load_rest(rest, h * w_ * sizeof(float) + w_offset,
// max_num_regs);
int reg_idx = 0;
mov(reg_ptr_src_i, reg_tmp);
if (has_block4) {
vmovups(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]);
add(reg_ptr_src_i, sizeof(float) * 4);
reg_idx++;
}
if (has_block2) {
vmovups(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]);
add(reg_ptr_src_i, sizeof(float) * 2);
reg_idx++;
}
if (has_block1) {
vmovss(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]);
reg_idx++;
}
PADDLE_ENFORCE_EQ(reg_idx, rest_used_num_regs,
"All heights should use same regs");
for (int i = 0; i < reg_idx; ++i) {
vaddps(xmm_t(i), xmm_t(i), xmm_t(i + max_num_regs));
}
inc(reg_h);
add(reg_tmp, w_ * sizeof(float));
cmp(reg_h, h_);
jl(l_next_h, T_NEAR);
}
}
// save right now
if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
vbroadcastss(xmm_t(max_num_regs - 1), reg32_scalar);
for (int i = 0; i < rest_used_num_regs; ++i) {
vmulps(xmm_t(i), xmm_t(i), xmm_t(max_num_regs - 1));
}
}
save_rest(rest, w_offset);
}
// return the number of used regs, use start from reg 0
int load_rest(int rest, int w_offset, const int num_shift_regs,
const int reg_start = 0) {
const bool has_block4 = rest / 4 > 0;
const bool has_block2 = (rest % 4) / 2 > 0;
const bool has_block1 = (rest % 2) == 1;
int reg_idx = reg_start;
if (has_block4) {
vmovups(xmm_t(reg_idx + num_shift_regs), ptr[param1 + w_offset]);
w_offset += sizeof(float) * 4;
reg_idx++;
}
if (has_block2) {
vmovq(xmm_t(reg_idx + num_shift_regs), ptr[param1 + w_offset]);
w_offset += sizeof(float) * 2;
reg_idx++;
}
if (has_block1) {
vmovss(xmm_t(reg_idx + num_shift_regs), ptr[param1 + w_offset]);
reg_idx++;
}
return reg_idx;
}
// use reg start from 0
void save_rest(int rest, int w_offset, int reg_start = 0) {
const bool has_block4 = rest / 4 > 0;
const bool has_block2 = (rest % 4) / 2 > 0;
const bool has_block1 = (rest % 2) == 1;
int reg_idx = reg_start;
if (has_block4) {
vmovups(ptr[param2 + w_offset], xmm_t(reg_idx));
w_offset += sizeof(float) * 4;
reg_idx++;
}
if (has_block2) {
vmovq(ptr[param2 + w_offset], xmm_t(reg_idx));
w_offset += sizeof(float) * 2;
reg_idx++;
}
if (has_block1) {
vmovss(ptr[param2 + w_offset], xmm_t(reg_idx));
}
}
private: private:
int h_; int h_;
int w_; int w_;
...@@ -90,6 +196,10 @@ class SeqPoolJitCode : public JitCode { ...@@ -90,6 +196,10 @@ class SeqPoolJitCode : public JitCode {
reg64_t param2{abi_param2}; reg64_t param2{abi_param2};
reg64_t param3{abi_param3}; reg64_t param3{abi_param3};
reg32_t reg32_scalar{r8d}; reg32_t reg32_scalar{r8d};
reg64_t reg_h{r9};
reg64_t reg_ptr_src_i{r10};
reg64_t reg_tmp{r11};
}; };
} // namespace gen } // namespace gen
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册