提交 142bb417 编写于 作者: T tensor-tang

add seqpool jitkernel test and benchmark

上级 e58a569c
...@@ -190,6 +190,24 @@ void BenchGRUKernel() { ...@@ -190,6 +190,24 @@ void BenchGRUKernel() {
} }
} }
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void BenchSeqPoolKernel() {
std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum};
for (auto type : pool_types) {
for (int h : TestSizes()) {
for (int w : TestSizes()) {
const jit::seq_pool_attr_t attr(h, w, type);
std::vector<T> x(h * w), y(w);
RandomVec<T>(h * w, x.data(), -2.f, 2.f);
const T* x_data = x.data();
T* y_data = y.data();
BenchAllImpls<KT, jit::SeqPoolTuples<T>, PlaceType>(attr, x_data,
y_data, &attr);
}
}
}
}
// Benchmark all jit kernels including jitcode, mkl and refer. // Benchmark all jit kernels including jitcode, mkl and refer.
// To use this tool, run command: ./benchmark [options...] // To use this tool, run command: ./benchmark [options...]
// Options: // Options:
...@@ -228,4 +246,7 @@ int main(int argc, char* argv[]) { ...@@ -228,4 +246,7 @@ int main(int argc, char* argv[]) {
BenchGRUKernel<jit::kGRUH1, T, PlaceType>(); BenchGRUKernel<jit::kGRUH1, T, PlaceType>();
BenchGRUKernel<jit::kGRUHtPart1, T, PlaceType>(); BenchGRUKernel<jit::kGRUHtPart1, T, PlaceType>();
BenchGRUKernel<jit::kGRUHtPart2, T, PlaceType>(); BenchGRUKernel<jit::kGRUHtPart2, T, PlaceType>();
// seq pool function
BenchSeqPoolKernel<jit::kSeqPool, T, PlaceType>();
} }
...@@ -26,6 +26,7 @@ namespace jit { ...@@ -26,6 +26,7 @@ namespace jit {
const char* to_string(KernelType kt) { const char* to_string(KernelType kt) {
switch (kt) { switch (kt) {
ONE_CASE(kNone);
ONE_CASE(kVMul); ONE_CASE(kVMul);
ONE_CASE(kVAdd); ONE_CASE(kVAdd);
ONE_CASE(kVAddRelu); ONE_CASE(kVAddRelu);
...@@ -45,12 +46,26 @@ const char* to_string(KernelType kt) { ...@@ -45,12 +46,26 @@ const char* to_string(KernelType kt) {
ONE_CASE(kCRFDecoding); ONE_CASE(kCRFDecoding);
ONE_CASE(kLayerNorm); ONE_CASE(kLayerNorm);
ONE_CASE(kNCHW16CMulNC); ONE_CASE(kNCHW16CMulNC);
ONE_CASE(kSeqPool);
default: default:
PADDLE_THROW("Not support type: %d, or forget to add it.", kt); PADDLE_THROW("Not support type: %d, or forget to add it.", kt);
return "NOT JITKernel"; return "NOT JITKernel";
} }
return nullptr; return nullptr;
} }
const char* to_string(SeqPoolType tp) {
switch (tp) {
ONE_CASE(kNonePoolType);
ONE_CASE(kSum);
ONE_CASE(kAvg);
ONE_CASE(kSqrt);
default:
PADDLE_THROW("Not support type: %d, or forget to add it.", tp);
return "NOT PoolType";
}
return nullptr;
}
#undef ONE_CASE #undef ONE_CASE
KernelType to_kerneltype(const std::string& act) { KernelType to_kerneltype(const std::string& act) {
......
...@@ -119,6 +119,7 @@ typename KernelTuples::func_type Get( ...@@ -119,6 +119,7 @@ typename KernelTuples::func_type Get(
} }
const char* to_string(KernelType kt); const char* to_string(KernelType kt);
const char* to_string(SeqPoolType kt);
KernelType to_kerneltype(const std::string& act); KernelType to_kerneltype(const std::string& act);
...@@ -134,6 +135,11 @@ inline std::ostream& operator<<(std::ostream& os, const gru_attr_t& attr) { ...@@ -134,6 +135,11 @@ inline std::ostream& operator<<(std::ostream& os, const gru_attr_t& attr) {
<< "],act_cand[" << to_string(attr.act_cand) << "]"; << "],act_cand[" << to_string(attr.act_cand) << "]";
return os; return os;
} }
inline std::ostream& operator<<(std::ostream& os, const seq_pool_attr_t& attr) {
os << "height_size[" << attr.h << "],width_size[" << attr.w << "],pool_type["
<< to_string(attr.type) << "]";
return os;
}
} // namespace jit } // namespace jit
} // namespace operators } // namespace operators
......
...@@ -44,6 +44,13 @@ typedef enum { ...@@ -44,6 +44,13 @@ typedef enum {
kSeqPool, kSeqPool,
} KernelType; } KernelType;
typedef enum {
kNonePoolType = 0,
kSum,
kAvg,
kSqrt,
} SeqPoolType;
template <typename T> template <typename T>
struct XYZNTuples { struct XYZNTuples {
typedef T data_type; typedef T data_type;
...@@ -113,16 +120,12 @@ struct GRUTuples { ...@@ -113,16 +120,12 @@ struct GRUTuples {
typedef void (*func_type)(gru_t*, const gru_attr_t*); typedef void (*func_type)(gru_t*, const gru_attr_t*);
}; };
typedef enum { typedef struct seq_pool_attr_s {
non = 0,
sum,
avg,
sqrt,
} SeqPoolType;
typedef struct {
int h, w; int h, w;
SeqPoolType type; SeqPoolType type;
seq_pool_attr_s() = default;
explicit seq_pool_attr_s(int height, int width, SeqPoolType pool_type)
: h(height), w(width), type(pool_type) {}
} seq_pool_attr_t; } seq_pool_attr_t;
template <typename T> template <typename T>
......
...@@ -334,7 +334,7 @@ void NCHW16CMulNC(const T* x, const T* y, T* z, int height, int width) { ...@@ -334,7 +334,7 @@ void NCHW16CMulNC(const T* x, const T* y, T* z, int height, int width) {
template <typename T> template <typename T>
void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) { void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) {
PADDLE_ENFORCE(attr->type == SeqPoolType::sum, "Only support sum yet"); PADDLE_ENFORCE(attr->type == SeqPoolType::kSum, "Only support sum yet");
for (int w = 0; w < attr->w; ++w) { for (int w = 0; w < attr->w; ++w) {
const T* src = x + w; const T* src = x + w;
T* dst = y + w; T* dst = y + w;
......
...@@ -211,6 +211,24 @@ struct TestFuncWithRefer<jit::GRUTuples<T>, std::vector<T>, std::vector<T>, ...@@ -211,6 +211,24 @@ struct TestFuncWithRefer<jit::GRUTuples<T>, std::vector<T>, std::vector<T>,
} }
}; };
template <typename T>
struct TestFuncWithRefer<jit::SeqPoolTuples<T>, std::vector<T>,
std::vector<T>> {
void operator()(const typename jit::SeqPoolTuples<T>::func_type tgt,
const std::vector<T>& x, const std::vector<T>& yref,
const typename jit::SeqPoolTuples<T>::attr_type& attr) {
EXPECT_TRUE(tgt != nullptr);
EXPECT_EQ(x.size() % yref.size(), 0);
int w = yref.size();
std::vector<T> y(w);
const T* x_data = x.data();
const T* yref_data = yref.data();
T* y_data = y.data();
tgt(x_data, y_data, &attr);
ExpectEQ<T>(y_data, yref_data, w);
}
};
template <paddle::operators::jit::KernelType KT, typename KernelTuples, template <paddle::operators::jit::KernelType KT, typename KernelTuples,
typename PlaceType, typename... Args> typename PlaceType, typename... Args>
void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) { void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
...@@ -415,6 +433,30 @@ void TestGRUKernel() { ...@@ -415,6 +433,30 @@ void TestGRUKernel() {
} }
} }
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void TestSeqPoolKernel() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
// TODO(TJ): support more
std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum};
for (auto type : pool_types) {
for (int h : TestSizes()) {
for (int w : TestSizes()) {
const jit::seq_pool_attr_t attr(h, w, type);
auto ref = jit::GetRefer<KT, jit::SeqPoolTuples<T>>();
EXPECT_TRUE(ref != nullptr);
std::vector<T> x(h * w), yref(w);
RandomVec<T>(h * w, x.data(), -2.f, 2.f);
const T* x_data = x.data();
T* yref_data = yref.data();
ref(x_data, yref_data, &attr);
VLOG(10) << attr;
TestAllImpls<KT, jit::SeqPoolTuples<T>, PlaceType, std::vector<T>,
std::vector<T>>(attr, x, yref, attr);
}
}
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void TestNCHW16CMulNCKernel() { void TestNCHW16CMulNCKernel() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
...@@ -569,6 +611,12 @@ TEST(JITKernel, kGRUHtPart2) { ...@@ -569,6 +611,12 @@ TEST(JITKernel, kGRUHtPart2) {
TestGRUKernel<jit::kGRUHtPart2, double, paddle::platform::CPUPlace>(); TestGRUKernel<jit::kGRUHtPart2, double, paddle::platform::CPUPlace>();
} }
TEST(JITKernel, kSeqPool) {
namespace jit = paddle::operators::jit;
TestSeqPoolKernel<jit::kSeqPool, float, paddle::platform::CPUPlace>();
TestSeqPoolKernel<jit::kSeqPool, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, kNCHW16CMulNC) { TEST(JITKernel, kNCHW16CMulNC) {
namespace jit = paddle::operators::jit; namespace jit = paddle::operators::jit;
TestNCHW16CMulNCKernel<jit::kNCHW16CMulNC, float, TestNCHW16CMulNCKernel<jit::kNCHW16CMulNC, float,
......
...@@ -254,7 +254,7 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> { ...@@ -254,7 +254,7 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
T* dst = output->mutable_data<T>(place); T* dst = output->mutable_data<T>(place);
jit::seq_pool_attr_t attr; jit::seq_pool_attr_t attr;
attr.w = input.numel() / input.dims()[0]; attr.w = input.numel() / input.dims()[0];
attr.type = jit::SeqPoolType::sum; attr.type = jit::SeqPoolType::kSum;
auto seqpool = auto seqpool =
jit::Get<jit::kSeqPool, jit::SeqPoolTuples<T>, platform::CPUPlace>( jit::Get<jit::kSeqPool, jit::SeqPoolTuples<T>, platform::CPUPlace>(
attr); attr);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册