提交 0145f40f 编写于 作者: T tensor-tang

use height from params of jitcode

上级 e0591dee
......@@ -195,8 +195,9 @@ void BenchSeqPoolKernel() {
std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum};
for (auto type : pool_types) {
for (int w : TestSizes()) {
jit::seq_pool_attr_t attr(w, type);
for (int h : TestSizes()) {
const jit::seq_pool_attr_t attr(h, w, type);
attr.h = h;
std::vector<T> x(h * w), y(w);
RandomVec<T>(h * w, x.data(), -2.f, 2.f);
const T* x_data = x.data();
......
......@@ -13,6 +13,7 @@
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/seqpool.h"
#include <stddef.h> // offsetof
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
......@@ -21,20 +22,22 @@ namespace operators {
namespace jit {
namespace gen {
thread_local float ALIGN32_BEG float_h[1] ALIGN32_END = {
1.f}; // TODO(TJ): try move to private
void SeqPoolJitCode::genCode() {
constexpr int block = YMM_FLOAT_BLOCK;
constexpr int max_num_regs = 8;
const int num_block = w_ / block;
const int num_groups = num_block / max_num_regs;
int rest_num_regs = num_block % max_num_regs;
if (type_ == SeqPoolType::kAvg) {
float scalar = 1.f / h_;
mov(reg32_scalar, scalar);
} else if (type_ == SeqPoolType::kSqrt) {
float scalar = 1.f / std::sqrt(static_cast<float>(h_));
mov(reg32_scalar, scalar);
mov(reg32_int_h, dword[param_attr]);
if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
mov(reg_tmp, reinterpret_cast<size_t>(float_h));
fild(dword[param_attr]);
fstp(dword[reg_tmp]);
mov(reg32_fp_h, dword[reg_tmp]);
}
const int group_len = max_num_regs * block * sizeof(float);
for (int g = 0; g < num_groups; ++g) {
pool_height<ymm_t>(g * group_len, block, max_num_regs);
......
......@@ -16,6 +16,7 @@
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/act.h" // for ones
#include "paddle/fluid/operators/jit/gen/jitcode.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -29,7 +30,7 @@ class SeqPoolJitCode : public JitCode {
explicit SeqPoolJitCode(const seq_pool_attr_t& attr,
size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr), h_(attr.h), w_(attr.w), type_(attr.type) {
: JitCode(code_size, code_ptr), w_(attr.w), type_(attr.type) {
if (type_ != SeqPoolType::kSum) {
LOG(FATAL) << "Only support sum pool yet ";
}
......@@ -55,39 +56,48 @@ class SeqPoolJitCode : public JitCode {
void pool_height(int w_offset, int block, int max_num_regs) {
int offset = w_offset;
for (int i = 0; i < max_num_regs; ++i) {
vmovups(JMM(i), ptr[param1 + offset]);
vmovups(JMM(i), ptr[param_src + offset]);
offset += sizeof(float) * block;
}
if (h_ > 1) {
Label l_next_h;
mov(reg_h, 1);
mov(reg_tmp, param1);
add(reg_tmp, w_ * sizeof(float) + w_offset);
L(l_next_h);
{
mov(reg_ptr_src_i, reg_tmp);
for (int i = 0; i < max_num_regs; ++i) {
vmovups(JMM(i + max_num_regs), ptr[reg_ptr_src_i]);
// sum anyway
vaddps(JMM(i), JMM(i), JMM(i + max_num_regs));
add(reg_ptr_src_i, sizeof(float) * block);
}
inc(reg_h);
add(reg_tmp, w_ * sizeof(float));
cmp(reg_h, h_);
jl(l_next_h, T_NEAR);
cmp(reg32_int_h, 1);
Label l_next_h, l_h_done;
jle(l_h_done, T_NEAR);
mov(reg_h_i, 1);
mov(reg_tmp, param_src);
add(reg_tmp, w_ * sizeof(float) + w_offset);
L(l_next_h);
{
mov(reg_ptr_src_i, reg_tmp);
for (int i = 0; i < max_num_regs; ++i) {
vmovups(JMM(i + max_num_regs), ptr[reg_ptr_src_i]);
// sum anyway
vaddps(JMM(i), JMM(i), JMM(i + max_num_regs));
add(reg_ptr_src_i, sizeof(float) * block);
}
inc(reg_h_i);
add(reg_tmp, w_ * sizeof(float));
cmp(reg_h_i, reg32_int_h);
jl(l_next_h, T_NEAR);
}
L(l_h_done);
// save right now
if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
vbroadcastss(JMM(max_num_regs), reg32_scalar);
mov(reg_tmp, reinterpret_cast<size_t>(exp_float_consts));
vmovups(JMM(max_num_regs), ptr[reg_tmp + OFFSET_EXP_ONE]);
movd(JMM(max_num_regs + 1), reg32_fp_h);
if (type_ == SeqPoolType::kSqrt) {
vsqrtps(JMM(max_num_regs + 1), JMM(max_num_regs + 1));
}
vdivps(JMM(max_num_regs + 2), JMM(max_num_regs), JMM(max_num_regs + 1));
vbroadcastss(JMM(max_num_regs),
JMM(max_num_regs + 2)); // TODO(TJ): fix me
}
offset = w_offset;
for (int i = 0; i < max_num_regs; ++i) {
if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
vmulps(JMM(i), JMM(i), JMM(max_num_regs));
}
vmovups(ptr[param2 + offset], JMM(i));
vmovups(ptr[param_dst + offset], JMM(i));
offset += sizeof(float) * block;
}
}
......@@ -97,47 +107,54 @@ class SeqPoolJitCode : public JitCode {
const bool has_block4 = rest / 4 > 0;
const bool has_block2 = (rest % 4) / 2 > 0;
const bool has_block1 = (rest % 2) == 1;
if (h_ > 1) {
Label l_next_h;
mov(reg_h, 1);
mov(reg_tmp, param1);
add(reg_tmp, w_ * sizeof(float) + w_offset);
L(l_next_h);
{
// int used_regs =load_rest(rest, h * w_ * sizeof(float) + w_offset,
// max_num_regs);
int reg_idx = 0;
mov(reg_ptr_src_i, reg_tmp);
if (has_block4) {
vmovups(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]);
add(reg_ptr_src_i, sizeof(float) * 4);
reg_idx++;
}
if (has_block2) {
vmovups(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]);
add(reg_ptr_src_i, sizeof(float) * 2);
reg_idx++;
}
if (has_block1) {
vmovss(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]);
reg_idx++;
}
PADDLE_ENFORCE_EQ(reg_idx, rest_used_num_regs,
"All heights should use same regs");
for (int i = 0; i < reg_idx; ++i) {
vaddps(xmm_t(i), xmm_t(i), xmm_t(i + max_num_regs));
}
inc(reg_h);
add(reg_tmp, w_ * sizeof(float));
cmp(reg_h, h_);
jl(l_next_h, T_NEAR);
cmp(reg32_int_h, 1);
Label l_next_h, l_h_done;
jle(l_h_done, T_NEAR);
mov(reg_h_i, 1);
mov(reg_tmp, param_src);
add(reg_tmp, w_ * sizeof(float) + w_offset);
L(l_next_h);
{
int reg_idx = 0;
mov(reg_ptr_src_i, reg_tmp);
if (has_block4) {
vmovups(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]);
add(reg_ptr_src_i, sizeof(float) * 4);
reg_idx++;
}
if (has_block2) {
vmovups(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]);
add(reg_ptr_src_i, sizeof(float) * 2);
reg_idx++;
}
if (has_block1) {
vmovss(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]);
reg_idx++;
}
PADDLE_ENFORCE_EQ(reg_idx, rest_used_num_regs,
"All heights should use same regs");
for (int i = 0; i < reg_idx; ++i) {
vaddps(xmm_t(i), xmm_t(i), xmm_t(i + max_num_regs));
}
inc(reg_h_i);
add(reg_tmp, w_ * sizeof(float));
cmp(reg_h_i, reg32_int_h);
jl(l_next_h, T_NEAR);
}
L(l_h_done);
// save right now
if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
vbroadcastss(xmm_t(max_num_regs - 1), reg32_scalar);
mov(reg_tmp, reinterpret_cast<size_t>(exp_float_consts));
vmovups(xmm_t(max_num_regs), ptr[reg_tmp + OFFSET_EXP_ONE]);
movd(xmm_t(max_num_regs + 1), reg32_fp_h);
if (type_ == SeqPoolType::kSqrt) {
vsqrtps(xmm_t(max_num_regs + 1), xmm_t(max_num_regs + 1));
}
vdivps(xmm_t(max_num_regs + 2), xmm_t(max_num_regs),
xmm_t(max_num_regs + 1));
vbroadcastss(xmm_t(max_num_regs), xmm_t(max_num_regs + 2));
for (int i = 0; i < rest_used_num_regs; ++i) {
vmulps(xmm_t(i), xmm_t(i), xmm_t(max_num_regs - 1));
vmulps(xmm_t(i), xmm_t(i), xmm_t(max_num_regs));
}
}
save_rest(rest, w_offset);
......@@ -151,17 +168,17 @@ class SeqPoolJitCode : public JitCode {
const bool has_block1 = (rest % 2) == 1;
int reg_idx = reg_start;
if (has_block4) {
vmovups(xmm_t(reg_idx + num_shift_regs), ptr[param1 + w_offset]);
vmovups(xmm_t(reg_idx + num_shift_regs), ptr[param_src + w_offset]);
w_offset += sizeof(float) * 4;
reg_idx++;
}
if (has_block2) {
vmovq(xmm_t(reg_idx + num_shift_regs), ptr[param1 + w_offset]);
vmovq(xmm_t(reg_idx + num_shift_regs), ptr[param_src + w_offset]);
w_offset += sizeof(float) * 2;
reg_idx++;
}
if (has_block1) {
vmovss(xmm_t(reg_idx + num_shift_regs), ptr[param1 + w_offset]);
vmovss(xmm_t(reg_idx + num_shift_regs), ptr[param_src + w_offset]);
reg_idx++;
}
return reg_idx;
......@@ -174,32 +191,33 @@ class SeqPoolJitCode : public JitCode {
const bool has_block1 = (rest % 2) == 1;
int reg_idx = reg_start;
if (has_block4) {
vmovups(ptr[param2 + w_offset], xmm_t(reg_idx));
vmovups(ptr[param_dst + w_offset], xmm_t(reg_idx));
w_offset += sizeof(float) * 4;
reg_idx++;
}
if (has_block2) {
vmovq(ptr[param2 + w_offset], xmm_t(reg_idx));
vmovq(ptr[param_dst + w_offset], xmm_t(reg_idx));
w_offset += sizeof(float) * 2;
reg_idx++;
}
if (has_block1) {
vmovss(ptr[param2 + w_offset], xmm_t(reg_idx));
vmovss(ptr[param_dst + w_offset], xmm_t(reg_idx));
}
}
private:
int h_;
int w_;
SeqPoolType type_;
reg64_t param1{abi_param1};
reg64_t param2{abi_param2};
reg64_t param3{abi_param3};
reg32_t reg32_scalar{r8d};
reg64_t param_src{abi_param1};
reg64_t param_dst{abi_param2};
reg64_t param_attr{abi_param3};
reg64_t reg_tmp{rax};
reg32_t reg32_int_h{r8d};
reg32_t reg32_fp_h{r9d};
reg64_t reg_h{r9};
reg64_t reg_ptr_src_i{r10};
reg64_t reg_tmp{r11};
reg64_t reg_h_i{r10};
reg64_t reg_ptr_src_i{r11};
};
} // namespace gen
......
......@@ -46,7 +46,7 @@ typedef enum {
typedef enum {
kNonePoolType = 0,
kSum,
kSum = 1,
kAvg,
kSqrt,
} SeqPoolType;
......@@ -121,10 +121,10 @@ struct GRUTuples {
};
typedef struct seq_pool_attr_s {
int h, w;
int h, w; // h should always be the first one
SeqPoolType type;
seq_pool_attr_s() = default;
explicit seq_pool_attr_s(int height, int width, SeqPoolType pool_type)
explicit seq_pool_attr_s(int width, SeqPoolType pool_type, int height = 1)
: h(height), w(width), type(pool_type) {}
} seq_pool_attr_t;
......
......@@ -45,10 +45,8 @@ size_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) {
template <>
size_t JitCodeKey<seq_pool_attr_t>(const seq_pool_attr_t& attr) {
size_t key = attr.w;
// TODO(TJ): support height, then removed it from key
constexpr int w_shift = 30;
return (key << act_type_shift) + static_cast<int>(attr.type) +
(static_cast<size_t>(attr.h) << (act_type_shift + w_shift));
constexpr int pool_type_shift = 3;
return (key << pool_type_shift) + static_cast<int>(attr.type);
}
} // namespace jit
......
......@@ -334,7 +334,6 @@ void NCHW16CMulNC(const T* x, const T* y, T* z, int height, int width) {
template <typename T>
void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) {
PADDLE_ENFORCE(attr->type == SeqPoolType::kSum, "Only support sum yet");
for (int w = 0; w < attr->w; ++w) {
const T* src = x + w;
T* dst = y + w;
......
......@@ -439,9 +439,10 @@ void TestSeqPoolKernel() {
// TODO(TJ): support more
std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum};
for (auto type : pool_types) {
for (int h : TestSizes()) {
for (int w : TestSizes()) {
const jit::seq_pool_attr_t attr(h, w, type);
for (int w : TestSizes()) {
jit::seq_pool_attr_t attr(w, type);
for (int h : TestSizes()) {
attr.h = h;
auto ref = jit::GetRefer<KT, jit::SeqPoolTuples<T>>();
EXPECT_TRUE(ref != nullptr);
std::vector<T> x(h * w), yref(w);
......
......@@ -252,14 +252,14 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
PADDLE_ENFORCE(platform::is_cpu_place(place));
const T* src = input.data<T>();
T* dst = output->mutable_data<T>(place);
jit::seq_pool_attr_t attr;
attr.w = input.numel() / input.dims()[0];
attr.type = jit::SeqPoolType::kSum;
jit::seq_pool_attr_t attr(
static_cast<int>(input.numel() / input.dims()[0]),
jit::SeqPoolType::kSum);
auto seqpool =
jit::Get<jit::kSeqPool, jit::SeqPoolTuples<T>, platform::CPUPlace>(
attr);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
attr.h = static_cast<int>(lod[i + 1] - lod[i]);
auto seqpool =
jit::Get<jit::kSeqPool, jit::SeqPoolTuples<T>, platform::CPUPlace>(
attr);
seqpool(src, dst, &attr);
dst += attr.w;
src += attr.h * attr.w;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册