提交 123b98f4 编写于 作者: T tensor-tang

refine heigth and codesize and support all pool

test=develop
上级 0145f40f
......@@ -192,7 +192,8 @@ void BenchGRUKernel() {
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void BenchSeqPoolKernel() {
std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum};
std::vector<jit::SeqPoolType> pool_types = {
jit::SeqPoolType::kSum, jit::SeqPoolType::kAvg, jit::SeqPoolType::kSqrt};
for (auto type : pool_types) {
for (int w : TestSizes()) {
jit::seq_pool_attr_t attr(w, type);
......
......@@ -13,7 +13,7 @@
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/seqpool.h"
#include <stddef.h> // offsetof
#include "paddle/fluid/operators/jit/gen/act.h" // for exp_float_consts ones
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
......@@ -22,9 +22,6 @@ namespace operators {
namespace jit {
namespace gen {
thread_local float ALIGN32_BEG float_h[1] ALIGN32_END = {
1.f}; // TODO(TJ): try move to private
void SeqPoolJitCode::genCode() {
constexpr int block = YMM_FLOAT_BLOCK;
constexpr int max_num_regs = 8;
......@@ -33,10 +30,17 @@ void SeqPoolJitCode::genCode() {
int rest_num_regs = num_block % max_num_regs;
mov(reg32_int_h, dword[param_attr]);
if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
mov(reg_tmp, reinterpret_cast<size_t>(float_h));
mov(reg_tmp, reinterpret_cast<size_t>(exp_float_consts));
vmovups(xmm_t(1), ptr[reg_tmp + OFFSET_EXP_ONE]);
mov(reg_tmp, reinterpret_cast<size_t>(fp_h_));
fild(dword[param_attr]);
fstp(dword[reg_tmp]);
mov(reg32_fp_h, dword[reg_tmp]);
vmovss(xmm_t(0), ptr[reg_tmp]);
if (type_ == SeqPoolType::kSqrt) {
vsqrtps(xmm_t(0), xmm_t(0));
}
vdivps(xmm_t(1), xmm_t(1), xmm_t(0));
vmovss(ptr[reg_tmp], xmm_t(1));
}
const int group_len = max_num_regs * block * sizeof(float);
for (int g = 0; g < num_groups; ++g) {
......@@ -45,7 +49,6 @@ void SeqPoolJitCode::genCode() {
if (rest_num_regs > 0) {
pool_height<ymm_t>(num_groups * group_len, block, rest_num_regs);
}
// part of rest_w * height
const int rest = w_ % block;
pool_height_of_rest_width(rest, (w_ - rest) * sizeof(float), max_num_regs);
......@@ -58,12 +61,10 @@ class SeqPoolCreator : public JitCodeCreator<seq_pool_attr_t> {
return platform::MayIUse(platform::avx);
}
size_t CodeSize(const seq_pool_attr_t& attr) const override {
// TODO(TJ): remove attr.h when enabled height
bool yes =
attr.type == SeqPoolType::kAvg || attr.type == SeqPoolType::kSqrt;
return 96 /* basic */ +
((attr.w / YMM_FLOAT_BLOCK + 4 /* rest */) * 2 /* for sum */
* (attr.h + (yes ? 3 : 1 /*for avg or sqrt*/))) *
return 96 +
((attr.w / YMM_FLOAT_BLOCK + 4 /* for rest */) *
4 /* load, mul and save */ +
256) *
8;
}
std::unique_ptr<GenBase> CreateJitCode(
......
......@@ -16,7 +16,6 @@
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/act.h" // for ones
#include "paddle/fluid/operators/jit/gen/jitcode.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -31,9 +30,11 @@ class SeqPoolJitCode : public JitCode {
size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr), w_(attr.w), type_(attr.type) {
if (type_ != SeqPoolType::kSum) {
if (!(type_ == SeqPoolType::kSum || type_ == SeqPoolType::kAvg ||
type_ == SeqPoolType::kSqrt)) {
LOG(FATAL) << "Only support sum pool yet ";
}
fp_h_[0] = 1.f;
this->genCode();
}
......@@ -82,15 +83,8 @@ class SeqPoolJitCode : public JitCode {
L(l_h_done);
// save right now
if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
mov(reg_tmp, reinterpret_cast<size_t>(exp_float_consts));
vmovups(JMM(max_num_regs), ptr[reg_tmp + OFFSET_EXP_ONE]);
movd(JMM(max_num_regs + 1), reg32_fp_h);
if (type_ == SeqPoolType::kSqrt) {
vsqrtps(JMM(max_num_regs + 1), JMM(max_num_regs + 1));
}
vdivps(JMM(max_num_regs + 2), JMM(max_num_regs), JMM(max_num_regs + 1));
vbroadcastss(JMM(max_num_regs),
JMM(max_num_regs + 2)); // TODO(TJ): fix me
mov(reg_tmp, reinterpret_cast<size_t>(fp_h_));
vbroadcastss(JMM(max_num_regs), ptr[reg_tmp]);
}
offset = w_offset;
for (int i = 0; i < max_num_regs; ++i) {
......@@ -144,15 +138,8 @@ class SeqPoolJitCode : public JitCode {
L(l_h_done);
// save right now
if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
mov(reg_tmp, reinterpret_cast<size_t>(exp_float_consts));
vmovups(xmm_t(max_num_regs), ptr[reg_tmp + OFFSET_EXP_ONE]);
movd(xmm_t(max_num_regs + 1), reg32_fp_h);
if (type_ == SeqPoolType::kSqrt) {
vsqrtps(xmm_t(max_num_regs + 1), xmm_t(max_num_regs + 1));
}
vdivps(xmm_t(max_num_regs + 2), xmm_t(max_num_regs),
xmm_t(max_num_regs + 1));
vbroadcastss(xmm_t(max_num_regs), xmm_t(max_num_regs + 2));
mov(reg_tmp, reinterpret_cast<size_t>(fp_h_));
vbroadcastss(xmm_t(max_num_regs), ptr[reg_tmp]);
for (int i = 0; i < rest_used_num_regs; ++i) {
vmulps(xmm_t(i), xmm_t(i), xmm_t(max_num_regs));
}
......@@ -206,6 +193,7 @@ class SeqPoolJitCode : public JitCode {
}
private:
float ALIGN32_BEG fp_h_[1] ALIGN32_END;
int w_;
SeqPoolType type_;
reg64_t param_src{abi_param1};
......
......@@ -436,8 +436,8 @@ void TestGRUKernel() {
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void TestSeqPoolKernel() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
// TODO(TJ): support more
std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum};
std::vector<jit::SeqPoolType> pool_types = {
jit::SeqPoolType::kSum, jit::SeqPoolType::kAvg, jit::SeqPoolType::kSqrt};
for (auto type : pool_types) {
for (int w : TestSizes()) {
jit::seq_pool_attr_t attr(w, type);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册