From e0591deebc02202c4ae8bfc95f31be606b8192b8 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Fri, 4 Jan 2019 14:40:43 +0000 Subject: [PATCH] enhance seqpool jitcode --- paddle/fluid/operators/jit/benchmark.cc | 4 +- paddle/fluid/operators/jit/gen/seqpool.cc | 55 +-------- paddle/fluid/operators/jit/gen/seqpool.h | 134 ++++++++++++++++++++-- 3 files changed, 126 insertions(+), 67 deletions(-) diff --git a/paddle/fluid/operators/jit/benchmark.cc b/paddle/fluid/operators/jit/benchmark.cc index f64e43389a5..37a552fb6da 100644 --- a/paddle/fluid/operators/jit/benchmark.cc +++ b/paddle/fluid/operators/jit/benchmark.cc @@ -194,8 +194,8 @@ template void BenchSeqPoolKernel() { std::vector pool_types = {jit::SeqPoolType::kSum}; for (auto type : pool_types) { - for (int h : TestSizes()) { - for (int w : TestSizes()) { + for (int w : TestSizes()) { + for (int h : TestSizes()) { const jit::seq_pool_attr_t attr(h, w, type); std::vector x(h * w), y(w); RandomVec(h * w, x.data(), -2.f, 2.f); diff --git a/paddle/fluid/operators/jit/gen/seqpool.cc b/paddle/fluid/operators/jit/gen/seqpool.cc index ce6801b0307..fd83f834366 100644 --- a/paddle/fluid/operators/jit/gen/seqpool.cc +++ b/paddle/fluid/operators/jit/gen/seqpool.cc @@ -35,7 +35,6 @@ void SeqPoolJitCode::genCode() { mov(reg32_scalar, scalar); } - // TODO(TJ): make height load from params const int group_len = max_num_regs * block * sizeof(float); for (int g = 0; g < num_groups; ++g) { pool_height(g * group_len, block, max_num_regs); @@ -44,59 +43,9 @@ void SeqPoolJitCode::genCode() { pool_height(num_groups * group_len, block, rest_num_regs); } - // rest part + // part of rest_w * height const int rest = w_ % block; - const bool has_block4 = rest / 4 > 0; - const bool has_block2 = (rest % 4) / 2 > 0; - const bool has_block1 = (rest % 2) == 1; - const int w_offset = num_block * YMM_FLOAT_BLOCK * sizeof(float); - for (int h = 0; h < h_; ++h) { - int offset = h * w_ * sizeof(float) + w_offset; - const int shift_regs = (h == 0) ? 0 : max_num_regs; - int reg_idx = 0; - if (has_block4) { - vmovups(xmm_t(reg_idx + shift_regs), ptr[param1 + offset]); - offset += sizeof(float) * 4; - reg_idx++; - } - if (has_block2) { - vmovq(xmm_t(reg_idx + shift_regs), ptr[param1 + offset]); - offset += sizeof(float) * 2; - reg_idx++; - } - if (has_block1) { - vmovss(xmm_t(reg_idx + shift_regs), ptr[param1 + offset]); - reg_idx++; - } - rest_num_regs = reg_idx; - if (h > 0) { - for (int i = 0; i < reg_idx; ++i) { - vaddps(xmm_t(i), xmm_t(i), xmm_t(i + max_num_regs)); - } - } - } - // save right now - int offset = w_offset; - if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) { - vbroadcastss(xmm_t(max_num_regs - 1), reg32_scalar); - for (int i = 0; i < rest_num_regs; ++i) { - vmulps(xmm_t(i), xmm_t(i), xmm_t(max_num_regs - 1)); - } - } - int reg_idx = 0; - if (has_block4) { - vmovups(ptr[param2 + offset], xmm_t(reg_idx)); - offset += sizeof(float) * 4; - reg_idx++; - } - if (has_block2) { - vmovq(ptr[param2 + offset], xmm_t(reg_idx)); - offset += sizeof(float) * 2; - reg_idx++; - } - if (has_block1) { - vmovss(ptr[param2 + offset], xmm_t(reg_idx)); - } + pool_height_of_rest_width(rest, (w_ - rest) * sizeof(float), max_num_regs); ret(); } diff --git a/paddle/fluid/operators/jit/gen/seqpool.h b/paddle/fluid/operators/jit/gen/seqpool.h index eb2d1913826..48288d8c2ae 100644 --- a/paddle/fluid/operators/jit/gen/seqpool.h +++ b/paddle/fluid/operators/jit/gen/seqpool.h @@ -17,6 +17,7 @@ #include #include "glog/logging.h" #include "paddle/fluid/operators/jit/gen/jitcode.h" +#include "paddle/fluid/platform/enforce.h" namespace paddle { namespace operators { @@ -45,8 +46,6 @@ class SeqPoolJitCode : public JitCode { base += "_Sqrt"; } base += ("_W" + std::to_string(w_)); - // TODO(TJ): make h load from params - base += ("_H" + std::to_string(h_)); return base.c_str(); } void genCode() override; @@ -54,25 +53,36 @@ class SeqPoolJitCode : public JitCode { protected: template void pool_height(int w_offset, int block, int max_num_regs) { - for (int h = 0; h < h_; ++h) { - int offset = h * w_ * sizeof(float) + w_offset; - const int shift_regs = (h == 0) ? 0 : max_num_regs; - for (int i = 0; i < max_num_regs; ++i) { - vmovups(JMM(i + shift_regs), ptr[param1 + offset]); - offset += sizeof(float) * block; - } - if (h > 0) { - // sum anyway + int offset = w_offset; + for (int i = 0; i < max_num_regs; ++i) { + vmovups(JMM(i), ptr[param1 + offset]); + offset += sizeof(float) * block; + } + if (h_ > 1) { + Label l_next_h; + mov(reg_h, 1); + mov(reg_tmp, param1); + add(reg_tmp, w_ * sizeof(float) + w_offset); + L(l_next_h); + { + mov(reg_ptr_src_i, reg_tmp); for (int i = 0; i < max_num_regs; ++i) { + vmovups(JMM(i + max_num_regs), ptr[reg_ptr_src_i]); + // sum anyway vaddps(JMM(i), JMM(i), JMM(i + max_num_regs)); + add(reg_ptr_src_i, sizeof(float) * block); } + inc(reg_h); + add(reg_tmp, w_ * sizeof(float)); + cmp(reg_h, h_); + jl(l_next_h, T_NEAR); } } // save right now if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) { vbroadcastss(JMM(max_num_regs), reg32_scalar); } - int offset = w_offset; + offset = w_offset; for (int i = 0; i < max_num_regs; ++i) { if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) { vmulps(JMM(i), JMM(i), JMM(max_num_regs)); @@ -82,6 +92,102 @@ class SeqPoolJitCode : public JitCode { } } + void pool_height_of_rest_width(int rest, int w_offset, int max_num_regs) { + const int rest_used_num_regs = load_rest(rest, w_offset, 0); + const bool has_block4 = rest / 4 > 0; + const bool has_block2 = (rest % 4) / 2 > 0; + const bool has_block1 = (rest % 2) == 1; + if (h_ > 1) { + Label l_next_h; + mov(reg_h, 1); + mov(reg_tmp, param1); + add(reg_tmp, w_ * sizeof(float) + w_offset); + L(l_next_h); + { + // int used_regs =load_rest(rest, h * w_ * sizeof(float) + w_offset, + // max_num_regs); + int reg_idx = 0; + mov(reg_ptr_src_i, reg_tmp); + if (has_block4) { + vmovups(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]); + add(reg_ptr_src_i, sizeof(float) * 4); + reg_idx++; + } + if (has_block2) { + vmovups(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]); + add(reg_ptr_src_i, sizeof(float) * 2); + reg_idx++; + } + if (has_block1) { + vmovss(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]); + reg_idx++; + } + PADDLE_ENFORCE_EQ(reg_idx, rest_used_num_regs, + "All heights should use same regs"); + for (int i = 0; i < reg_idx; ++i) { + vaddps(xmm_t(i), xmm_t(i), xmm_t(i + max_num_regs)); + } + inc(reg_h); + add(reg_tmp, w_ * sizeof(float)); + cmp(reg_h, h_); + jl(l_next_h, T_NEAR); + } + } + // save right now + if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) { + vbroadcastss(xmm_t(max_num_regs - 1), reg32_scalar); + for (int i = 0; i < rest_used_num_regs; ++i) { + vmulps(xmm_t(i), xmm_t(i), xmm_t(max_num_regs - 1)); + } + } + save_rest(rest, w_offset); + } + + // return the number of used regs, use start from reg 0 + int load_rest(int rest, int w_offset, const int num_shift_regs, + const int reg_start = 0) { + const bool has_block4 = rest / 4 > 0; + const bool has_block2 = (rest % 4) / 2 > 0; + const bool has_block1 = (rest % 2) == 1; + int reg_idx = reg_start; + if (has_block4) { + vmovups(xmm_t(reg_idx + num_shift_regs), ptr[param1 + w_offset]); + w_offset += sizeof(float) * 4; + reg_idx++; + } + if (has_block2) { + vmovq(xmm_t(reg_idx + num_shift_regs), ptr[param1 + w_offset]); + w_offset += sizeof(float) * 2; + reg_idx++; + } + if (has_block1) { + vmovss(xmm_t(reg_idx + num_shift_regs), ptr[param1 + w_offset]); + reg_idx++; + } + return reg_idx; + } + + // use reg start from 0 + void save_rest(int rest, int w_offset, int reg_start = 0) { + const bool has_block4 = rest / 4 > 0; + const bool has_block2 = (rest % 4) / 2 > 0; + const bool has_block1 = (rest % 2) == 1; + int reg_idx = reg_start; + if (has_block4) { + vmovups(ptr[param2 + w_offset], xmm_t(reg_idx)); + w_offset += sizeof(float) * 4; + reg_idx++; + } + if (has_block2) { + vmovq(ptr[param2 + w_offset], xmm_t(reg_idx)); + w_offset += sizeof(float) * 2; + reg_idx++; + } + if (has_block1) { + vmovss(ptr[param2 + w_offset], xmm_t(reg_idx)); + } + } + private: int h_; int w_; @@ -90,6 +196,10 @@ class SeqPoolJitCode : public JitCode { reg64_t param2{abi_param2}; reg64_t param3{abi_param3}; reg32_t reg32_scalar{r8d}; + + reg64_t reg_h{r9}; + reg64_t reg_ptr_src_i{r10}; + reg64_t reg_tmp{r11}; }; } // namespace gen -- GitLab