提交 142bb417 编写于 作者: T tensor-tang

add seqpool jitkernel test and benchmark

上级 e58a569c
......@@ -190,6 +190,24 @@ void BenchGRUKernel() {
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void BenchSeqPoolKernel() {
std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum};
for (auto type : pool_types) {
for (int h : TestSizes()) {
for (int w : TestSizes()) {
const jit::seq_pool_attr_t attr(h, w, type);
std::vector<T> x(h * w), y(w);
RandomVec<T>(h * w, x.data(), -2.f, 2.f);
const T* x_data = x.data();
T* y_data = y.data();
BenchAllImpls<KT, jit::SeqPoolTuples<T>, PlaceType>(attr, x_data,
y_data, &attr);
}
}
}
}
// Benchmark all jit kernels including jitcode, mkl and refer.
// To use this tool, run command: ./benchmark [options...]
// Options:
......@@ -228,4 +246,7 @@ int main(int argc, char* argv[]) {
BenchGRUKernel<jit::kGRUH1, T, PlaceType>();
BenchGRUKernel<jit::kGRUHtPart1, T, PlaceType>();
BenchGRUKernel<jit::kGRUHtPart2, T, PlaceType>();
// seq pool function
BenchSeqPoolKernel<jit::kSeqPool, T, PlaceType>();
}
......@@ -26,6 +26,7 @@ namespace jit {
const char* to_string(KernelType kt) {
switch (kt) {
ONE_CASE(kNone);
ONE_CASE(kVMul);
ONE_CASE(kVAdd);
ONE_CASE(kVAddRelu);
......@@ -45,12 +46,26 @@ const char* to_string(KernelType kt) {
ONE_CASE(kCRFDecoding);
ONE_CASE(kLayerNorm);
ONE_CASE(kNCHW16CMulNC);
ONE_CASE(kSeqPool);
default:
PADDLE_THROW("Not support type: %d, or forget to add it.", kt);
return "NOT JITKernel";
}
return nullptr;
}
const char* to_string(SeqPoolType tp) {
switch (tp) {
ONE_CASE(kNonePoolType);
ONE_CASE(kSum);
ONE_CASE(kAvg);
ONE_CASE(kSqrt);
default:
PADDLE_THROW("Not support type: %d, or forget to add it.", tp);
return "NOT PoolType";
}
return nullptr;
}
#undef ONE_CASE
KernelType to_kerneltype(const std::string& act) {
......
......@@ -119,6 +119,7 @@ typename KernelTuples::func_type Get(
}
const char* to_string(KernelType kt);
const char* to_string(SeqPoolType kt);
KernelType to_kerneltype(const std::string& act);
......@@ -134,6 +135,11 @@ inline std::ostream& operator<<(std::ostream& os, const gru_attr_t& attr) {
<< "],act_cand[" << to_string(attr.act_cand) << "]";
return os;
}
inline std::ostream& operator<<(std::ostream& os, const seq_pool_attr_t& attr) {
os << "height_size[" << attr.h << "],width_size[" << attr.w << "],pool_type["
<< to_string(attr.type) << "]";
return os;
}
} // namespace jit
} // namespace operators
......
......@@ -44,6 +44,13 @@ typedef enum {
kSeqPool,
} KernelType;
typedef enum {
kNonePoolType = 0,
kSum,
kAvg,
kSqrt,
} SeqPoolType;
template <typename T>
struct XYZNTuples {
typedef T data_type;
......@@ -113,16 +120,12 @@ struct GRUTuples {
typedef void (*func_type)(gru_t*, const gru_attr_t*);
};
typedef enum {
non = 0,
sum,
avg,
sqrt,
} SeqPoolType;
typedef struct {
typedef struct seq_pool_attr_s {
int h, w;
SeqPoolType type;
seq_pool_attr_s() = default;
explicit seq_pool_attr_s(int height, int width, SeqPoolType pool_type)
: h(height), w(width), type(pool_type) {}
} seq_pool_attr_t;
template <typename T>
......
......@@ -334,7 +334,7 @@ void NCHW16CMulNC(const T* x, const T* y, T* z, int height, int width) {
template <typename T>
void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) {
PADDLE_ENFORCE(attr->type == SeqPoolType::sum, "Only support sum yet");
PADDLE_ENFORCE(attr->type == SeqPoolType::kSum, "Only support sum yet");
for (int w = 0; w < attr->w; ++w) {
const T* src = x + w;
T* dst = y + w;
......
......@@ -211,6 +211,24 @@ struct TestFuncWithRefer<jit::GRUTuples<T>, std::vector<T>, std::vector<T>,
}
};
template <typename T>
struct TestFuncWithRefer<jit::SeqPoolTuples<T>, std::vector<T>,
std::vector<T>> {
void operator()(const typename jit::SeqPoolTuples<T>::func_type tgt,
const std::vector<T>& x, const std::vector<T>& yref,
const typename jit::SeqPoolTuples<T>::attr_type& attr) {
EXPECT_TRUE(tgt != nullptr);
EXPECT_EQ(x.size() % yref.size(), 0);
int w = yref.size();
std::vector<T> y(w);
const T* x_data = x.data();
const T* yref_data = yref.data();
T* y_data = y.data();
tgt(x_data, y_data, &attr);
ExpectEQ<T>(y_data, yref_data, w);
}
};
template <paddle::operators::jit::KernelType KT, typename KernelTuples,
typename PlaceType, typename... Args>
void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
......@@ -415,6 +433,30 @@ void TestGRUKernel() {
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void TestSeqPoolKernel() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
// TODO(TJ): support more
std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum};
for (auto type : pool_types) {
for (int h : TestSizes()) {
for (int w : TestSizes()) {
const jit::seq_pool_attr_t attr(h, w, type);
auto ref = jit::GetRefer<KT, jit::SeqPoolTuples<T>>();
EXPECT_TRUE(ref != nullptr);
std::vector<T> x(h * w), yref(w);
RandomVec<T>(h * w, x.data(), -2.f, 2.f);
const T* x_data = x.data();
T* yref_data = yref.data();
ref(x_data, yref_data, &attr);
VLOG(10) << attr;
TestAllImpls<KT, jit::SeqPoolTuples<T>, PlaceType, std::vector<T>,
std::vector<T>>(attr, x, yref, attr);
}
}
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void TestNCHW16CMulNCKernel() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
......@@ -569,6 +611,12 @@ TEST(JITKernel, kGRUHtPart2) {
TestGRUKernel<jit::kGRUHtPart2, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, kSeqPool) {
namespace jit = paddle::operators::jit;
TestSeqPoolKernel<jit::kSeqPool, float, paddle::platform::CPUPlace>();
TestSeqPoolKernel<jit::kSeqPool, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, kNCHW16CMulNC) {
namespace jit = paddle::operators::jit;
TestNCHW16CMulNCKernel<jit::kNCHW16CMulNC, float,
......
......@@ -254,7 +254,7 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
T* dst = output->mutable_data<T>(place);
jit::seq_pool_attr_t attr;
attr.w = input.numel() / input.dims()[0];
attr.type = jit::SeqPoolType::sum;
attr.type = jit::SeqPoolType::kSum;
auto seqpool =
jit::Get<jit::kSeqPool, jit::SeqPoolTuples<T>, platform::CPUPlace>(
attr);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册