未验证 提交 223c61ca 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #15170 from tensor-tang/jit/seqpool

refine seqpool op
......@@ -190,6 +190,26 @@ void BenchGRUKernel() {
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void BenchSeqPoolKernel() {
std::vector<jit::SeqPoolType> pool_types = {
jit::SeqPoolType::kSum, jit::SeqPoolType::kAvg, jit::SeqPoolType::kSqrt};
for (auto type : pool_types) {
for (int w : TestSizes()) {
jit::seq_pool_attr_t attr(w, type);
for (int h : TestSizes()) {
attr.h = h;
std::vector<T> x(h * w), y(w);
RandomVec<T>(h * w, x.data(), -2.f, 2.f);
const T* x_data = x.data();
T* y_data = y.data();
BenchAllImpls<KT, jit::SeqPoolTuples<T>, PlaceType>(attr, x_data,
y_data, &attr);
}
}
}
}
// Benchmark all jit kernels including jitcode, mkl and refer.
// To use this tool, run command: ./benchmark [options...]
// Options:
......@@ -228,4 +248,7 @@ int main(int argc, char* argv[]) {
BenchGRUKernel<jit::kGRUH1, T, PlaceType>();
BenchGRUKernel<jit::kGRUHtPart1, T, PlaceType>();
BenchGRUKernel<jit::kGRUHtPart2, T, PlaceType>();
// seq pool function
BenchSeqPoolKernel<jit::kSeqPool, T, PlaceType>();
}
......@@ -26,3 +26,4 @@ USE_JITKERNEL_GEN(kGRUH1)
USE_JITKERNEL_GEN(kGRUHtPart1)
USE_JITKERNEL_GEN(kGRUHtPart2)
USE_JITKERNEL_GEN(kNCHW16CMulNC)
USE_JITKERNEL_GEN(kSeqPool)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/seqpool.h"
#include "paddle/fluid/operators/jit/gen/act.h" // for exp_float_consts ones
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
void SeqPoolJitCode::genCode() {
constexpr int block = YMM_FLOAT_BLOCK;
constexpr int max_num_regs = 8;
const int num_block = w_ / block;
const int num_groups = num_block / max_num_regs;
int rest_num_regs = num_block % max_num_regs;
mov(reg32_int_h, dword[param_attr]);
if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
mov(reg_tmp, reinterpret_cast<size_t>(exp_float_consts));
vmovups(xmm_t(1), ptr[reg_tmp + OFFSET_EXP_ONE]);
mov(reg_tmp, reinterpret_cast<size_t>(fp_h_));
fild(dword[param_attr]);
fstp(dword[reg_tmp]);
vmovss(xmm_t(0), ptr[reg_tmp]);
if (type_ == SeqPoolType::kSqrt) {
vsqrtps(xmm_t(0), xmm_t(0));
}
vdivps(xmm_t(1), xmm_t(1), xmm_t(0));
vmovss(ptr[reg_tmp], xmm_t(1));
}
const int group_len = max_num_regs * block * sizeof(float);
for (int g = 0; g < num_groups; ++g) {
pool_height<ymm_t>(g * group_len, block, max_num_regs);
}
if (rest_num_regs > 0) {
pool_height<ymm_t>(num_groups * group_len, block, rest_num_regs);
}
// part of rest_w * height
const int rest = w_ % block;
pool_height_of_rest_width(rest, (w_ - rest) * sizeof(float), max_num_regs);
ret();
}
class SeqPoolCreator : public JitCodeCreator<seq_pool_attr_t> {
public:
bool UseMe(const seq_pool_attr_t& attr) const override {
return platform::MayIUse(platform::avx);
}
size_t CodeSize(const seq_pool_attr_t& attr) const override {
return 96 +
((attr.w / YMM_FLOAT_BLOCK + 4 /* for rest */) *
4 /* load, mul and save */ +
256) *
8;
}
std::unique_ptr<GenBase> CreateJitCode(
const seq_pool_attr_t& attr) const override {
PADDLE_ENFORCE_GT(attr.w, 0);
PADDLE_ENFORCE_GT(attr.h, 0);
return make_unique<SeqPoolJitCode>(attr, CodeSize(attr));
}
};
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(kSeqPool, gen::SeqPoolCreator);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
class SeqPoolJitCode : public JitCode {
public:
explicit SeqPoolJitCode(const seq_pool_attr_t& attr,
size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr), w_(attr.w), type_(attr.type) {
if (!(type_ == SeqPoolType::kSum || type_ == SeqPoolType::kAvg ||
type_ == SeqPoolType::kSqrt)) {
LOG(FATAL) << "Only support sum pool yet ";
}
fp_h_[0] = 1.f;
this->genCode();
}
virtual const char* name() const {
std::string base = "SeqPoolJitCode";
if (type_ == SeqPoolType::kSum) {
base += "_Sum";
} else if (type_ == SeqPoolType::kAvg) {
base += "_Avg";
} else if (type_ == SeqPoolType::kSqrt) {
base += "_Sqrt";
}
base += ("_W" + std::to_string(w_));
return base.c_str();
}
void genCode() override;
protected:
template <typename JMM>
void pool_height(int w_offset, int block, int max_num_regs) {
int offset = w_offset;
for (int i = 0; i < max_num_regs; ++i) {
vmovups(JMM(i), ptr[param_src + offset]);
offset += sizeof(float) * block;
}
cmp(reg32_int_h, 1);
Label l_next_h, l_h_done;
jle(l_h_done, T_NEAR);
mov(reg_h_i, 1);
mov(reg_tmp, param_src);
add(reg_tmp, w_ * sizeof(float) + w_offset);
L(l_next_h);
{
mov(reg_ptr_src_i, reg_tmp);
for (int i = 0; i < max_num_regs; ++i) {
vmovups(JMM(i + max_num_regs), ptr[reg_ptr_src_i]);
// sum anyway
vaddps(JMM(i), JMM(i), JMM(i + max_num_regs));
add(reg_ptr_src_i, sizeof(float) * block);
}
inc(reg_h_i);
add(reg_tmp, w_ * sizeof(float));
cmp(reg_h_i, reg32_int_h);
jl(l_next_h, T_NEAR);
}
L(l_h_done);
// save right now
if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
mov(reg_tmp, reinterpret_cast<size_t>(fp_h_));
vbroadcastss(JMM(max_num_regs), ptr[reg_tmp]);
}
offset = w_offset;
for (int i = 0; i < max_num_regs; ++i) {
if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
vmulps(JMM(i), JMM(i), JMM(max_num_regs));
}
vmovups(ptr[param_dst + offset], JMM(i));
offset += sizeof(float) * block;
}
}
void pool_height_of_rest_width(int rest, int w_offset, int max_num_regs) {
const int rest_used_num_regs = load_rest(rest, w_offset, 0);
const bool has_block4 = rest / 4 > 0;
const bool has_block2 = (rest % 4) / 2 > 0;
const bool has_block1 = (rest % 2) == 1;
cmp(reg32_int_h, 1);
Label l_next_h, l_h_done;
jle(l_h_done, T_NEAR);
mov(reg_h_i, 1);
mov(reg_tmp, param_src);
add(reg_tmp, w_ * sizeof(float) + w_offset);
L(l_next_h);
{
int reg_idx = 0;
mov(reg_ptr_src_i, reg_tmp);
if (has_block4) {
vmovups(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]);
add(reg_ptr_src_i, sizeof(float) * 4);
reg_idx++;
}
if (has_block2) {
vmovups(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]);
add(reg_ptr_src_i, sizeof(float) * 2);
reg_idx++;
}
if (has_block1) {
vmovss(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]);
reg_idx++;
}
PADDLE_ENFORCE_EQ(reg_idx, rest_used_num_regs,
"All heights should use same regs");
for (int i = 0; i < reg_idx; ++i) {
vaddps(xmm_t(i), xmm_t(i), xmm_t(i + max_num_regs));
}
inc(reg_h_i);
add(reg_tmp, w_ * sizeof(float));
cmp(reg_h_i, reg32_int_h);
jl(l_next_h, T_NEAR);
}
L(l_h_done);
// save right now
if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
mov(reg_tmp, reinterpret_cast<size_t>(fp_h_));
vbroadcastss(xmm_t(max_num_regs), ptr[reg_tmp]);
for (int i = 0; i < rest_used_num_regs; ++i) {
vmulps(xmm_t(i), xmm_t(i), xmm_t(max_num_regs));
}
}
save_rest(rest, w_offset);
}
// return the number of used regs, use start from reg 0
int load_rest(int rest, int w_offset, const int num_shift_regs,
const int reg_start = 0) {
const bool has_block4 = rest / 4 > 0;
const bool has_block2 = (rest % 4) / 2 > 0;
const bool has_block1 = (rest % 2) == 1;
int reg_idx = reg_start;
if (has_block4) {
vmovups(xmm_t(reg_idx + num_shift_regs), ptr[param_src + w_offset]);
w_offset += sizeof(float) * 4;
reg_idx++;
}
if (has_block2) {
vmovq(xmm_t(reg_idx + num_shift_regs), ptr[param_src + w_offset]);
w_offset += sizeof(float) * 2;
reg_idx++;
}
if (has_block1) {
vmovss(xmm_t(reg_idx + num_shift_regs), ptr[param_src + w_offset]);
reg_idx++;
}
return reg_idx;
}
// use reg start from 0
void save_rest(int rest, int w_offset, int reg_start = 0) {
const bool has_block4 = rest / 4 > 0;
const bool has_block2 = (rest % 4) / 2 > 0;
const bool has_block1 = (rest % 2) == 1;
int reg_idx = reg_start;
if (has_block4) {
vmovups(ptr[param_dst + w_offset], xmm_t(reg_idx));
w_offset += sizeof(float) * 4;
reg_idx++;
}
if (has_block2) {
vmovq(ptr[param_dst + w_offset], xmm_t(reg_idx));
w_offset += sizeof(float) * 2;
reg_idx++;
}
if (has_block1) {
vmovss(ptr[param_dst + w_offset], xmm_t(reg_idx));
}
}
private:
float ALIGN32_BEG fp_h_[1] ALIGN32_END;
int w_;
SeqPoolType type_;
reg64_t param_src{abi_param1};
reg64_t param_dst{abi_param2};
reg64_t param_attr{abi_param3};
reg64_t reg_tmp{rax};
reg32_t reg32_int_h{r8d};
reg32_t reg32_fp_h{r9d};
reg64_t reg_h_i{r10};
reg64_t reg_ptr_src_i{r11};
};
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
......@@ -26,6 +26,7 @@ namespace jit {
const char* to_string(KernelType kt) {
switch (kt) {
ONE_CASE(kNone);
ONE_CASE(kVMul);
ONE_CASE(kVAdd);
ONE_CASE(kVAddRelu);
......@@ -45,12 +46,26 @@ const char* to_string(KernelType kt) {
ONE_CASE(kCRFDecoding);
ONE_CASE(kLayerNorm);
ONE_CASE(kNCHW16CMulNC);
ONE_CASE(kSeqPool);
default:
PADDLE_THROW("Not support type: %d, or forget to add it.", kt);
return "NOT JITKernel";
}
return nullptr;
}
const char* to_string(SeqPoolType tp) {
switch (tp) {
ONE_CASE(kNonePoolType);
ONE_CASE(kSum);
ONE_CASE(kAvg);
ONE_CASE(kSqrt);
default:
PADDLE_THROW("Not support type: %d, or forget to add it.", tp);
return "NOT PoolType";
}
return nullptr;
}
#undef ONE_CASE
KernelType to_kerneltype(const std::string& act) {
......
......@@ -119,6 +119,7 @@ typename KernelTuples::func_type Get(
}
const char* to_string(KernelType kt);
const char* to_string(SeqPoolType kt);
KernelType to_kerneltype(const std::string& act);
......@@ -134,6 +135,11 @@ inline std::ostream& operator<<(std::ostream& os, const gru_attr_t& attr) {
<< "],act_cand[" << to_string(attr.act_cand) << "]";
return os;
}
inline std::ostream& operator<<(std::ostream& os, const seq_pool_attr_t& attr) {
os << "height_size[" << attr.h << "],width_size[" << attr.w << "],pool_type["
<< to_string(attr.type) << "]";
return os;
}
} // namespace jit
} // namespace operators
......
......@@ -41,8 +41,16 @@ typedef enum {
kCRFDecoding,
kLayerNorm,
kNCHW16CMulNC,
kSeqPool,
} KernelType;
typedef enum {
kNonePoolType = 0,
kSum = 1,
kAvg,
kSqrt,
} SeqPoolType;
template <typename T>
struct XYZNTuples {
typedef T data_type;
......@@ -112,6 +120,21 @@ struct GRUTuples {
typedef void (*func_type)(gru_t*, const gru_attr_t*);
};
typedef struct seq_pool_attr_s {
int h, w; // h should always be the first one
SeqPoolType type;
seq_pool_attr_s() = default;
explicit seq_pool_attr_s(int width, SeqPoolType pool_type, int height = 1)
: h(height), w(width), type(pool_type) {}
} seq_pool_attr_t;
template <typename T>
struct SeqPoolTuples {
typedef T data_type;
typedef seq_pool_attr_t attr_type;
typedef void (*func_type)(const T*, T*, const seq_pool_attr_t*);
};
template <typename T>
struct CRFDecodingTuples {
typedef T data_type;
......
......@@ -42,6 +42,13 @@ size_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) {
(static_cast<int>(attr.act_cand) << act_type_shift);
}
template <>
size_t JitCodeKey<seq_pool_attr_t>(const seq_pool_attr_t& attr) {
size_t key = attr.w;
constexpr int pool_type_shift = 3;
return (key << pool_type_shift) + static_cast<int>(attr.type);
}
} // namespace jit
} // namespace operators
} // namespace paddle
......@@ -9,3 +9,4 @@ USE_JITKERNEL_MORE(kVScal, mkl)
USE_JITKERNEL_MORE(kVExp, mkl)
USE_JITKERNEL_MORE(kVSigmoid, mkl)
USE_JITKERNEL_MORE(kVTanh, mkl)
USE_JITKERNEL_MORE(kSeqPool, mkl)
......@@ -72,6 +72,26 @@ void VExp<double>(const double* x, double* y, int n) {
platform::dynload::vdExp(n, x, y);
}
template <>
void VCopy<float>(const float* x, float* y, int n) {
platform::dynload::cblas_scopy(n, x, 1, y, 1);
}
template <>
void VCopy<double>(const double* x, double* y, int n) {
platform::dynload::cblas_dcopy(n, x, 1, y, 1);
}
template <>
void VAXPY<float>(float a, const float* x, float* y, int n) {
platform::dynload::cblas_saxpy(n, a, x, 1, y, 1);
}
template <>
void VAXPY<double>(double a, const double* x, double* y, int n) {
platform::dynload::cblas_daxpy(n, a, x, 1, y, 1);
}
// TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512
template <>
bool VMulKernel<float>::UseMe(const int& d) const {
......@@ -103,6 +123,16 @@ bool VTanhKernel<float>::UseMe(const int& d) const {
return d > 7;
}
template <>
bool SeqPoolKernel<float>::UseMe(const seq_pool_attr_t& attr) const {
return true;
}
template <>
bool SeqPoolKernel<double>::UseMe(const seq_pool_attr_t& attr) const {
return true;
}
#define AWALYS_USE_ME_WITH_DOUBLE(func) \
template <> \
bool func##Kernel<double>::UseMe(const int& d) const { \
......@@ -135,5 +165,6 @@ REGISTER_MKL_KERNEL(kVScal, VScal);
REGISTER_MKL_KERNEL(kVExp, VExp);
REGISTER_MKL_KERNEL(kVSigmoid, VSigmoid);
REGISTER_MKL_KERNEL(kVTanh, VTanh);
REGISTER_MKL_KERNEL(kSeqPool, SeqPool);
#undef REGISTER_MKL_KERNEL
......@@ -14,6 +14,7 @@
#pragma once
#include <cmath>
#include <type_traits>
#include "paddle/fluid/operators/jit/kernel_base.h"
......@@ -35,6 +36,12 @@ void VScal(const T* a, const T* x, T* y, int n);
template <typename T>
void VExp(const T* x, T* y, int n);
template <typename T>
void VCopy(const T* x, T* y, int n);
template <typename T>
void VAXPY(T a, const T* x, T* y, int n);
template <typename T>
void VSigmoid(const T* x, T* y, int n) {
const T min = SIGMOID_THRESHOLD_MIN;
......@@ -60,6 +67,23 @@ void VTanh(const T* x, T* y, int n) {
}
}
template <typename T>
void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) {
VCopy<T>(x, y, attr->w);
for (int h = 1; h != attr->h; ++h) {
VAXPY<T>(static_cast<T>(1), x + h * attr->w, y, attr->w);
}
if (attr->type == SeqPoolType::kAvg || attr->type == SeqPoolType::kSqrt) {
T scalar = static_cast<T>(1);
if (attr->type == SeqPoolType::kAvg) {
scalar = scalar / static_cast<T>(attr->h);
} else {
scalar = scalar / std::sqrt(static_cast<T>(attr->h));
}
VScal<T>(&scalar, y, y, attr->w);
}
}
#define DECLARE_MKL_KERNEL(name, tuples) \
template <typename T> \
class name##Kernel : public KernelMore<tuples<T>> { \
......@@ -81,6 +105,8 @@ DECLARE_MKL_KERNEL(VExp, XYNTuples);
DECLARE_MKL_KERNEL(VSigmoid, XYNTuples);
DECLARE_MKL_KERNEL(VTanh, XYNTuples);
DECLARE_MKL_KERNEL(SeqPool, SeqPoolTuples);
#undef DECLARE_MKL_KERNEL
} // namespace mkl
......
......@@ -26,3 +26,4 @@ USE_JITKERNEL_REFER(kGRUHtPart2)
USE_JITKERNEL_REFER(kCRFDecoding)
USE_JITKERNEL_REFER(kLayerNorm)
USE_JITKERNEL_REFER(kNCHW16CMulNC)
USE_JITKERNEL_REFER(kSeqPool)
......@@ -47,4 +47,6 @@ REGISTER_REFER_KERNEL(kLayerNorm, LayerNorm);
REGISTER_REFER_KERNEL(kNCHW16CMulNC, NCHW16CMulNC);
REGISTER_REFER_KERNEL(kSeqPool, SeqPool);
#undef REGISTER_REFER_KERNEL
......@@ -332,6 +332,28 @@ void NCHW16CMulNC(const T* x, const T* y, T* z, int height, int width) {
}
}
template <typename T>
void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) {
for (int w = 0; w < attr->w; ++w) {
const T* src = x + w;
T* dst = y + w;
*dst = static_cast<T>(0);
for (int h = 0; h < attr->h; ++h) {
*dst = *dst + *src;
src += attr->w;
}
}
if (attr->type == SeqPoolType::kAvg || attr->type == SeqPoolType::kSqrt) {
T scalar = static_cast<T>(1);
if (attr->type == SeqPoolType::kAvg) {
scalar = scalar / static_cast<T>(attr->h);
} else {
scalar = scalar / std::sqrt(static_cast<T>(attr->h));
}
VScal<T>(&scalar, y, y, attr->w);
}
}
#define DECLARE_REFER_KERNEL(name, tuples) \
template <typename T> \
class name##Kernel : public ReferKernel<tuples<T>> { \
......@@ -370,6 +392,8 @@ DECLARE_REFER_KERNEL(LayerNorm, LayerNormTuples);
DECLARE_REFER_KERNEL(NCHW16CMulNC, NCHW16CMulNCTuples);
DECLARE_REFER_KERNEL(SeqPool, SeqPoolTuples);
#undef DECLARE_REFER_KERNEL
} // namespace refer
......
......@@ -211,6 +211,24 @@ struct TestFuncWithRefer<jit::GRUTuples<T>, std::vector<T>, std::vector<T>,
}
};
template <typename T>
struct TestFuncWithRefer<jit::SeqPoolTuples<T>, std::vector<T>,
std::vector<T>> {
void operator()(const typename jit::SeqPoolTuples<T>::func_type tgt,
const std::vector<T>& x, const std::vector<T>& yref,
const typename jit::SeqPoolTuples<T>::attr_type& attr) {
EXPECT_TRUE(tgt != nullptr);
EXPECT_EQ(x.size() % yref.size(), 0);
int w = yref.size();
std::vector<T> y(w);
const T* x_data = x.data();
const T* yref_data = yref.data();
T* y_data = y.data();
tgt(x_data, y_data, &attr);
ExpectEQ<T>(y_data, yref_data, w);
}
};
template <paddle::operators::jit::KernelType KT, typename KernelTuples,
typename PlaceType, typename... Args>
void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
......@@ -415,6 +433,31 @@ void TestGRUKernel() {
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void TestSeqPoolKernel() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
std::vector<jit::SeqPoolType> pool_types = {
jit::SeqPoolType::kSum, jit::SeqPoolType::kAvg, jit::SeqPoolType::kSqrt};
for (auto type : pool_types) {
for (int w : TestSizes()) {
jit::seq_pool_attr_t attr(w, type);
for (int h : TestSizes()) {
attr.h = h;
auto ref = jit::GetRefer<KT, jit::SeqPoolTuples<T>>();
EXPECT_TRUE(ref != nullptr);
std::vector<T> x(h * w), yref(w);
RandomVec<T>(h * w, x.data(), -2.f, 2.f);
const T* x_data = x.data();
T* yref_data = yref.data();
ref(x_data, yref_data, &attr);
VLOG(10) << attr;
TestAllImpls<KT, jit::SeqPoolTuples<T>, PlaceType, std::vector<T>,
std::vector<T>>(attr, x, yref, attr);
}
}
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void TestNCHW16CMulNCKernel() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
......@@ -569,6 +612,12 @@ TEST(JITKernel, kGRUHtPart2) {
TestGRUKernel<jit::kGRUHtPart2, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, kSeqPool) {
namespace jit = paddle::operators::jit;
TestSeqPoolKernel<jit::kSeqPool, float, paddle::platform::CPUPlace>();
TestSeqPoolKernel<jit::kSeqPool, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, kNCHW16CMulNC) {
namespace jit = paddle::operators::jit;
TestNCHW16CMulNCKernel<jit::kNCHW16CMulNC, float,
......
......@@ -51,7 +51,7 @@ math_library(pooling)
math_library(selected_rows_functor DEPS selected_rows math_function blas)
math_library(sequence2batch)
math_library(sequence_padding)
math_library(sequence_pooling DEPS math_function)
math_library(sequence_pooling DEPS math_function jit_kernel_helper)
math_library(sequence_scale)
math_library(softmax DEPS math_function)
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/sequence_pooling.h"
......@@ -239,15 +240,33 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
last_pool(context, input, output);
return;
}
if (pooltype == "FIRST") {
math::FirstSeqPoolFunctor<T> first_pool;
first_pool(context, input, output);
return;
}
auto lod = input.lod()[0];
if (pooltype == "SUM") {
auto place = context.GetPlace();
PADDLE_ENFORCE(platform::is_cpu_place(place));
const T* src = input.data<T>();
T* dst = output->mutable_data<T>(place);
jit::seq_pool_attr_t attr(
static_cast<int>(input.numel() / input.dims()[0]),
jit::SeqPoolType::kSum);
auto seqpool =
jit::Get<jit::kSeqPool, jit::SeqPoolTuples<T>, platform::CPUPlace>(
attr);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
attr.h = static_cast<int>(lod[i + 1] - lod[i]);
seqpool(src, dst, &attr);
dst += attr.w;
src += attr.h * attr.w;
}
return;
}
auto& place = *context.eigen_device();
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
Tensor in_t =
input.Slice(static_cast<int>(lod[i]), static_cast<int>(lod[i + 1]));
......@@ -258,15 +277,6 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
auto out_e = EigenVector<T>::Flatten(out_t);
if (pooltype == "AVERAGE") {
out_e.device(place) = in_e.mean(Eigen::array<int, 1>({{0}}));
} else if (pooltype == "SUM") {
if (h > 0) {
const T* in_data = in_t.data<T>();
T* out_data = out_t.mutable_data<T>(context.GetPlace());
blas.VCOPY(w, in_data, out_data);
for (int64_t r = 1; r != h; ++r) {
blas.AXPY(w, 1., in_data + r * w, out_data);
}
}
} else if (pooltype == "SQRT") {
out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}})) /
std::sqrt(static_cast<T>(h));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册