提交 123b98f4 编写于 作者: T tensor-tang

refine heigth and codesize and support all pool

test=develop
上级 0145f40f
...@@ -192,7 +192,8 @@ void BenchGRUKernel() { ...@@ -192,7 +192,8 @@ void BenchGRUKernel() {
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void BenchSeqPoolKernel() { void BenchSeqPoolKernel() {
std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum}; std::vector<jit::SeqPoolType> pool_types = {
jit::SeqPoolType::kSum, jit::SeqPoolType::kAvg, jit::SeqPoolType::kSqrt};
for (auto type : pool_types) { for (auto type : pool_types) {
for (int w : TestSizes()) { for (int w : TestSizes()) {
jit::seq_pool_attr_t attr(w, type); jit::seq_pool_attr_t attr(w, type);
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/gen/seqpool.h" #include "paddle/fluid/operators/jit/gen/seqpool.h"
#include <stddef.h> // offsetof #include "paddle/fluid/operators/jit/gen/act.h" // for exp_float_consts ones
#include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
...@@ -22,9 +22,6 @@ namespace operators { ...@@ -22,9 +22,6 @@ namespace operators {
namespace jit { namespace jit {
namespace gen { namespace gen {
thread_local float ALIGN32_BEG float_h[1] ALIGN32_END = {
1.f}; // TODO(TJ): try move to private
void SeqPoolJitCode::genCode() { void SeqPoolJitCode::genCode() {
constexpr int block = YMM_FLOAT_BLOCK; constexpr int block = YMM_FLOAT_BLOCK;
constexpr int max_num_regs = 8; constexpr int max_num_regs = 8;
...@@ -33,10 +30,17 @@ void SeqPoolJitCode::genCode() { ...@@ -33,10 +30,17 @@ void SeqPoolJitCode::genCode() {
int rest_num_regs = num_block % max_num_regs; int rest_num_regs = num_block % max_num_regs;
mov(reg32_int_h, dword[param_attr]); mov(reg32_int_h, dword[param_attr]);
if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) { if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
mov(reg_tmp, reinterpret_cast<size_t>(float_h)); mov(reg_tmp, reinterpret_cast<size_t>(exp_float_consts));
vmovups(xmm_t(1), ptr[reg_tmp + OFFSET_EXP_ONE]);
mov(reg_tmp, reinterpret_cast<size_t>(fp_h_));
fild(dword[param_attr]); fild(dword[param_attr]);
fstp(dword[reg_tmp]); fstp(dword[reg_tmp]);
mov(reg32_fp_h, dword[reg_tmp]); vmovss(xmm_t(0), ptr[reg_tmp]);
if (type_ == SeqPoolType::kSqrt) {
vsqrtps(xmm_t(0), xmm_t(0));
}
vdivps(xmm_t(1), xmm_t(1), xmm_t(0));
vmovss(ptr[reg_tmp], xmm_t(1));
} }
const int group_len = max_num_regs * block * sizeof(float); const int group_len = max_num_regs * block * sizeof(float);
for (int g = 0; g < num_groups; ++g) { for (int g = 0; g < num_groups; ++g) {
...@@ -45,7 +49,6 @@ void SeqPoolJitCode::genCode() { ...@@ -45,7 +49,6 @@ void SeqPoolJitCode::genCode() {
if (rest_num_regs > 0) { if (rest_num_regs > 0) {
pool_height<ymm_t>(num_groups * group_len, block, rest_num_regs); pool_height<ymm_t>(num_groups * group_len, block, rest_num_regs);
} }
// part of rest_w * height // part of rest_w * height
const int rest = w_ % block; const int rest = w_ % block;
pool_height_of_rest_width(rest, (w_ - rest) * sizeof(float), max_num_regs); pool_height_of_rest_width(rest, (w_ - rest) * sizeof(float), max_num_regs);
...@@ -58,12 +61,10 @@ class SeqPoolCreator : public JitCodeCreator<seq_pool_attr_t> { ...@@ -58,12 +61,10 @@ class SeqPoolCreator : public JitCodeCreator<seq_pool_attr_t> {
return platform::MayIUse(platform::avx); return platform::MayIUse(platform::avx);
} }
size_t CodeSize(const seq_pool_attr_t& attr) const override { size_t CodeSize(const seq_pool_attr_t& attr) const override {
// TODO(TJ): remove attr.h when enabled height return 96 +
bool yes = ((attr.w / YMM_FLOAT_BLOCK + 4 /* for rest */) *
attr.type == SeqPoolType::kAvg || attr.type == SeqPoolType::kSqrt; 4 /* load, mul and save */ +
return 96 /* basic */ + 256) *
((attr.w / YMM_FLOAT_BLOCK + 4 /* rest */) * 2 /* for sum */
* (attr.h + (yes ? 3 : 1 /*for avg or sqrt*/))) *
8; 8;
} }
std::unique_ptr<GenBase> CreateJitCode( std::unique_ptr<GenBase> CreateJitCode(
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
#include <string> #include <string>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/act.h" // for ones
#include "paddle/fluid/operators/jit/gen/jitcode.h" #include "paddle/fluid/operators/jit/gen/jitcode.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -31,9 +30,11 @@ class SeqPoolJitCode : public JitCode { ...@@ -31,9 +30,11 @@ class SeqPoolJitCode : public JitCode {
size_t code_size = 256 * 1024, size_t code_size = 256 * 1024,
void* code_ptr = nullptr) void* code_ptr = nullptr)
: JitCode(code_size, code_ptr), w_(attr.w), type_(attr.type) { : JitCode(code_size, code_ptr), w_(attr.w), type_(attr.type) {
if (type_ != SeqPoolType::kSum) { if (!(type_ == SeqPoolType::kSum || type_ == SeqPoolType::kAvg ||
type_ == SeqPoolType::kSqrt)) {
LOG(FATAL) << "Only support sum pool yet "; LOG(FATAL) << "Only support sum pool yet ";
} }
fp_h_[0] = 1.f;
this->genCode(); this->genCode();
} }
...@@ -82,15 +83,8 @@ class SeqPoolJitCode : public JitCode { ...@@ -82,15 +83,8 @@ class SeqPoolJitCode : public JitCode {
L(l_h_done); L(l_h_done);
// save right now // save right now
if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) { if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
mov(reg_tmp, reinterpret_cast<size_t>(exp_float_consts)); mov(reg_tmp, reinterpret_cast<size_t>(fp_h_));
vmovups(JMM(max_num_regs), ptr[reg_tmp + OFFSET_EXP_ONE]); vbroadcastss(JMM(max_num_regs), ptr[reg_tmp]);
movd(JMM(max_num_regs + 1), reg32_fp_h);
if (type_ == SeqPoolType::kSqrt) {
vsqrtps(JMM(max_num_regs + 1), JMM(max_num_regs + 1));
}
vdivps(JMM(max_num_regs + 2), JMM(max_num_regs), JMM(max_num_regs + 1));
vbroadcastss(JMM(max_num_regs),
JMM(max_num_regs + 2)); // TODO(TJ): fix me
} }
offset = w_offset; offset = w_offset;
for (int i = 0; i < max_num_regs; ++i) { for (int i = 0; i < max_num_regs; ++i) {
...@@ -144,15 +138,8 @@ class SeqPoolJitCode : public JitCode { ...@@ -144,15 +138,8 @@ class SeqPoolJitCode : public JitCode {
L(l_h_done); L(l_h_done);
// save right now // save right now
if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) { if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
mov(reg_tmp, reinterpret_cast<size_t>(exp_float_consts)); mov(reg_tmp, reinterpret_cast<size_t>(fp_h_));
vmovups(xmm_t(max_num_regs), ptr[reg_tmp + OFFSET_EXP_ONE]); vbroadcastss(xmm_t(max_num_regs), ptr[reg_tmp]);
movd(xmm_t(max_num_regs + 1), reg32_fp_h);
if (type_ == SeqPoolType::kSqrt) {
vsqrtps(xmm_t(max_num_regs + 1), xmm_t(max_num_regs + 1));
}
vdivps(xmm_t(max_num_regs + 2), xmm_t(max_num_regs),
xmm_t(max_num_regs + 1));
vbroadcastss(xmm_t(max_num_regs), xmm_t(max_num_regs + 2));
for (int i = 0; i < rest_used_num_regs; ++i) { for (int i = 0; i < rest_used_num_regs; ++i) {
vmulps(xmm_t(i), xmm_t(i), xmm_t(max_num_regs)); vmulps(xmm_t(i), xmm_t(i), xmm_t(max_num_regs));
} }
...@@ -206,6 +193,7 @@ class SeqPoolJitCode : public JitCode { ...@@ -206,6 +193,7 @@ class SeqPoolJitCode : public JitCode {
} }
private: private:
float ALIGN32_BEG fp_h_[1] ALIGN32_END;
int w_; int w_;
SeqPoolType type_; SeqPoolType type_;
reg64_t param_src{abi_param1}; reg64_t param_src{abi_param1};
......
...@@ -436,8 +436,8 @@ void TestGRUKernel() { ...@@ -436,8 +436,8 @@ void TestGRUKernel() {
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void TestSeqPoolKernel() { void TestSeqPoolKernel() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
// TODO(TJ): support more std::vector<jit::SeqPoolType> pool_types = {
std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum}; jit::SeqPoolType::kSum, jit::SeqPoolType::kAvg, jit::SeqPoolType::kSqrt};
for (auto type : pool_types) { for (auto type : pool_types) {
for (int w : TestSizes()) { for (int w : TestSizes()) {
jit::seq_pool_attr_t attr(w, type); jit::seq_pool_attr_t attr(w, type);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册