提交 0145f40f 编写于 作者: T tensor-tang

use height from params of jitcode

上级 e0591dee
...@@ -195,8 +195,9 @@ void BenchSeqPoolKernel() { ...@@ -195,8 +195,9 @@ void BenchSeqPoolKernel() {
std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum}; std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum};
for (auto type : pool_types) { for (auto type : pool_types) {
for (int w : TestSizes()) { for (int w : TestSizes()) {
jit::seq_pool_attr_t attr(w, type);
for (int h : TestSizes()) { for (int h : TestSizes()) {
const jit::seq_pool_attr_t attr(h, w, type); attr.h = h;
std::vector<T> x(h * w), y(w); std::vector<T> x(h * w), y(w);
RandomVec<T>(h * w, x.data(), -2.f, 2.f); RandomVec<T>(h * w, x.data(), -2.f, 2.f);
const T* x_data = x.data(); const T* x_data = x.data();
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/gen/seqpool.h" #include "paddle/fluid/operators/jit/gen/seqpool.h"
#include <stddef.h> // offsetof
#include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
...@@ -21,20 +22,22 @@ namespace operators { ...@@ -21,20 +22,22 @@ namespace operators {
namespace jit { namespace jit {
namespace gen { namespace gen {
thread_local float ALIGN32_BEG float_h[1] ALIGN32_END = {
1.f}; // TODO(TJ): try move to private
void SeqPoolJitCode::genCode() { void SeqPoolJitCode::genCode() {
constexpr int block = YMM_FLOAT_BLOCK; constexpr int block = YMM_FLOAT_BLOCK;
constexpr int max_num_regs = 8; constexpr int max_num_regs = 8;
const int num_block = w_ / block; const int num_block = w_ / block;
const int num_groups = num_block / max_num_regs; const int num_groups = num_block / max_num_regs;
int rest_num_regs = num_block % max_num_regs; int rest_num_regs = num_block % max_num_regs;
if (type_ == SeqPoolType::kAvg) { mov(reg32_int_h, dword[param_attr]);
float scalar = 1.f / h_; if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
mov(reg32_scalar, scalar); mov(reg_tmp, reinterpret_cast<size_t>(float_h));
} else if (type_ == SeqPoolType::kSqrt) { fild(dword[param_attr]);
float scalar = 1.f / std::sqrt(static_cast<float>(h_)); fstp(dword[reg_tmp]);
mov(reg32_scalar, scalar); mov(reg32_fp_h, dword[reg_tmp]);
} }
const int group_len = max_num_regs * block * sizeof(float); const int group_len = max_num_regs * block * sizeof(float);
for (int g = 0; g < num_groups; ++g) { for (int g = 0; g < num_groups; ++g) {
pool_height<ymm_t>(g * group_len, block, max_num_regs); pool_height<ymm_t>(g * group_len, block, max_num_regs);
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <string> #include <string>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/act.h" // for ones
#include "paddle/fluid/operators/jit/gen/jitcode.h" #include "paddle/fluid/operators/jit/gen/jitcode.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -29,7 +30,7 @@ class SeqPoolJitCode : public JitCode { ...@@ -29,7 +30,7 @@ class SeqPoolJitCode : public JitCode {
explicit SeqPoolJitCode(const seq_pool_attr_t& attr, explicit SeqPoolJitCode(const seq_pool_attr_t& attr,
size_t code_size = 256 * 1024, size_t code_size = 256 * 1024,
void* code_ptr = nullptr) void* code_ptr = nullptr)
: JitCode(code_size, code_ptr), h_(attr.h), w_(attr.w), type_(attr.type) { : JitCode(code_size, code_ptr), w_(attr.w), type_(attr.type) {
if (type_ != SeqPoolType::kSum) { if (type_ != SeqPoolType::kSum) {
LOG(FATAL) << "Only support sum pool yet "; LOG(FATAL) << "Only support sum pool yet ";
} }
...@@ -55,13 +56,14 @@ class SeqPoolJitCode : public JitCode { ...@@ -55,13 +56,14 @@ class SeqPoolJitCode : public JitCode {
void pool_height(int w_offset, int block, int max_num_regs) { void pool_height(int w_offset, int block, int max_num_regs) {
int offset = w_offset; int offset = w_offset;
for (int i = 0; i < max_num_regs; ++i) { for (int i = 0; i < max_num_regs; ++i) {
vmovups(JMM(i), ptr[param1 + offset]); vmovups(JMM(i), ptr[param_src + offset]);
offset += sizeof(float) * block; offset += sizeof(float) * block;
} }
if (h_ > 1) { cmp(reg32_int_h, 1);
Label l_next_h; Label l_next_h, l_h_done;
mov(reg_h, 1); jle(l_h_done, T_NEAR);
mov(reg_tmp, param1); mov(reg_h_i, 1);
mov(reg_tmp, param_src);
add(reg_tmp, w_ * sizeof(float) + w_offset); add(reg_tmp, w_ * sizeof(float) + w_offset);
L(l_next_h); L(l_next_h);
{ {
...@@ -72,22 +74,30 @@ class SeqPoolJitCode : public JitCode { ...@@ -72,22 +74,30 @@ class SeqPoolJitCode : public JitCode {
vaddps(JMM(i), JMM(i), JMM(i + max_num_regs)); vaddps(JMM(i), JMM(i), JMM(i + max_num_regs));
add(reg_ptr_src_i, sizeof(float) * block); add(reg_ptr_src_i, sizeof(float) * block);
} }
inc(reg_h); inc(reg_h_i);
add(reg_tmp, w_ * sizeof(float)); add(reg_tmp, w_ * sizeof(float));
cmp(reg_h, h_); cmp(reg_h_i, reg32_int_h);
jl(l_next_h, T_NEAR); jl(l_next_h, T_NEAR);
} }
} L(l_h_done);
// save right now // save right now
if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) { if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
vbroadcastss(JMM(max_num_regs), reg32_scalar); mov(reg_tmp, reinterpret_cast<size_t>(exp_float_consts));
vmovups(JMM(max_num_regs), ptr[reg_tmp + OFFSET_EXP_ONE]);
movd(JMM(max_num_regs + 1), reg32_fp_h);
if (type_ == SeqPoolType::kSqrt) {
vsqrtps(JMM(max_num_regs + 1), JMM(max_num_regs + 1));
}
vdivps(JMM(max_num_regs + 2), JMM(max_num_regs), JMM(max_num_regs + 1));
vbroadcastss(JMM(max_num_regs),
JMM(max_num_regs + 2)); // TODO(TJ): fix me
} }
offset = w_offset; offset = w_offset;
for (int i = 0; i < max_num_regs; ++i) { for (int i = 0; i < max_num_regs; ++i) {
if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) { if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
vmulps(JMM(i), JMM(i), JMM(max_num_regs)); vmulps(JMM(i), JMM(i), JMM(max_num_regs));
} }
vmovups(ptr[param2 + offset], JMM(i)); vmovups(ptr[param_dst + offset], JMM(i));
offset += sizeof(float) * block; offset += sizeof(float) * block;
} }
} }
...@@ -97,15 +107,14 @@ class SeqPoolJitCode : public JitCode { ...@@ -97,15 +107,14 @@ class SeqPoolJitCode : public JitCode {
const bool has_block4 = rest / 4 > 0; const bool has_block4 = rest / 4 > 0;
const bool has_block2 = (rest % 4) / 2 > 0; const bool has_block2 = (rest % 4) / 2 > 0;
const bool has_block1 = (rest % 2) == 1; const bool has_block1 = (rest % 2) == 1;
if (h_ > 1) { cmp(reg32_int_h, 1);
Label l_next_h; Label l_next_h, l_h_done;
mov(reg_h, 1); jle(l_h_done, T_NEAR);
mov(reg_tmp, param1); mov(reg_h_i, 1);
mov(reg_tmp, param_src);
add(reg_tmp, w_ * sizeof(float) + w_offset); add(reg_tmp, w_ * sizeof(float) + w_offset);
L(l_next_h); L(l_next_h);
{ {
// int used_regs =load_rest(rest, h * w_ * sizeof(float) + w_offset,
// max_num_regs);
int reg_idx = 0; int reg_idx = 0;
mov(reg_ptr_src_i, reg_tmp); mov(reg_ptr_src_i, reg_tmp);
if (has_block4) { if (has_block4) {
...@@ -127,17 +136,25 @@ class SeqPoolJitCode : public JitCode { ...@@ -127,17 +136,25 @@ class SeqPoolJitCode : public JitCode {
for (int i = 0; i < reg_idx; ++i) { for (int i = 0; i < reg_idx; ++i) {
vaddps(xmm_t(i), xmm_t(i), xmm_t(i + max_num_regs)); vaddps(xmm_t(i), xmm_t(i), xmm_t(i + max_num_regs));
} }
inc(reg_h); inc(reg_h_i);
add(reg_tmp, w_ * sizeof(float)); add(reg_tmp, w_ * sizeof(float));
cmp(reg_h, h_); cmp(reg_h_i, reg32_int_h);
jl(l_next_h, T_NEAR); jl(l_next_h, T_NEAR);
} }
} L(l_h_done);
// save right now // save right now
if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) { if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
vbroadcastss(xmm_t(max_num_regs - 1), reg32_scalar); mov(reg_tmp, reinterpret_cast<size_t>(exp_float_consts));
vmovups(xmm_t(max_num_regs), ptr[reg_tmp + OFFSET_EXP_ONE]);
movd(xmm_t(max_num_regs + 1), reg32_fp_h);
if (type_ == SeqPoolType::kSqrt) {
vsqrtps(xmm_t(max_num_regs + 1), xmm_t(max_num_regs + 1));
}
vdivps(xmm_t(max_num_regs + 2), xmm_t(max_num_regs),
xmm_t(max_num_regs + 1));
vbroadcastss(xmm_t(max_num_regs), xmm_t(max_num_regs + 2));
for (int i = 0; i < rest_used_num_regs; ++i) { for (int i = 0; i < rest_used_num_regs; ++i) {
vmulps(xmm_t(i), xmm_t(i), xmm_t(max_num_regs - 1)); vmulps(xmm_t(i), xmm_t(i), xmm_t(max_num_regs));
} }
} }
save_rest(rest, w_offset); save_rest(rest, w_offset);
...@@ -151,17 +168,17 @@ class SeqPoolJitCode : public JitCode { ...@@ -151,17 +168,17 @@ class SeqPoolJitCode : public JitCode {
const bool has_block1 = (rest % 2) == 1; const bool has_block1 = (rest % 2) == 1;
int reg_idx = reg_start; int reg_idx = reg_start;
if (has_block4) { if (has_block4) {
vmovups(xmm_t(reg_idx + num_shift_regs), ptr[param1 + w_offset]); vmovups(xmm_t(reg_idx + num_shift_regs), ptr[param_src + w_offset]);
w_offset += sizeof(float) * 4; w_offset += sizeof(float) * 4;
reg_idx++; reg_idx++;
} }
if (has_block2) { if (has_block2) {
vmovq(xmm_t(reg_idx + num_shift_regs), ptr[param1 + w_offset]); vmovq(xmm_t(reg_idx + num_shift_regs), ptr[param_src + w_offset]);
w_offset += sizeof(float) * 2; w_offset += sizeof(float) * 2;
reg_idx++; reg_idx++;
} }
if (has_block1) { if (has_block1) {
vmovss(xmm_t(reg_idx + num_shift_regs), ptr[param1 + w_offset]); vmovss(xmm_t(reg_idx + num_shift_regs), ptr[param_src + w_offset]);
reg_idx++; reg_idx++;
} }
return reg_idx; return reg_idx;
...@@ -174,32 +191,33 @@ class SeqPoolJitCode : public JitCode { ...@@ -174,32 +191,33 @@ class SeqPoolJitCode : public JitCode {
const bool has_block1 = (rest % 2) == 1; const bool has_block1 = (rest % 2) == 1;
int reg_idx = reg_start; int reg_idx = reg_start;
if (has_block4) { if (has_block4) {
vmovups(ptr[param2 + w_offset], xmm_t(reg_idx)); vmovups(ptr[param_dst + w_offset], xmm_t(reg_idx));
w_offset += sizeof(float) * 4; w_offset += sizeof(float) * 4;
reg_idx++; reg_idx++;
} }
if (has_block2) { if (has_block2) {
vmovq(ptr[param2 + w_offset], xmm_t(reg_idx)); vmovq(ptr[param_dst + w_offset], xmm_t(reg_idx));
w_offset += sizeof(float) * 2; w_offset += sizeof(float) * 2;
reg_idx++; reg_idx++;
} }
if (has_block1) { if (has_block1) {
vmovss(ptr[param2 + w_offset], xmm_t(reg_idx)); vmovss(ptr[param_dst + w_offset], xmm_t(reg_idx));
} }
} }
private: private:
int h_;
int w_; int w_;
SeqPoolType type_; SeqPoolType type_;
reg64_t param1{abi_param1}; reg64_t param_src{abi_param1};
reg64_t param2{abi_param2}; reg64_t param_dst{abi_param2};
reg64_t param3{abi_param3}; reg64_t param_attr{abi_param3};
reg32_t reg32_scalar{r8d}; reg64_t reg_tmp{rax};
reg32_t reg32_int_h{r8d};
reg32_t reg32_fp_h{r9d};
reg64_t reg_h{r9}; reg64_t reg_h_i{r10};
reg64_t reg_ptr_src_i{r10}; reg64_t reg_ptr_src_i{r11};
reg64_t reg_tmp{r11};
}; };
} // namespace gen } // namespace gen
......
...@@ -46,7 +46,7 @@ typedef enum { ...@@ -46,7 +46,7 @@ typedef enum {
typedef enum { typedef enum {
kNonePoolType = 0, kNonePoolType = 0,
kSum, kSum = 1,
kAvg, kAvg,
kSqrt, kSqrt,
} SeqPoolType; } SeqPoolType;
...@@ -121,10 +121,10 @@ struct GRUTuples { ...@@ -121,10 +121,10 @@ struct GRUTuples {
}; };
typedef struct seq_pool_attr_s { typedef struct seq_pool_attr_s {
int h, w; int h, w; // h should always be the first one
SeqPoolType type; SeqPoolType type;
seq_pool_attr_s() = default; seq_pool_attr_s() = default;
explicit seq_pool_attr_s(int height, int width, SeqPoolType pool_type) explicit seq_pool_attr_s(int width, SeqPoolType pool_type, int height = 1)
: h(height), w(width), type(pool_type) {} : h(height), w(width), type(pool_type) {}
} seq_pool_attr_t; } seq_pool_attr_t;
......
...@@ -45,10 +45,8 @@ size_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) { ...@@ -45,10 +45,8 @@ size_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) {
template <> template <>
size_t JitCodeKey<seq_pool_attr_t>(const seq_pool_attr_t& attr) { size_t JitCodeKey<seq_pool_attr_t>(const seq_pool_attr_t& attr) {
size_t key = attr.w; size_t key = attr.w;
// TODO(TJ): support height, then removed it from key constexpr int pool_type_shift = 3;
constexpr int w_shift = 30; return (key << pool_type_shift) + static_cast<int>(attr.type);
return (key << act_type_shift) + static_cast<int>(attr.type) +
(static_cast<size_t>(attr.h) << (act_type_shift + w_shift));
} }
} // namespace jit } // namespace jit
......
...@@ -334,7 +334,6 @@ void NCHW16CMulNC(const T* x, const T* y, T* z, int height, int width) { ...@@ -334,7 +334,6 @@ void NCHW16CMulNC(const T* x, const T* y, T* z, int height, int width) {
template <typename T> template <typename T>
void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) { void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) {
PADDLE_ENFORCE(attr->type == SeqPoolType::kSum, "Only support sum yet");
for (int w = 0; w < attr->w; ++w) { for (int w = 0; w < attr->w; ++w) {
const T* src = x + w; const T* src = x + w;
T* dst = y + w; T* dst = y + w;
......
...@@ -439,9 +439,10 @@ void TestSeqPoolKernel() { ...@@ -439,9 +439,10 @@ void TestSeqPoolKernel() {
// TODO(TJ): support more // TODO(TJ): support more
std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum}; std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum};
for (auto type : pool_types) { for (auto type : pool_types) {
for (int h : TestSizes()) {
for (int w : TestSizes()) { for (int w : TestSizes()) {
const jit::seq_pool_attr_t attr(h, w, type); jit::seq_pool_attr_t attr(w, type);
for (int h : TestSizes()) {
attr.h = h;
auto ref = jit::GetRefer<KT, jit::SeqPoolTuples<T>>(); auto ref = jit::GetRefer<KT, jit::SeqPoolTuples<T>>();
EXPECT_TRUE(ref != nullptr); EXPECT_TRUE(ref != nullptr);
std::vector<T> x(h * w), yref(w); std::vector<T> x(h * w), yref(w);
......
...@@ -252,14 +252,14 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> { ...@@ -252,14 +252,14 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
PADDLE_ENFORCE(platform::is_cpu_place(place)); PADDLE_ENFORCE(platform::is_cpu_place(place));
const T* src = input.data<T>(); const T* src = input.data<T>();
T* dst = output->mutable_data<T>(place); T* dst = output->mutable_data<T>(place);
jit::seq_pool_attr_t attr; jit::seq_pool_attr_t attr(
attr.w = input.numel() / input.dims()[0]; static_cast<int>(input.numel() / input.dims()[0]),
attr.type = jit::SeqPoolType::kSum; jit::SeqPoolType::kSum);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
attr.h = static_cast<int>(lod[i + 1] - lod[i]);
auto seqpool = auto seqpool =
jit::Get<jit::kSeqPool, jit::SeqPoolTuples<T>, platform::CPUPlace>( jit::Get<jit::kSeqPool, jit::SeqPoolTuples<T>, platform::CPUPlace>(
attr); attr);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
attr.h = static_cast<int>(lod[i + 1] - lod[i]);
seqpool(src, dst, &attr); seqpool(src, dst, &attr);
dst += attr.w; dst += attr.w;
src += attr.h * attr.w; src += attr.h * attr.w;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册