提交 15da2f9a 编写于 作者: T tensor-tang

add embseqpool jitkernel refer code, test and benchmark

test=develop
上级 c2ccf145
...@@ -301,6 +301,37 @@ void BenchSeqPoolKernel() { ...@@ -301,6 +301,37 @@ void BenchSeqPoolKernel() {
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchEmbSeqPoolKernel() {
std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum};
int64_t tbl_h = 1e4;
for (int tbl_w : {10, 16, 256}) {
Tensor table;
table.Resize({tbl_h, tbl_w});
RandomVec<T>(tbl_h * tbl_w, table.mutable_data<T>(PlaceType()), -2.f, 2.f);
const T* table_data = table.data<T>();
for (auto type : pool_types) {
for (int idx_w : {1, 2, 10, 16}) {
for (int idx_h : {1, 2, 10, 16}) {
int64_t out_w = tbl_w * idx_w;
jit::emb_seq_pool_attr_t attr(tbl_h, tbl_w, idx_h, idx_w, out_w,
type);
Tensor idx, out;
idx.Resize({idx_h, idx_w});
out.Resize({out_w});
RandomVec<int64_t>(idx_h * idx_w,
idx.mutable_data<int64_t>(PlaceType()), 0,
tbl_h - 1);
const int64_t* idx_data = idx.data<int64_t>();
T* o_data = out.mutable_data<T>(PlaceType());
BenchAllImpls<KT, jit::EmbSeqPoolTuples<T>, PlaceType>(
attr, table_data, idx_data, o_data, &attr);
}
}
}
}
}
template <jit::KernelType KT, typename T, typename PlaceType> template <jit::KernelType KT, typename T, typename PlaceType>
void BenchMatMulKernel() { void BenchMatMulKernel() {
for (int m : {1, 2, 3, 4}) { for (int m : {1, 2, 3, 4}) {
...@@ -376,6 +407,11 @@ BENCH_FP32_CPU(kGRUHtPart2) { BenchGRUKernel<jit::kGRUHtPart2, T, CPUPlace>(); } ...@@ -376,6 +407,11 @@ BENCH_FP32_CPU(kGRUHtPart2) { BenchGRUKernel<jit::kGRUHtPart2, T, CPUPlace>(); }
// seq pool function // seq pool function
BENCH_FP32_CPU(kSeqPool) { BenchSeqPoolKernel<jit::kSeqPool, T, CPUPlace>(); } BENCH_FP32_CPU(kSeqPool) { BenchSeqPoolKernel<jit::kSeqPool, T, CPUPlace>(); }
// embedding seq pool function
BENCH_FP32_CPU(kEmbSeqPool) {
BenchEmbSeqPoolKernel<jit::kEmbSeqPool, T, CPUPlace>();
}
// matmul // matmul
BENCH_FP32_CPU(kMatMul) { BenchMatMulKernel<jit::kMatMul, T, CPUPlace>(); } BENCH_FP32_CPU(kMatMul) { BenchMatMulKernel<jit::kMatMul, T, CPUPlace>(); }
......
...@@ -54,6 +54,7 @@ const char* to_string(KernelType kt) { ...@@ -54,6 +54,7 @@ const char* to_string(KernelType kt) {
ONE_CASE(kHMax); ONE_CASE(kHMax);
ONE_CASE(kHSum); ONE_CASE(kHSum);
ONE_CASE(kSoftmax); ONE_CASE(kSoftmax);
ONE_CASE(kEmbSeqPool);
default: default:
PADDLE_THROW("Not support type: %d, or forget to add it.", kt); PADDLE_THROW("Not support type: %d, or forget to add it.", kt);
return "NOT JITKernel"; return "NOT JITKernel";
......
...@@ -172,6 +172,15 @@ inline std::ostream& operator<<(std::ostream& os, const seq_pool_attr_t& attr) { ...@@ -172,6 +172,15 @@ inline std::ostream& operator<<(std::ostream& os, const seq_pool_attr_t& attr) {
return os; return os;
} }
inline std::ostream& operator<<(std::ostream& os,
const emb_seq_pool_attr_t& attr) {
os << "table_height[" << attr.table_height << "],table_width["
<< attr.table_width << "],index_height[" << attr.index_height
<< "],index_width[" << attr.index_width << "],output_width["
<< attr.out_width << "],pool_type[" << to_string(attr.pool_type) << "]";
return os;
}
inline std::ostream& operator<<(std::ostream& os, const matmul_attr_t& attr) { inline std::ostream& operator<<(std::ostream& os, const matmul_attr_t& attr) {
os << "M[" << attr.m << "],N[" << attr.n << "],K[" << attr.k << "]"; os << "M[" << attr.m << "],N[" << attr.n << "],K[" << attr.k << "]";
return os; return os;
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
* limitations under the License. */ * limitations under the License. */
#pragma once #pragma once
#include <cstdint>
#include "paddle/fluid/operators/jit/macro.h" #include "paddle/fluid/operators/jit/macro.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
...@@ -20,34 +21,35 @@ namespace paddle { ...@@ -20,34 +21,35 @@ namespace paddle {
namespace operators { namespace operators {
namespace jit { namespace jit {
// TODO(TJ): reorder by alphabet
typedef enum { typedef enum {
kNone = 0, kNone = 0,
kVMul = 1, // sort by alphabet
kVAdd = 2, kCRFDecoding = 1,
kVAddRelu, kEmbSeqPool = 2,
kVSub,
kVScal,
kVAddBias,
kVRelu,
kVIdentity,
kVSquare,
kVExp,
kVSigmoid,
kVTanh,
kLSTMCtHt,
kLSTMC1H1,
kGRUH1, kGRUH1,
kGRUHtPart1, kGRUHtPart1,
kGRUHtPart2, kGRUHtPart2,
kCRFDecoding, kHSum, // horizontal max
kHMax, // horizontal sum
kLSTMCtHt,
kLSTMC1H1,
kLayerNorm, kLayerNorm,
kMatMul,
kNCHW16CMulNC, kNCHW16CMulNC,
kSeqPool, kSeqPool,
kMatMul,
kHSum, // horizontal max
kHMax, // horizontal sum
kSoftmax, kSoftmax,
kVAdd,
kVAddBias,
kVAddRelu,
kVExp,
kVIdentity,
kVMul,
kVRelu,
kVScal,
kVSigmoid,
kVSquare,
kVSub,
kVTanh,
} KernelType; } KernelType;
typedef enum { typedef enum {
...@@ -145,6 +147,32 @@ struct SeqPoolTuples { ...@@ -145,6 +147,32 @@ struct SeqPoolTuples {
typedef void (*func_type)(const T*, T*, const seq_pool_attr_t*); typedef void (*func_type)(const T*, T*, const seq_pool_attr_t*);
}; };
typedef struct emb_seq_pool_attr_s {
int64_t table_height, table_width;
int64_t index_height, index_width;
int64_t out_width;
SeqPoolType pool_type;
emb_seq_pool_attr_s() = default;
explicit emb_seq_pool_attr_s(int64_t tbl_height, int64_t tbl_width,
int64_t idx_height, int64_t idx_width,
int64_t output_width,
SeqPoolType seqpool_type = SeqPoolType::kSum)
: table_height(tbl_height),
table_width(tbl_width),
index_height(idx_height),
index_width(idx_width),
out_width(output_width),
pool_type(seqpool_type) {}
} emb_seq_pool_attr_t;
template <typename T>
struct EmbSeqPoolTuples {
typedef T data_type;
typedef emb_seq_pool_attr_t attr_type;
typedef void (*func_type)(const T*, const int64_t*, T*,
const emb_seq_pool_attr_t*);
};
typedef struct matmul_attr_s { typedef struct matmul_attr_s {
int m, n, k; int m, n, k;
void* packed_weight{nullptr}; void* packed_weight{nullptr};
......
...@@ -56,6 +56,11 @@ size_t JitCodeKey<matmul_attr_t>(const matmul_attr_t& attr) { ...@@ -56,6 +56,11 @@ size_t JitCodeKey<matmul_attr_t>(const matmul_attr_t& attr) {
return (key << shift * 2) + ((static_cast<size_t>(attr.n)) << shift) + attr.k; return (key << shift * 2) + ((static_cast<size_t>(attr.n)) << shift) + attr.k;
} }
template <>
size_t JitCodeKey<emb_seq_pool_attr_t>(const emb_seq_pool_attr_t& attr) {
return attr.table_width;
}
} // namespace jit } // namespace jit
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -32,3 +32,4 @@ USE_JITKERNEL_REFER(kVSquare) ...@@ -32,3 +32,4 @@ USE_JITKERNEL_REFER(kVSquare)
USE_JITKERNEL_REFER(kHSum) USE_JITKERNEL_REFER(kHSum)
USE_JITKERNEL_REFER(kHMax) USE_JITKERNEL_REFER(kHMax)
USE_JITKERNEL_REFER(kSoftmax) USE_JITKERNEL_REFER(kSoftmax)
USE_JITKERNEL_REFER(kEmbSeqPool)
...@@ -57,4 +57,6 @@ REGISTER_REFER_KERNEL(kHSum, HSum); ...@@ -57,4 +57,6 @@ REGISTER_REFER_KERNEL(kHSum, HSum);
REGISTER_REFER_KERNEL(kSoftmax, Softmax); REGISTER_REFER_KERNEL(kSoftmax, Softmax);
REGISTER_REFER_KERNEL(kEmbSeqPool, EmbSeqPool);
#undef REGISTER_REFER_KERNEL #undef REGISTER_REFER_KERNEL
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <cmath> #include <cmath>
#include <limits> #include <limits>
#include <string>
#include "paddle/fluid/operators/jit/helper.h" #include "paddle/fluid/operators/jit/helper.h"
#include "paddle/fluid/operators/jit/kernel_base.h" #include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -414,6 +415,37 @@ void Softmax(const T* x, T* y, int n, int bs = 1) { ...@@ -414,6 +415,37 @@ void Softmax(const T* x, T* y, int n, int bs = 1) {
} }
} }
// embedding seq pool
// table is a matrix with (tbl_h, tbl_w)
// idx is a matrix with (idx_h, idx_w)
// output is a vector with length tbl_w * idx_w
template <typename T>
void EmbSeqPool(const T* table, const int64_t* idx, T* out,
const emb_seq_pool_attr_t* attr) {
PADDLE_ENFORCE_EQ(attr->table_width * attr->index_width, attr->out_width);
auto check_idx_value_valid = [&](int64_t i) {
PADDLE_ENFORCE_LT(idx[i], attr->table_height, "idx value: %d, i: %d",
idx[i], i);
PADDLE_ENFORCE_GE(idx[i], 0, "idx value: %d, i: %d", idx[i], i);
};
for (int64_t w = 0; w != attr->index_width; ++w) {
check_idx_value_valid(w);
std::memcpy(out + w * attr->table_width, table + idx[w] * attr->table_width,
attr->table_width * sizeof(T));
}
for (int64_t h = 1; h < attr->index_height; ++h) {
for (int64_t w = 0; w < attr->index_width; ++w) {
int64_t i = h * attr->index_width + w;
check_idx_value_valid(i);
VAdd(table + idx[i] * attr->table_width, out + w * attr->table_width,
out + w * attr->table_width, attr->table_width);
}
}
}
#define DECLARE_REFER_KERNEL(name, tuples) \ #define DECLARE_REFER_KERNEL(name, tuples) \
template <typename T> \ template <typename T> \
class name##Kernel : public ReferKernel<tuples<T>> { \ class name##Kernel : public ReferKernel<tuples<T>> { \
...@@ -462,6 +494,8 @@ DECLARE_REFER_KERNEL(HSum, XRNTuples); ...@@ -462,6 +494,8 @@ DECLARE_REFER_KERNEL(HSum, XRNTuples);
DECLARE_REFER_KERNEL(Softmax, SoftmaxTuples); DECLARE_REFER_KERNEL(Softmax, SoftmaxTuples);
DECLARE_REFER_KERNEL(EmbSeqPool, EmbSeqPoolTuples);
#undef DECLARE_REFER_KERNEL #undef DECLARE_REFER_KERNEL
} // namespace refer } // namespace refer
......
...@@ -270,6 +270,32 @@ struct TestFuncWithRefer<jit::SeqPoolTuples<T>, std::vector<T>, std::vector<T>, ...@@ -270,6 +270,32 @@ struct TestFuncWithRefer<jit::SeqPoolTuples<T>, std::vector<T>, std::vector<T>,
} }
}; };
template <typename T>
struct TestFuncWithRefer<jit::EmbSeqPoolTuples<T>, std::vector<T>,
std::vector<int64_t>, std::vector<T>,
typename jit::EmbSeqPoolTuples<T>::attr_type> {
void operator()(const typename jit::EmbSeqPoolTuples<T>::func_type tgt,
const std::vector<T>& table, const std::vector<int64_t>& idx,
const std::vector<T>& oref,
const typename jit::EmbSeqPoolTuples<T>::attr_type& attr) {
EXPECT_TRUE(tgt != nullptr);
EXPECT_EQ(table.size(),
static_cast<size_t>(attr.table_height * attr.table_width));
EXPECT_EQ(idx.size(),
static_cast<size_t>(attr.index_height * attr.index_width));
EXPECT_EQ(oref.size(),
static_cast<size_t>(attr.table_width * attr.index_width));
const T* table_data = table.data();
const int64_t* idx_data = idx.data();
const T* oref_data = oref.data();
int o_w = oref.size();
std::vector<T> out(o_w);
T* o_data = out.data();
tgt(table_data, idx_data, o_data, &attr);
ExpectEQ<T>(o_data, oref_data, o_w);
}
};
template <typename T> template <typename T>
struct TestFuncWithRefer<jit::MatMulTuples<T>, std::vector<T>, std::vector<T>, struct TestFuncWithRefer<jit::MatMulTuples<T>, std::vector<T>, std::vector<T>,
std::vector<T>, std::vector<T>,
...@@ -587,6 +613,40 @@ void TestSoftmaxKernel() { ...@@ -587,6 +613,40 @@ void TestSoftmaxKernel() {
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType>
void TestEmbSeqPoolKernel() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
int64_t tbl_h = 1e4;
std::vector<jit::SeqPoolType> pool_types = {
jit::SeqPoolType::kSum}; // only support sum yet
for (int tbl_w : TestSizes()) {
std::vector<T> table(tbl_h * tbl_w);
RandomVec<T>(tbl_h * tbl_w, table.data(), -2.f, 2.f);
const T* table_data = table.data();
for (auto type : pool_types) {
for (int idx_w : {1, 2, 10, 16}) {
for (int idx_h : {1, 2, 10, 16}) {
auto ref = jit::GetRefer<KT, jit::EmbSeqPoolTuples<T>>();
EXPECT_TRUE(ref != nullptr);
std::vector<int64_t> idx(idx_h * idx_w);
RandomVec<int64_t>(idx_h * idx_w, idx.data(), 0, tbl_h - 1);
int64_t out_w = tbl_w * idx_w;
std::vector<T> oref(out_w);
const int64_t* idx_data = idx.data();
T* o_data = oref.data();
jit::emb_seq_pool_attr_t attr(tbl_h, tbl_w, idx_h, idx_w, out_w,
type);
ref(table_data, idx_data, o_data, &attr);
TestAllImpls<KT, jit::EmbSeqPoolTuples<T>, PlaceType, std::vector<T>,
std::vector<int64_t>, std::vector<T>>(attr, table, idx,
oref, attr);
}
}
}
}
}
template <jit::KernelType KT, typename T, typename PlaceType> template <jit::KernelType KT, typename T, typename PlaceType>
void TestNCHW16CMulNCKernel() { void TestNCHW16CMulNCKernel() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
...@@ -756,6 +816,11 @@ TEST(JITKernel, kSoftmax) { ...@@ -756,6 +816,11 @@ TEST(JITKernel, kSoftmax) {
TestSoftmaxKernel<jit::kSoftmax, double, CPUPlace>(); TestSoftmaxKernel<jit::kSoftmax, double, CPUPlace>();
} }
TEST(JITKernel, kEmbSeqPool) {
TestEmbSeqPoolKernel<jit::kEmbSeqPool, float, CPUPlace>();
TestEmbSeqPoolKernel<jit::kEmbSeqPool, double, CPUPlace>();
}
TEST(JITKernel, kNCHW16CMulNC) { TEST(JITKernel, kNCHW16CMulNC) {
TestNCHW16CMulNCKernel<jit::kNCHW16CMulNC, float, CPUPlace>(); TestNCHW16CMulNCKernel<jit::kNCHW16CMulNC, float, CPUPlace>();
TestNCHW16CMulNCKernel<jit::kNCHW16CMulNC, double, CPUPlace>(); TestNCHW16CMulNCKernel<jit::kNCHW16CMulNC, double, CPUPlace>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册