From 0145f40f4576fa035b92e3876ca9c4cfefbc5c52 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Sat, 5 Jan 2019 11:34:15 +0000 Subject: [PATCH] use height from params of jitcode --- paddle/fluid/operators/jit/benchmark.cc | 3 +- paddle/fluid/operators/jit/gen/seqpool.cc | 17 +- paddle/fluid/operators/jit/gen/seqpool.h | 162 ++++++++++-------- paddle/fluid/operators/jit/kernel_base.h | 6 +- paddle/fluid/operators/jit/kernel_key.cc | 6 +- paddle/fluid/operators/jit/refer/refer.h | 1 - paddle/fluid/operators/jit/test.cc | 7 +- .../fluid/operators/math/sequence_pooling.cc | 12 +- 8 files changed, 117 insertions(+), 97 deletions(-) diff --git a/paddle/fluid/operators/jit/benchmark.cc b/paddle/fluid/operators/jit/benchmark.cc index 37a552fb6d..4cbada4a5b 100644 --- a/paddle/fluid/operators/jit/benchmark.cc +++ b/paddle/fluid/operators/jit/benchmark.cc @@ -195,8 +195,9 @@ void BenchSeqPoolKernel() { std::vector pool_types = {jit::SeqPoolType::kSum}; for (auto type : pool_types) { for (int w : TestSizes()) { + jit::seq_pool_attr_t attr(w, type); for (int h : TestSizes()) { - const jit::seq_pool_attr_t attr(h, w, type); + attr.h = h; std::vector x(h * w), y(w); RandomVec(h * w, x.data(), -2.f, 2.f); const T* x_data = x.data(); diff --git a/paddle/fluid/operators/jit/gen/seqpool.cc b/paddle/fluid/operators/jit/gen/seqpool.cc index fd83f83436..d651f282bf 100644 --- a/paddle/fluid/operators/jit/gen/seqpool.cc +++ b/paddle/fluid/operators/jit/gen/seqpool.cc @@ -13,6 +13,7 @@ * limitations under the License. */ #include "paddle/fluid/operators/jit/gen/seqpool.h" +#include // offsetof #include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/platform/cpu_info.h" @@ -21,20 +22,22 @@ namespace operators { namespace jit { namespace gen { +thread_local float ALIGN32_BEG float_h[1] ALIGN32_END = { + 1.f}; // TODO(TJ): try move to private + void SeqPoolJitCode::genCode() { constexpr int block = YMM_FLOAT_BLOCK; constexpr int max_num_regs = 8; const int num_block = w_ / block; const int num_groups = num_block / max_num_regs; int rest_num_regs = num_block % max_num_regs; - if (type_ == SeqPoolType::kAvg) { - float scalar = 1.f / h_; - mov(reg32_scalar, scalar); - } else if (type_ == SeqPoolType::kSqrt) { - float scalar = 1.f / std::sqrt(static_cast(h_)); - mov(reg32_scalar, scalar); + mov(reg32_int_h, dword[param_attr]); + if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) { + mov(reg_tmp, reinterpret_cast(float_h)); + fild(dword[param_attr]); + fstp(dword[reg_tmp]); + mov(reg32_fp_h, dword[reg_tmp]); } - const int group_len = max_num_regs * block * sizeof(float); for (int g = 0; g < num_groups; ++g) { pool_height(g * group_len, block, max_num_regs); diff --git a/paddle/fluid/operators/jit/gen/seqpool.h b/paddle/fluid/operators/jit/gen/seqpool.h index 48288d8c2a..c61bf27cc1 100644 --- a/paddle/fluid/operators/jit/gen/seqpool.h +++ b/paddle/fluid/operators/jit/gen/seqpool.h @@ -16,6 +16,7 @@ #include #include "glog/logging.h" +#include "paddle/fluid/operators/jit/gen/act.h" // for ones #include "paddle/fluid/operators/jit/gen/jitcode.h" #include "paddle/fluid/platform/enforce.h" @@ -29,7 +30,7 @@ class SeqPoolJitCode : public JitCode { explicit SeqPoolJitCode(const seq_pool_attr_t& attr, size_t code_size = 256 * 1024, void* code_ptr = nullptr) - : JitCode(code_size, code_ptr), h_(attr.h), w_(attr.w), type_(attr.type) { + : JitCode(code_size, code_ptr), w_(attr.w), type_(attr.type) { if (type_ != SeqPoolType::kSum) { LOG(FATAL) << "Only support sum pool yet "; } @@ -55,39 +56,48 @@ class SeqPoolJitCode : public JitCode { void pool_height(int w_offset, int block, int max_num_regs) { int offset = w_offset; for (int i = 0; i < max_num_regs; ++i) { - vmovups(JMM(i), ptr[param1 + offset]); + vmovups(JMM(i), ptr[param_src + offset]); offset += sizeof(float) * block; } - if (h_ > 1) { - Label l_next_h; - mov(reg_h, 1); - mov(reg_tmp, param1); - add(reg_tmp, w_ * sizeof(float) + w_offset); - L(l_next_h); - { - mov(reg_ptr_src_i, reg_tmp); - for (int i = 0; i < max_num_regs; ++i) { - vmovups(JMM(i + max_num_regs), ptr[reg_ptr_src_i]); - // sum anyway - vaddps(JMM(i), JMM(i), JMM(i + max_num_regs)); - add(reg_ptr_src_i, sizeof(float) * block); - } - inc(reg_h); - add(reg_tmp, w_ * sizeof(float)); - cmp(reg_h, h_); - jl(l_next_h, T_NEAR); + cmp(reg32_int_h, 1); + Label l_next_h, l_h_done; + jle(l_h_done, T_NEAR); + mov(reg_h_i, 1); + mov(reg_tmp, param_src); + add(reg_tmp, w_ * sizeof(float) + w_offset); + L(l_next_h); + { + mov(reg_ptr_src_i, reg_tmp); + for (int i = 0; i < max_num_regs; ++i) { + vmovups(JMM(i + max_num_regs), ptr[reg_ptr_src_i]); + // sum anyway + vaddps(JMM(i), JMM(i), JMM(i + max_num_regs)); + add(reg_ptr_src_i, sizeof(float) * block); } + inc(reg_h_i); + add(reg_tmp, w_ * sizeof(float)); + cmp(reg_h_i, reg32_int_h); + jl(l_next_h, T_NEAR); } + L(l_h_done); // save right now if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) { - vbroadcastss(JMM(max_num_regs), reg32_scalar); + mov(reg_tmp, reinterpret_cast(exp_float_consts)); + vmovups(JMM(max_num_regs), ptr[reg_tmp + OFFSET_EXP_ONE]); + movd(JMM(max_num_regs + 1), reg32_fp_h); + if (type_ == SeqPoolType::kSqrt) { + vsqrtps(JMM(max_num_regs + 1), JMM(max_num_regs + 1)); + } + vdivps(JMM(max_num_regs + 2), JMM(max_num_regs), JMM(max_num_regs + 1)); + vbroadcastss(JMM(max_num_regs), + JMM(max_num_regs + 2)); // TODO(TJ): fix me } offset = w_offset; for (int i = 0; i < max_num_regs; ++i) { if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) { vmulps(JMM(i), JMM(i), JMM(max_num_regs)); } - vmovups(ptr[param2 + offset], JMM(i)); + vmovups(ptr[param_dst + offset], JMM(i)); offset += sizeof(float) * block; } } @@ -97,47 +107,54 @@ class SeqPoolJitCode : public JitCode { const bool has_block4 = rest / 4 > 0; const bool has_block2 = (rest % 4) / 2 > 0; const bool has_block1 = (rest % 2) == 1; - if (h_ > 1) { - Label l_next_h; - mov(reg_h, 1); - mov(reg_tmp, param1); - add(reg_tmp, w_ * sizeof(float) + w_offset); - L(l_next_h); - { - // int used_regs =load_rest(rest, h * w_ * sizeof(float) + w_offset, - // max_num_regs); - int reg_idx = 0; - mov(reg_ptr_src_i, reg_tmp); - if (has_block4) { - vmovups(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]); - add(reg_ptr_src_i, sizeof(float) * 4); - reg_idx++; - } - if (has_block2) { - vmovups(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]); - add(reg_ptr_src_i, sizeof(float) * 2); - reg_idx++; - } - if (has_block1) { - vmovss(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]); - reg_idx++; - } - PADDLE_ENFORCE_EQ(reg_idx, rest_used_num_regs, - "All heights should use same regs"); - for (int i = 0; i < reg_idx; ++i) { - vaddps(xmm_t(i), xmm_t(i), xmm_t(i + max_num_regs)); - } - inc(reg_h); - add(reg_tmp, w_ * sizeof(float)); - cmp(reg_h, h_); - jl(l_next_h, T_NEAR); + cmp(reg32_int_h, 1); + Label l_next_h, l_h_done; + jle(l_h_done, T_NEAR); + mov(reg_h_i, 1); + mov(reg_tmp, param_src); + add(reg_tmp, w_ * sizeof(float) + w_offset); + L(l_next_h); + { + int reg_idx = 0; + mov(reg_ptr_src_i, reg_tmp); + if (has_block4) { + vmovups(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]); + add(reg_ptr_src_i, sizeof(float) * 4); + reg_idx++; + } + if (has_block2) { + vmovups(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]); + add(reg_ptr_src_i, sizeof(float) * 2); + reg_idx++; + } + if (has_block1) { + vmovss(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]); + reg_idx++; } + PADDLE_ENFORCE_EQ(reg_idx, rest_used_num_regs, + "All heights should use same regs"); + for (int i = 0; i < reg_idx; ++i) { + vaddps(xmm_t(i), xmm_t(i), xmm_t(i + max_num_regs)); + } + inc(reg_h_i); + add(reg_tmp, w_ * sizeof(float)); + cmp(reg_h_i, reg32_int_h); + jl(l_next_h, T_NEAR); } + L(l_h_done); // save right now if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) { - vbroadcastss(xmm_t(max_num_regs - 1), reg32_scalar); + mov(reg_tmp, reinterpret_cast(exp_float_consts)); + vmovups(xmm_t(max_num_regs), ptr[reg_tmp + OFFSET_EXP_ONE]); + movd(xmm_t(max_num_regs + 1), reg32_fp_h); + if (type_ == SeqPoolType::kSqrt) { + vsqrtps(xmm_t(max_num_regs + 1), xmm_t(max_num_regs + 1)); + } + vdivps(xmm_t(max_num_regs + 2), xmm_t(max_num_regs), + xmm_t(max_num_regs + 1)); + vbroadcastss(xmm_t(max_num_regs), xmm_t(max_num_regs + 2)); for (int i = 0; i < rest_used_num_regs; ++i) { - vmulps(xmm_t(i), xmm_t(i), xmm_t(max_num_regs - 1)); + vmulps(xmm_t(i), xmm_t(i), xmm_t(max_num_regs)); } } save_rest(rest, w_offset); @@ -151,17 +168,17 @@ class SeqPoolJitCode : public JitCode { const bool has_block1 = (rest % 2) == 1; int reg_idx = reg_start; if (has_block4) { - vmovups(xmm_t(reg_idx + num_shift_regs), ptr[param1 + w_offset]); + vmovups(xmm_t(reg_idx + num_shift_regs), ptr[param_src + w_offset]); w_offset += sizeof(float) * 4; reg_idx++; } if (has_block2) { - vmovq(xmm_t(reg_idx + num_shift_regs), ptr[param1 + w_offset]); + vmovq(xmm_t(reg_idx + num_shift_regs), ptr[param_src + w_offset]); w_offset += sizeof(float) * 2; reg_idx++; } if (has_block1) { - vmovss(xmm_t(reg_idx + num_shift_regs), ptr[param1 + w_offset]); + vmovss(xmm_t(reg_idx + num_shift_regs), ptr[param_src + w_offset]); reg_idx++; } return reg_idx; @@ -174,32 +191,33 @@ class SeqPoolJitCode : public JitCode { const bool has_block1 = (rest % 2) == 1; int reg_idx = reg_start; if (has_block4) { - vmovups(ptr[param2 + w_offset], xmm_t(reg_idx)); + vmovups(ptr[param_dst + w_offset], xmm_t(reg_idx)); w_offset += sizeof(float) * 4; reg_idx++; } if (has_block2) { - vmovq(ptr[param2 + w_offset], xmm_t(reg_idx)); + vmovq(ptr[param_dst + w_offset], xmm_t(reg_idx)); w_offset += sizeof(float) * 2; reg_idx++; } if (has_block1) { - vmovss(ptr[param2 + w_offset], xmm_t(reg_idx)); + vmovss(ptr[param_dst + w_offset], xmm_t(reg_idx)); } } private: - int h_; int w_; SeqPoolType type_; - reg64_t param1{abi_param1}; - reg64_t param2{abi_param2}; - reg64_t param3{abi_param3}; - reg32_t reg32_scalar{r8d}; + reg64_t param_src{abi_param1}; + reg64_t param_dst{abi_param2}; + reg64_t param_attr{abi_param3}; + reg64_t reg_tmp{rax}; + + reg32_t reg32_int_h{r8d}; + reg32_t reg32_fp_h{r9d}; - reg64_t reg_h{r9}; - reg64_t reg_ptr_src_i{r10}; - reg64_t reg_tmp{r11}; + reg64_t reg_h_i{r10}; + reg64_t reg_ptr_src_i{r11}; }; } // namespace gen diff --git a/paddle/fluid/operators/jit/kernel_base.h b/paddle/fluid/operators/jit/kernel_base.h index 2659374650..2a7697a6f2 100644 --- a/paddle/fluid/operators/jit/kernel_base.h +++ b/paddle/fluid/operators/jit/kernel_base.h @@ -46,7 +46,7 @@ typedef enum { typedef enum { kNonePoolType = 0, - kSum, + kSum = 1, kAvg, kSqrt, } SeqPoolType; @@ -121,10 +121,10 @@ struct GRUTuples { }; typedef struct seq_pool_attr_s { - int h, w; + int h, w; // h should always be the first one SeqPoolType type; seq_pool_attr_s() = default; - explicit seq_pool_attr_s(int height, int width, SeqPoolType pool_type) + explicit seq_pool_attr_s(int width, SeqPoolType pool_type, int height = 1) : h(height), w(width), type(pool_type) {} } seq_pool_attr_t; diff --git a/paddle/fluid/operators/jit/kernel_key.cc b/paddle/fluid/operators/jit/kernel_key.cc index db78ed8ad8..61de386886 100644 --- a/paddle/fluid/operators/jit/kernel_key.cc +++ b/paddle/fluid/operators/jit/kernel_key.cc @@ -45,10 +45,8 @@ size_t JitCodeKey(const gru_attr_t& attr) { template <> size_t JitCodeKey(const seq_pool_attr_t& attr) { size_t key = attr.w; - // TODO(TJ): support height, then removed it from key - constexpr int w_shift = 30; - return (key << act_type_shift) + static_cast(attr.type) + - (static_cast(attr.h) << (act_type_shift + w_shift)); + constexpr int pool_type_shift = 3; + return (key << pool_type_shift) + static_cast(attr.type); } } // namespace jit diff --git a/paddle/fluid/operators/jit/refer/refer.h b/paddle/fluid/operators/jit/refer/refer.h index 4e19783c86..b4e9c8dd10 100644 --- a/paddle/fluid/operators/jit/refer/refer.h +++ b/paddle/fluid/operators/jit/refer/refer.h @@ -334,7 +334,6 @@ void NCHW16CMulNC(const T* x, const T* y, T* z, int height, int width) { template void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) { - PADDLE_ENFORCE(attr->type == SeqPoolType::kSum, "Only support sum yet"); for (int w = 0; w < attr->w; ++w) { const T* src = x + w; T* dst = y + w; diff --git a/paddle/fluid/operators/jit/test.cc b/paddle/fluid/operators/jit/test.cc index 0f1776507a..5e05c71f40 100644 --- a/paddle/fluid/operators/jit/test.cc +++ b/paddle/fluid/operators/jit/test.cc @@ -439,9 +439,10 @@ void TestSeqPoolKernel() { // TODO(TJ): support more std::vector pool_types = {jit::SeqPoolType::kSum}; for (auto type : pool_types) { - for (int h : TestSizes()) { - for (int w : TestSizes()) { - const jit::seq_pool_attr_t attr(h, w, type); + for (int w : TestSizes()) { + jit::seq_pool_attr_t attr(w, type); + for (int h : TestSizes()) { + attr.h = h; auto ref = jit::GetRefer>(); EXPECT_TRUE(ref != nullptr); std::vector x(h * w), yref(w); diff --git a/paddle/fluid/operators/math/sequence_pooling.cc b/paddle/fluid/operators/math/sequence_pooling.cc index 283e2e251a..2a47502614 100644 --- a/paddle/fluid/operators/math/sequence_pooling.cc +++ b/paddle/fluid/operators/math/sequence_pooling.cc @@ -252,14 +252,14 @@ class SequencePoolFunctor { PADDLE_ENFORCE(platform::is_cpu_place(place)); const T* src = input.data(); T* dst = output->mutable_data(place); - jit::seq_pool_attr_t attr; - attr.w = input.numel() / input.dims()[0]; - attr.type = jit::SeqPoolType::kSum; + jit::seq_pool_attr_t attr( + static_cast(input.numel() / input.dims()[0]), + jit::SeqPoolType::kSum); + auto seqpool = + jit::Get, platform::CPUPlace>( + attr); for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { attr.h = static_cast(lod[i + 1] - lod[i]); - auto seqpool = - jit::Get, platform::CPUPlace>( - attr); seqpool(src, dst, &attr); dst += attr.w; src += attr.h * attr.w; -- GitLab