diff --git a/paddle/fluid/operators/jit/benchmark.cc b/paddle/fluid/operators/jit/benchmark.cc index 4cbada4a5bf3334e8da5167cecb8e3cf0c19f744..bde2791add4075be6949703dfbea634966d25c1c 100644 --- a/paddle/fluid/operators/jit/benchmark.cc +++ b/paddle/fluid/operators/jit/benchmark.cc @@ -192,7 +192,8 @@ void BenchGRUKernel() { template void BenchSeqPoolKernel() { - std::vector pool_types = {jit::SeqPoolType::kSum}; + std::vector pool_types = { + jit::SeqPoolType::kSum, jit::SeqPoolType::kAvg, jit::SeqPoolType::kSqrt}; for (auto type : pool_types) { for (int w : TestSizes()) { jit::seq_pool_attr_t attr(w, type); diff --git a/paddle/fluid/operators/jit/gen/seqpool.cc b/paddle/fluid/operators/jit/gen/seqpool.cc index d651f282bf16545ae5de1ff77b41769412c062cb..530d24ee1fb7d9da84102641e1d4d2ab08ab1860 100644 --- a/paddle/fluid/operators/jit/gen/seqpool.cc +++ b/paddle/fluid/operators/jit/gen/seqpool.cc @@ -13,7 +13,7 @@ * limitations under the License. */ #include "paddle/fluid/operators/jit/gen/seqpool.h" -#include // offsetof +#include "paddle/fluid/operators/jit/gen/act.h" // for exp_float_consts ones #include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/platform/cpu_info.h" @@ -22,9 +22,6 @@ namespace operators { namespace jit { namespace gen { -thread_local float ALIGN32_BEG float_h[1] ALIGN32_END = { - 1.f}; // TODO(TJ): try move to private - void SeqPoolJitCode::genCode() { constexpr int block = YMM_FLOAT_BLOCK; constexpr int max_num_regs = 8; @@ -33,10 +30,17 @@ void SeqPoolJitCode::genCode() { int rest_num_regs = num_block % max_num_regs; mov(reg32_int_h, dword[param_attr]); if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) { - mov(reg_tmp, reinterpret_cast(float_h)); + mov(reg_tmp, reinterpret_cast(exp_float_consts)); + vmovups(xmm_t(1), ptr[reg_tmp + OFFSET_EXP_ONE]); + mov(reg_tmp, reinterpret_cast(fp_h_)); fild(dword[param_attr]); fstp(dword[reg_tmp]); - mov(reg32_fp_h, dword[reg_tmp]); + vmovss(xmm_t(0), ptr[reg_tmp]); + if (type_ == SeqPoolType::kSqrt) { + vsqrtps(xmm_t(0), xmm_t(0)); + } + vdivps(xmm_t(1), xmm_t(1), xmm_t(0)); + vmovss(ptr[reg_tmp], xmm_t(1)); } const int group_len = max_num_regs * block * sizeof(float); for (int g = 0; g < num_groups; ++g) { @@ -45,7 +49,6 @@ void SeqPoolJitCode::genCode() { if (rest_num_regs > 0) { pool_height(num_groups * group_len, block, rest_num_regs); } - // part of rest_w * height const int rest = w_ % block; pool_height_of_rest_width(rest, (w_ - rest) * sizeof(float), max_num_regs); @@ -58,12 +61,10 @@ class SeqPoolCreator : public JitCodeCreator { return platform::MayIUse(platform::avx); } size_t CodeSize(const seq_pool_attr_t& attr) const override { - // TODO(TJ): remove attr.h when enabled height - bool yes = - attr.type == SeqPoolType::kAvg || attr.type == SeqPoolType::kSqrt; - return 96 /* basic */ + - ((attr.w / YMM_FLOAT_BLOCK + 4 /* rest */) * 2 /* for sum */ - * (attr.h + (yes ? 3 : 1 /*for avg or sqrt*/))) * + return 96 + + ((attr.w / YMM_FLOAT_BLOCK + 4 /* for rest */) * + 4 /* load, mul and save */ + + 256) * 8; } std::unique_ptr CreateJitCode( diff --git a/paddle/fluid/operators/jit/gen/seqpool.h b/paddle/fluid/operators/jit/gen/seqpool.h index c61bf27cc1bc33b840cb3187bf224b8f4f516b76..fcbbb3c84c562e2ba57110134bf07bb218b41edb 100644 --- a/paddle/fluid/operators/jit/gen/seqpool.h +++ b/paddle/fluid/operators/jit/gen/seqpool.h @@ -16,7 +16,6 @@ #include #include "glog/logging.h" -#include "paddle/fluid/operators/jit/gen/act.h" // for ones #include "paddle/fluid/operators/jit/gen/jitcode.h" #include "paddle/fluid/platform/enforce.h" @@ -31,9 +30,11 @@ class SeqPoolJitCode : public JitCode { size_t code_size = 256 * 1024, void* code_ptr = nullptr) : JitCode(code_size, code_ptr), w_(attr.w), type_(attr.type) { - if (type_ != SeqPoolType::kSum) { + if (!(type_ == SeqPoolType::kSum || type_ == SeqPoolType::kAvg || + type_ == SeqPoolType::kSqrt)) { LOG(FATAL) << "Only support sum pool yet "; } + fp_h_[0] = 1.f; this->genCode(); } @@ -82,15 +83,8 @@ class SeqPoolJitCode : public JitCode { L(l_h_done); // save right now if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) { - mov(reg_tmp, reinterpret_cast(exp_float_consts)); - vmovups(JMM(max_num_regs), ptr[reg_tmp + OFFSET_EXP_ONE]); - movd(JMM(max_num_regs + 1), reg32_fp_h); - if (type_ == SeqPoolType::kSqrt) { - vsqrtps(JMM(max_num_regs + 1), JMM(max_num_regs + 1)); - } - vdivps(JMM(max_num_regs + 2), JMM(max_num_regs), JMM(max_num_regs + 1)); - vbroadcastss(JMM(max_num_regs), - JMM(max_num_regs + 2)); // TODO(TJ): fix me + mov(reg_tmp, reinterpret_cast(fp_h_)); + vbroadcastss(JMM(max_num_regs), ptr[reg_tmp]); } offset = w_offset; for (int i = 0; i < max_num_regs; ++i) { @@ -144,15 +138,8 @@ class SeqPoolJitCode : public JitCode { L(l_h_done); // save right now if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) { - mov(reg_tmp, reinterpret_cast(exp_float_consts)); - vmovups(xmm_t(max_num_regs), ptr[reg_tmp + OFFSET_EXP_ONE]); - movd(xmm_t(max_num_regs + 1), reg32_fp_h); - if (type_ == SeqPoolType::kSqrt) { - vsqrtps(xmm_t(max_num_regs + 1), xmm_t(max_num_regs + 1)); - } - vdivps(xmm_t(max_num_regs + 2), xmm_t(max_num_regs), - xmm_t(max_num_regs + 1)); - vbroadcastss(xmm_t(max_num_regs), xmm_t(max_num_regs + 2)); + mov(reg_tmp, reinterpret_cast(fp_h_)); + vbroadcastss(xmm_t(max_num_regs), ptr[reg_tmp]); for (int i = 0; i < rest_used_num_regs; ++i) { vmulps(xmm_t(i), xmm_t(i), xmm_t(max_num_regs)); } @@ -206,6 +193,7 @@ class SeqPoolJitCode : public JitCode { } private: + float ALIGN32_BEG fp_h_[1] ALIGN32_END; int w_; SeqPoolType type_; reg64_t param_src{abi_param1}; diff --git a/paddle/fluid/operators/jit/test.cc b/paddle/fluid/operators/jit/test.cc index 5e05c71f40479b8978bebb2473bdffee5fff27ae..30291bfef3bc96fe2e687e5be6d782eee89496aa 100644 --- a/paddle/fluid/operators/jit/test.cc +++ b/paddle/fluid/operators/jit/test.cc @@ -436,8 +436,8 @@ void TestGRUKernel() { template void TestSeqPoolKernel() { VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); - // TODO(TJ): support more - std::vector pool_types = {jit::SeqPoolType::kSum}; + std::vector pool_types = { + jit::SeqPoolType::kSum, jit::SeqPoolType::kAvg, jit::SeqPoolType::kSqrt}; for (auto type : pool_types) { for (int w : TestSizes()) { jit::seq_pool_attr_t attr(w, type);