提交 14a764c9 编写于 作者: T tensor-tang

simplify the jitkernel templates and tests

test=develop
上级 802f362a
...@@ -82,9 +82,9 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> { ...@@ -82,9 +82,9 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
Tensor track; Tensor track;
int* track_value = int* track_value =
track.mutable_data<int>(emission_dims, platform::CPUPlace()); track.mutable_data<int>(emission_dims, platform::CPUPlace());
auto ker = jit::KernelFuncs<jit::kCRFDecoding, jit::CRFDecodingTuples<T>, auto ker =
platform::CPUPlace>::Cache() jit::KernelFuncs<jit::CRFDecodingTuple<T>, platform::CPUPlace>::Cache()
.At(tag_num); .At(tag_num);
ker(static_cast<int>(seq_len), x, w, alpha_value, track_value, tag_num); ker(static_cast<int>(seq_len), x, w, alpha_value, track_value, tag_num);
T max_score = -std::numeric_limits<T>::max(); T max_score = -std::numeric_limits<T>::max();
int max_i = 0; int max_i = 0;
......
...@@ -110,10 +110,9 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> { ...@@ -110,10 +110,9 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
constexpr int simd_width = 16; constexpr int simd_width = 16;
int C = c / simd_width; int C = c / simd_width;
auto multiply = auto multiply = jit::KernelFuncs<jit::NCHW16CMulNCTuple<T>,
jit::KernelFuncs<jit::kNCHW16CMulNC, jit::NCHW16CMulNCTuples<T>, platform::CPUPlace>::Cache()
platform::CPUPlace>::Cache() .At(0);
.At(0);
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (int ni = 0; ni < n; ni++) { for (int ni = 0; ni < n; ni++) {
for (int ci = 0; ci < C; ci++) { for (int ci = 0; ci < C; ci++) {
......
...@@ -53,8 +53,7 @@ struct EmbeddingVSumFunctor { ...@@ -53,8 +53,7 @@ struct EmbeddingVSumFunctor {
for (size_t i = 0; i != ids_lod.size() - 1; ++i) { for (size_t i = 0; i != ids_lod.size() - 1; ++i) {
attr.index_height = ids_lod[i + 1] - ids_lod[i]; attr.index_height = ids_lod[i + 1] - ids_lod[i];
auto emb_seqpool = auto emb_seqpool =
jit::KernelFuncs<jit::kEmbSeqPool, jit::EmbSeqPoolTuples<T>, jit::KernelFuncs<jit::EmbSeqPoolTuple<T>, platform::CPUPlace>::Cache()
platform::CPUPlace>::Cache()
.At(attr); .At(attr);
emb_seqpool(table, ids + ids_lod[i] * idx_width, output + i * out_width, emb_seqpool(table, ids + ids_lod[i] * idx_width, output + i * out_width,
&attr); &attr);
...@@ -138,8 +137,7 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> { ...@@ -138,8 +137,7 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
const T *d_output_data = d_output->data<T>(); const T *d_output_data = d_output->data<T>();
auto vbroadcast = auto vbroadcast =
jit::KernelFuncs<jit::kVBroadcast, jit::VBroadcastTuples<T>, jit::KernelFuncs<jit::VBroadcastTuple<T>, platform::CPUPlace>::Cache()
platform::CPUPlace>::Cache()
.At(out_width); .At(out_width);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) { for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]); int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
......
...@@ -182,32 +182,32 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -182,32 +182,32 @@ class FusionGRUKernel : public framework::OpKernel<T> {
const int total_T = x_dims[0]; \ const int total_T = x_dims[0]; \
const int D3 = wh_dims[1] const int D3 = wh_dims[1]
#define INIT_OTHER_DEFINES \ #define INIT_OTHER_DEFINES \
auto* h0 = ctx.Input<Tensor>("H0"); \ auto* h0 = ctx.Input<Tensor>("H0"); \
auto* wx = ctx.Input<Tensor>("WeightX"); \ auto* wx = ctx.Input<Tensor>("WeightX"); \
auto* bias = ctx.Input<Tensor>("Bias"); \ auto* bias = ctx.Input<Tensor>("Bias"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \ auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
bool is_reverse = ctx.Attr<bool>("is_reverse"); \ bool is_reverse = ctx.Attr<bool>("is_reverse"); \
const int M = x_dims[1]; \ const int M = x_dims[1]; \
const int D = wh_dims[0]; \ const int D = wh_dims[0]; \
const int D2 = D * 2; \ const int D2 = D * 2; \
const jit::gru_attr_t attr( \ const jit::gru_attr_t attr( \
D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \ D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \
jit::to_kerneltype(ctx.Attr<std::string>("activation"))); \ jit::to_kerneltype(ctx.Attr<std::string>("activation"))); \
jit::gru_t one_step; \ jit::gru_t one_step; \
auto ComputeH1 = jit::KernelFuncs<jit::kGRUH1, jit::GRUTuples<T>, \ auto ComputeH1 = \
platform::CPUPlace>::Cache() \ jit::KernelFuncs<jit::GRUH1Tuple<T>, platform::CPUPlace>::Cache().At( \
.At(attr); \ attr); \
auto ComputeHtPart1 = jit::KernelFuncs<jit::kGRUHtPart1, jit::GRUTuples<T>, \ auto ComputeHtPart1 = \
platform::CPUPlace>::Cache() \ jit::KernelFuncs<jit::GRUHtPart1Tuple<T>, platform::CPUPlace>::Cache() \
.At(attr); \ .At(attr); \
auto ComputeHtPart2 = jit::KernelFuncs<jit::kGRUHtPart2, jit::GRUTuples<T>, \ auto ComputeHtPart2 = \
platform::CPUPlace>::Cache() \ jit::KernelFuncs<jit::GRUHtPart2Tuple<T>, platform::CPUPlace>::Cache() \
.At(attr); \ .At(attr); \
const T* x_data = x->data<T>(); \ const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \ const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \ const T* wh_data = wh->data<T>(); \
auto place = ctx.GetPlace(); \ auto place = ctx.GetPlace(); \
T* xx_data = xx->mutable_data<T>(place) T* xx_data = xx->mutable_data<T>(place)
void SeqCompute(const framework::ExecutionContext& ctx) const { void SeqCompute(const framework::ExecutionContext& ctx) const {
......
...@@ -235,34 +235,34 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -235,34 +235,34 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
const int D = wh_dims[0]; \ const int D = wh_dims[0]; \
const int D4 = wh_dims[1] const int D4 = wh_dims[1]
#define INIT_OTHER_DEFINES \ #define INIT_OTHER_DEFINES \
const T* x_data = x->data<T>(); \ const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \ const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \ const T* wh_data = wh->data<T>(); \
/* diagonal weight*/ \ /* diagonal weight*/ \
const T* wp_data = bias->data<T>() + D4; \ const T* wp_data = bias->data<T>() + D4; \
/* for peephole only*/ \ /* for peephole only*/ \
T* checked_cell_data = nullptr; \ T* checked_cell_data = nullptr; \
auto place = ctx.GetPlace(); \ auto place = ctx.GetPlace(); \
if (use_peepholes) { \ if (use_peepholes) { \
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \ /* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \
auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \ auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \
checked_cell_data = checked_cell->mutable_data<T>(place); \ checked_cell_data = checked_cell->mutable_data<T>(place); \
} \ } \
const jit::lstm_attr_t attr( \ const jit::lstm_attr_t attr( \
D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \ D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \
jit::to_kerneltype(ctx.Attr<std::string>("candidate_activation")), \ jit::to_kerneltype(ctx.Attr<std::string>("candidate_activation")), \
jit::to_kerneltype(ctx.Attr<std::string>("cell_activation")), \ jit::to_kerneltype(ctx.Attr<std::string>("cell_activation")), \
use_peepholes); \ use_peepholes); \
jit::lstm_t one_step; \ jit::lstm_t one_step; \
one_step.wp = wp_data; \ one_step.wp = wp_data; \
one_step.checked = checked_cell_data; \ one_step.checked = checked_cell_data; \
auto ComputeC1H1 = jit::KernelFuncs<jit::kLSTMC1H1, jit::LSTMTuples<T>, \ auto ComputeC1H1 = \
platform::CPUPlace>::Cache() \ jit::KernelFuncs<jit::LSTMC1H1Tuple<T>, platform::CPUPlace>::Cache().At( \
.At(attr); \ attr); \
auto ComputeCtHt = jit::KernelFuncs<jit::kLSTMCtHt, jit::LSTMTuples<T>, \ auto ComputeCtHt = \
platform::CPUPlace>::Cache() \ jit::KernelFuncs<jit::LSTMCtHtTuple<T>, platform::CPUPlace>::Cache().At( \
.At(attr) attr)
// Wh GEMM // Wh GEMM
#define GEMM_WH_ADDON(bs, prev, out) \ #define GEMM_WH_ADDON(bs, prev, out) \
......
...@@ -81,12 +81,12 @@ void FusionRepeatedFCReluOpMaker::Make() { ...@@ -81,12 +81,12 @@ void FusionRepeatedFCReluOpMaker::Make() {
template <typename T> template <typename T>
static void fc_relu(const T* x, const T* w, const T* b, T* y, static void fc_relu(const T* x, const T* w, const T* b, T* y,
const jit::matmul_attr_t& attr) { const jit::matmul_attr_t& attr) {
auto matmul = jit::KernelFuncs<jit::kMatMul, jit::MatMulTuples<T>, auto matmul =
platform::CPUPlace>::Cache() jit::KernelFuncs<jit::MatMulTuple<T>, platform::CPUPlace>::Cache().At(
.At(attr); attr);
auto addbias_relu = jit::KernelFuncs<jit::kVAddRelu, jit::XYZNTuples<T>, auto addbias_relu =
platform::CPUPlace>::Cache() jit::KernelFuncs<jit::VAddReluTuple<T>, platform::CPUPlace>::Cache().At(
.At(attr.n); attr.n);
matmul(x, w, y, &attr); matmul(x, w, y, &attr);
T* dst = y; T* dst = y;
for (int i = 0; i < attr.m; ++i) { for (int i = 0; i < attr.m; ++i) {
......
...@@ -97,9 +97,9 @@ class FusionSeqPoolConcatKernel : public framework::OpKernel<T> { ...@@ -97,9 +97,9 @@ class FusionSeqPoolConcatKernel : public framework::OpKernel<T> {
} else if (pooltype == "SQRT") { } else if (pooltype == "SQRT") {
attr.type = jit::SeqPoolType::kSqrt; attr.type = jit::SeqPoolType::kSqrt;
} }
auto seqpool = jit::KernelFuncs<jit::kSeqPool, jit::SeqPoolTuples<T>, auto seqpool =
platform::CPUPlace>::Cache() jit::KernelFuncs<jit::SeqPoolTuple<T>, platform::CPUPlace>::Cache().At(
.At(attr); attr);
size_t n = ins.size(); size_t n = ins.size();
size_t dst_step_size = n * w; size_t dst_step_size = n * w;
for (size_t i = 0; i < n; ++i) { for (size_t i = 0; i < n; ++i) {
......
...@@ -93,24 +93,24 @@ class FusionSquaredMatSubKernel : public framework::OpKernel<T> { ...@@ -93,24 +93,24 @@ class FusionSquaredMatSubKernel : public framework::OpKernel<T> {
attr.n = y_dims[1]; attr.n = y_dims[1];
int o_numel = attr.m * attr.n; int o_numel = attr.m * attr.n;
auto vsquare_x = jit::KernelFuncs<jit::kVSquare, jit::XYNTuples<T>, auto vsquare_x =
platform::CPUPlace>::Cache() jit::KernelFuncs<jit::VSquareTuple<T>, platform::CPUPlace>::Cache().At(
.At(attr.m * attr.k); attr.m * attr.k);
auto vsquare_y = jit::KernelFuncs<jit::kVSquare, jit::XYNTuples<T>, auto vsquare_y =
platform::CPUPlace>::Cache() jit::KernelFuncs<jit::VSquareTuple<T>, platform::CPUPlace>::Cache().At(
.At(attr.k * attr.n); attr.k * attr.n);
auto vsquare_xy = jit::KernelFuncs<jit::kVSquare, jit::XYNTuples<T>, auto vsquare_xy =
platform::CPUPlace>::Cache() jit::KernelFuncs<jit::VSquareTuple<T>, platform::CPUPlace>::Cache().At(
.At(o_numel); o_numel);
auto vsub = jit::KernelFuncs<jit::kVSub, jit::XYZNTuples<T>, auto vsub =
platform::CPUPlace>::Cache() jit::KernelFuncs<jit::VSubTuple<T>, platform::CPUPlace>::Cache().At(
.At(o_numel); o_numel);
auto vscal = jit::KernelFuncs<jit::kVScal, jit::AXYNTuples<T>, auto vscal =
platform::CPUPlace>::Cache() jit::KernelFuncs<jit::VScalTuple<T>, platform::CPUPlace>::Cache().At(
.At(o_numel); o_numel);
auto matmul = jit::KernelFuncs<jit::kMatMul, jit::MatMulTuples<T>, auto matmul =
platform::CPUPlace>::Cache() jit::KernelFuncs<jit::MatMulTuple<T>, platform::CPUPlace>::Cache().At(
.At(attr); attr);
const T* x_data = x->data<T>(); const T* x_data = x->data<T>();
const T* y_data = y->data<T>(); const T* y_data = y->data<T>();
......
...@@ -59,8 +59,6 @@ BenchJITKernel* InsertBenchmark(BenchJITKernel* b) { ...@@ -59,8 +59,6 @@ BenchJITKernel* InsertBenchmark(BenchJITKernel* b) {
InsertBenchmark(new BenchJITKernel_##name##_##dtype##_##place##_()); \ InsertBenchmark(new BenchJITKernel_##name##_##dtype##_##place##_()); \
void BenchJITKernel_##name##_##dtype##_##place##_::Run() void BenchJITKernel_##name##_##dtype##_##place##_::Run()
#define BENCH_FP32_CPU(name) BENCH_JITKERNEL(name, FP32, CPU)
void RUN_ALL_BENCHMARK() { void RUN_ALL_BENCHMARK() {
for (auto p : g_all_benchmarks) { for (auto p : g_all_benchmarks) {
if (!FLAGS_filter.empty() && FLAGS_filter != p->Name()) { if (!FLAGS_filter.empty() && FLAGS_filter != p->Name()) {
...@@ -90,11 +88,11 @@ std::vector<int> TestSizes() { ...@@ -90,11 +88,11 @@ std::vector<int> TestSizes() {
return s; return s;
} }
template <typename KernelTuples, typename... Args> template <typename KernelTuple, typename... Args>
struct BenchFunc { struct BenchFunc {
// return this function avg time // return this function avg time
// TODO(TJ): clear cache every time // TODO(TJ): clear cache every time
double operator()(const typename KernelTuples::func_type tgt, Args... args) { double operator()(const typename KernelTuple::func_type tgt, Args... args) {
for (int i = 0; i < FLAGS_burning; ++i) { for (int i = 0; i < FLAGS_burning; ++i) {
tgt(args...); tgt(args...);
} }
...@@ -109,31 +107,30 @@ struct BenchFunc { ...@@ -109,31 +107,30 @@ struct BenchFunc {
namespace jit = paddle::operators::jit; namespace jit = paddle::operators::jit;
template <jit::KernelType KT, typename KernelTuples, typename PlaceType, template <typename KernelTuple, typename PlaceType, typename... Args>
typename... Args> void BenchAllImpls(const typename KernelTuple::attr_type& attr, Args... args) {
void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) { BenchFunc<KernelTuple, Args...> benchmark;
BenchFunc<KernelTuples, Args...> benchmark;
std::vector<std::pair<std::string, double>> infos; std::vector<std::pair<std::string, double>> infos;
// test refer // test refer
auto refer = jit::GetRefer<KT, KernelTuples>(); auto refer = jit::GetRefer<KernelTuple>();
if (!refer) { if (!refer) {
LOG(FATAL) << "Refer can not be empty!"; LOG(FATAL) << "Refer can not be empty!";
} }
infos.push_back(std::make_pair("Refer", benchmark(refer, args...))); infos.push_back(std::make_pair("Refer", benchmark(refer, args...)));
// test jitcode // test jitcode
auto jitcode = jit::GetJitCode<KT, KernelTuples, PlaceType>(attr); auto jitcode = jit::GetJitCode<KernelTuple, PlaceType>(attr);
if (jitcode) { if (jitcode) {
infos.push_back(std::make_pair("JitCode", benchmark(jitcode, args...))); infos.push_back(std::make_pair("JitCode", benchmark(jitcode, args...)));
} }
// test all impls in more // test all impls in more
jit::KernelKey kkey(KT, PlaceType()); jit::KernelKey kkey(KernelTuple::kernel_type, PlaceType());
auto& pool = jit::KernelPool().Instance().AllKernels(); auto& pool = jit::KernelPool().Instance().AllKernels();
auto iter = pool.find(kkey); auto iter = pool.find(kkey);
if (iter != pool.end()) { if (iter != pool.end()) {
auto& impls = iter->second; auto& impls = iter->second;
for (auto& impl : impls) { for (auto& impl : impls) {
auto i = dynamic_cast<const jit::KernelMore<KernelTuples>*>(impl.get()); auto i = dynamic_cast<const jit::KernelMore<KernelTuple>*>(impl.get());
if (i && i->UseMe(attr)) { if (i && i->UseMe(attr)) {
auto more = i->GetFunc(); auto more = i->GetFunc();
infos.push_back( infos.push_back(
...@@ -142,7 +139,7 @@ void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) { ...@@ -142,7 +139,7 @@ void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
} }
} }
// Test result from Get function // Test result from Get function
auto tgt = jit::KernelFuncs<KT, KernelTuples, PlaceType>::Cache().At(attr); auto tgt = jit::KernelFuncs<KernelTuple, PlaceType>::Cache().At(attr);
if (!tgt) { if (!tgt) {
LOG(FATAL) << "Target can not be empty!"; LOG(FATAL) << "Target can not be empty!";
} }
...@@ -150,7 +147,8 @@ void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) { ...@@ -150,7 +147,8 @@ void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
// print // print
std::ostringstream loginfos; std::ostringstream loginfos;
loginfos << "Kernel Type " << jit::to_string(KT) << ": " << attr << ": "; loginfos << "Kernel Type " << jit::to_string(KernelTuple::kernel_type) << ": "
<< attr << ": ";
for (auto pair : infos) { for (auto pair : infos) {
loginfos << pair.first << " takes " << pair.second << " us; "; loginfos << pair.first << " takes " << pair.second << " us; ";
} }
...@@ -159,8 +157,9 @@ void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) { ...@@ -159,8 +157,9 @@ void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
using Tensor = paddle::framework::Tensor; using Tensor = paddle::framework::Tensor;
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchXYZNKernel() { void BenchKernelXYZN() {
using T = typename KernelTuple::data_type;
for (int d : TestSizes()) { for (int d : TestSizes()) {
Tensor x, y, z; Tensor x, y, z;
x.Resize({d}); x.Resize({d});
...@@ -171,16 +170,16 @@ void BenchXYZNKernel() { ...@@ -171,16 +170,16 @@ void BenchXYZNKernel() {
T* z_data = z.mutable_data<T>(PlaceType()); T* z_data = z.mutable_data<T>(PlaceType());
RandomVec<T>(d, x_data); RandomVec<T>(d, x_data);
RandomVec<T>(d, y_data); RandomVec<T>(d, y_data);
BenchAllImpls<KT, jit::XYZNTuples<T>, PlaceType>(d, x.data<T>(), BenchAllImpls<KernelTuple, PlaceType>(d, x.data<T>(), y.data<T>(), z_data,
y.data<T>(), z_data, d); d);
// test inplace // test inplace
BenchAllImpls<KT, jit::XYZNTuples<T>, PlaceType>(d, x.data<T>(), z_data, BenchAllImpls<KernelTuple, PlaceType>(d, x.data<T>(), z_data, z_data, d);
z_data, d);
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchAXYNKernel() { void BenchKernelAXYN() {
using T = typename KernelTuple::data_type;
for (int d : TestSizes()) { for (int d : TestSizes()) {
const T a = static_cast<T>(3); const T a = static_cast<T>(3);
Tensor x, y; Tensor x, y;
...@@ -189,26 +188,26 @@ void BenchAXYNKernel() { ...@@ -189,26 +188,26 @@ void BenchAXYNKernel() {
T* x_data = x.mutable_data<T>(PlaceType()); T* x_data = x.mutable_data<T>(PlaceType());
T* y_data = y.mutable_data<T>(PlaceType()); T* y_data = y.mutable_data<T>(PlaceType());
RandomVec<T>(d, x_data); RandomVec<T>(d, x_data);
BenchAllImpls<KT, jit::AXYNTuples<T>, PlaceType>(d, &a, x.data<T>(), y_data, BenchAllImpls<KernelTuple, PlaceType>(d, &a, x.data<T>(), y_data, d);
d);
// test inplace // test inplace
BenchAllImpls<KT, jit::AXYNTuples<T>, PlaceType>(d, &a, x.data<T>(), x_data, BenchAllImpls<KernelTuple, PlaceType>(d, &a, x.data<T>(), x_data, d);
d);
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchXRNKernel() { void BenchKernelXRN() {
using T = typename KernelTuple::data_type;
for (int d : TestSizes()) { for (int d : TestSizes()) {
Tensor x; Tensor x;
RandomVec<T>(d, x.mutable_data<T>({d}, PlaceType())); RandomVec<T>(d, x.mutable_data<T>({d}, PlaceType()));
T res; T res;
BenchAllImpls<KT, jit::XRNTuples<T>, PlaceType>(d, x.data<T>(), &res, d); BenchAllImpls<KernelTuple, PlaceType>(d, x.data<T>(), &res, d);
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchXYNKernel() { void BenchKernelXYN() {
using T = typename KernelTuple::data_type;
for (int d : TestSizes()) { for (int d : TestSizes()) {
Tensor x, y; Tensor x, y;
x.Resize({d}); x.Resize({d});
...@@ -216,12 +215,13 @@ void BenchXYNKernel() { ...@@ -216,12 +215,13 @@ void BenchXYNKernel() {
T* x_data = x.mutable_data<T>(PlaceType()); T* x_data = x.mutable_data<T>(PlaceType());
T* y_data = y.mutable_data<T>(PlaceType()); T* y_data = y.mutable_data<T>(PlaceType());
RandomVec<T>(d, x_data); RandomVec<T>(d, x_data);
BenchAllImpls<KT, jit::XYNTuples<T>, PlaceType>(d, x.data<T>(), y_data, d); BenchAllImpls<KernelTuple, PlaceType>(d, x.data<T>(), y_data, d);
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchLSTMKernel() { void BenchKernelLSTM() {
using T = typename KernelTuple::data_type;
for (bool use_peephole : {true, false}) { for (bool use_peephole : {true, false}) {
for (int d : TestSizes()) { for (int d : TestSizes()) {
const jit::lstm_attr_t attr(d, jit::kVSigmoid, jit::kVTanh, jit::kVTanh, const jit::lstm_attr_t attr(d, jit::kVSigmoid, jit::kVTanh, jit::kVTanh,
...@@ -252,13 +252,14 @@ void BenchLSTMKernel() { ...@@ -252,13 +252,14 @@ void BenchLSTMKernel() {
step.wp = wp_data; step.wp = wp_data;
step.checked = checked_data; step.checked = checked_data;
} }
BenchAllImpls<KT, jit::LSTMTuples<T>, PlaceType>(attr, &step, &attr); BenchAllImpls<KernelTuple, PlaceType>(attr, &step, &attr);
} }
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchGRUKernel() { void BenchKernelGRU() {
using T = typename KernelTuple::data_type;
for (int d : TestSizes()) { for (int d : TestSizes()) {
const jit::gru_attr_t attr(d, jit::kVSigmoid, jit::kVTanh); const jit::gru_attr_t attr(d, jit::kVSigmoid, jit::kVTanh);
auto place = PlaceType(); auto place = PlaceType();
...@@ -275,12 +276,13 @@ void BenchGRUKernel() { ...@@ -275,12 +276,13 @@ void BenchGRUKernel() {
step.gates = x_data; step.gates = x_data;
step.ht_1 = ht_1_data; step.ht_1 = ht_1_data;
step.ht = ht_data; step.ht = ht_data;
BenchAllImpls<KT, jit::GRUTuples<T>, PlaceType>(attr, &step, &attr); BenchAllImpls<KernelTuple, PlaceType>(attr, &step, &attr);
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchSeqPoolKernel() { void BenchKernelSeqPool() {
using T = typename KernelTuple::data_type;
std::vector<jit::SeqPoolType> pool_types = { std::vector<jit::SeqPoolType> pool_types = {
jit::SeqPoolType::kSum, jit::SeqPoolType::kAvg, jit::SeqPoolType::kSqrt}; jit::SeqPoolType::kSum, jit::SeqPoolType::kAvg, jit::SeqPoolType::kSqrt};
for (auto type : pool_types) { for (auto type : pool_types) {
...@@ -294,15 +296,15 @@ void BenchSeqPoolKernel() { ...@@ -294,15 +296,15 @@ void BenchSeqPoolKernel() {
RandomVec<T>(h * w, x.mutable_data<T>(PlaceType()), -2.f, 2.f); RandomVec<T>(h * w, x.mutable_data<T>(PlaceType()), -2.f, 2.f);
const T* x_data = x.data<T>(); const T* x_data = x.data<T>();
T* y_data = y.mutable_data<T>(PlaceType()); T* y_data = y.mutable_data<T>(PlaceType());
BenchAllImpls<KT, jit::SeqPoolTuples<T>, PlaceType>(attr, x_data, BenchAllImpls<KernelTuple, PlaceType>(attr, x_data, y_data, &attr);
y_data, &attr);
} }
} }
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchEmbSeqPoolKernel() { void BenchKernelEmbSeqPool() {
using T = typename KernelTuple::data_type;
std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum}; std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum};
int64_t tbl_h = 1e4; int64_t tbl_h = 1e4;
for (int tbl_w : {10, 16, 256}) { for (int tbl_w : {10, 16, 256}) {
...@@ -324,16 +326,17 @@ void BenchEmbSeqPoolKernel() { ...@@ -324,16 +326,17 @@ void BenchEmbSeqPoolKernel() {
tbl_h - 1); tbl_h - 1);
const int64_t* idx_data = idx.data<int64_t>(); const int64_t* idx_data = idx.data<int64_t>();
T* o_data = out.mutable_data<T>(PlaceType()); T* o_data = out.mutable_data<T>(PlaceType());
BenchAllImpls<KT, jit::EmbSeqPoolTuples<T>, PlaceType>( BenchAllImpls<KernelTuple, PlaceType>(attr, table_data, idx_data,
attr, table_data, idx_data, o_data, &attr); o_data, &attr);
} }
} }
} }
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchSgdKernel() { void BenchKernelSgd() {
using T = typename KernelTuple::data_type;
const T lr = 0.1; const T lr = 0.1;
auto UnDuplicatedRandomVec = [](int n, const int64_t lower, auto UnDuplicatedRandomVec = [](int n, const int64_t lower,
const int64_t upper) -> std::vector<int64_t> { const int64_t upper) -> std::vector<int64_t> {
...@@ -364,15 +367,16 @@ void BenchSgdKernel() { ...@@ -364,15 +367,16 @@ void BenchSgdKernel() {
const T* grad_data = grad.data<T>(); const T* grad_data = grad.data<T>();
const int64_t* rows_data = rows.data(); const int64_t* rows_data = rows.data();
jit::sgd_attr_t attr(param_h, grad_w, rows_size, grad_w, rows_size); jit::sgd_attr_t attr(param_h, grad_w, rows_size, grad_w, rows_size);
BenchAllImpls<KT, jit::SgdTuples<T>, PlaceType>( BenchAllImpls<KernelTuple, PlaceType>(attr, &lr, param_data, grad_data,
attr, &lr, param_data, grad_data, rows_data, param_data, &attr); rows_data, param_data, &attr);
} }
} }
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchMatMulKernel() { void BenchKernelMatMul() {
using T = typename KernelTuple::data_type;
for (int m : {1, 2, 3, 4}) { for (int m : {1, 2, 3, 4}) {
for (int n : TestSizes()) { for (int n : TestSizes()) {
for (int k : TestSizes()) { for (int k : TestSizes()) {
...@@ -386,15 +390,16 @@ void BenchMatMulKernel() { ...@@ -386,15 +390,16 @@ void BenchMatMulKernel() {
const T* b_data = b.data<T>(); const T* b_data = b.data<T>();
T* c_data = c.mutable_data<T>(PlaceType()); T* c_data = c.mutable_data<T>(PlaceType());
const jit::matmul_attr_t attr{m, n, k}; const jit::matmul_attr_t attr{m, n, k};
BenchAllImpls<KT, jit::MatMulTuples<T>, PlaceType>(attr, a_data, b_data, BenchAllImpls<KernelTuple, PlaceType>(attr, a_data, b_data, c_data,
c_data, &attr); &attr);
} }
} }
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchSoftmaxKernel() { void BenchKernelSoftmax() {
using T = typename KernelTuple::data_type;
for (int bs : {1, 2, 10}) { for (int bs : {1, 2, 10}) {
for (int n : TestSizes()) { for (int n : TestSizes()) {
Tensor x, y; Tensor x, y;
...@@ -403,14 +408,14 @@ void BenchSoftmaxKernel() { ...@@ -403,14 +408,14 @@ void BenchSoftmaxKernel() {
RandomVec<T>(bs * n, x.mutable_data<T>(PlaceType()), -2.f, 2.f); RandomVec<T>(bs * n, x.mutable_data<T>(PlaceType()), -2.f, 2.f);
const T* x_data = x.data<T>(); const T* x_data = x.data<T>();
T* y_data = y.mutable_data<T>(PlaceType()); T* y_data = y.mutable_data<T>(PlaceType());
BenchAllImpls<KT, jit::SoftmaxTuples<T>, PlaceType>(n, x_data, y_data, n, BenchAllImpls<KernelTuple, PlaceType>(n, x_data, y_data, n, bs);
bs);
} }
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchLayerNormKernel() { void BenchKernelLayerNorm() {
using T = typename KernelTuple::data_type;
const T epsilon = 9.99999975e-06; const T epsilon = 9.99999975e-06;
for (int n : {1, 2, 10}) { for (int n : {1, 2, 10}) {
for (int x_dim_0 : {1, 9, 17, 50}) { for (int x_dim_0 : {1, 9, 17, 50}) {
...@@ -439,16 +444,17 @@ void BenchLayerNormKernel() { ...@@ -439,16 +444,17 @@ void BenchLayerNormKernel() {
T* var_data = var.data<T>(); T* var_data = var.data<T>();
T* out_data = out.mutable_data<T>(PlaceType()); T* out_data = out.mutable_data<T>(PlaceType());
BenchAllImpls<KT, jit::LayerNormTuples<T>, PlaceType>( BenchAllImpls<KernelTuple, PlaceType>(right, x_data, out_data,
right, x_data, out_data, mean_data, var_data, scale_data, bias_data, mean_data, var_data, scale_data,
left, epsilon, right); bias_data, left, epsilon, right);
} }
} }
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchCRFDecodingKernel() { void BenchKernelCRFDecoding() {
using T = typename KernelTuple::data_type;
constexpr int state_trans_base_idx = 2; constexpr int state_trans_base_idx = 2;
for (int seq_len : {1, 11, 17, 50}) { for (int seq_len : {1, 11, 17, 50}) {
for (int tag_num : TestSizes()) { for (int tag_num : TestSizes()) {
...@@ -468,14 +474,15 @@ void BenchCRFDecodingKernel() { ...@@ -468,14 +474,15 @@ void BenchCRFDecodingKernel() {
T* alpha_data = alpha.mutable_data<T>(PlaceType()); T* alpha_data = alpha.mutable_data<T>(PlaceType());
int* track_data = track.mutable_data<int>(PlaceType()); int* track_data = track.mutable_data<int>(PlaceType());
BenchAllImpls<KT, jit::CRFDecodingTuples<T>, PlaceType>( BenchAllImpls<KernelTuple, PlaceType>(tag_num, seq_len, x_data, w_data,
tag_num, seq_len, x_data, w_data, alpha_data, track_data, tag_num); alpha_data, track_data, tag_num);
} }
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchVBroadcastKernel() { void BenchKernelVBroadcast() {
using T = typename KernelTuple::data_type;
for (int64_t w : {1, 16, 64, 100, 256}) { for (int64_t w : {1, 16, 64, 100, 256}) {
Tensor x; Tensor x;
x.Resize({w}); x.Resize({w});
...@@ -485,78 +492,86 @@ void BenchVBroadcastKernel() { ...@@ -485,78 +492,86 @@ void BenchVBroadcastKernel() {
Tensor y; Tensor y;
y.Resize({h * w}); y.Resize({h * w});
T* y_data = y.mutable_data<T>(PlaceType()); T* y_data = y.mutable_data<T>(PlaceType());
BenchAllImpls<KT, jit::VBroadcastTuples<T>, PlaceType>( BenchAllImpls<KernelTuple, PlaceType>(w, x_data, y_data,
w, x_data, y_data, static_cast<int64_t>(h), w); static_cast<int64_t>(h), w);
} }
} }
} }
using T = float; #define BenchKernelVMul BenchKernelXYZN
using CPUPlace = paddle::platform::CPUPlace; #define BenchKernelVAdd BenchKernelXYZN
#define BenchKernelVAddRelu BenchKernelXYZN
#define BenchKernelVSub BenchKernelXYZN
// xyzn #define BenchKernelVScal BenchKernelAXYN
BENCH_FP32_CPU(kVMul) { BenchXYZNKernel<jit::kVMul, T, CPUPlace>(); } #define BenchKernelVAddBias BenchKernelAXYN
BENCH_FP32_CPU(kVAdd) { BenchXYZNKernel<jit::kVAdd, T, CPUPlace>(); }
BENCH_FP32_CPU(kVAddRelu) { BenchXYZNKernel<jit::kVAddRelu, T, CPUPlace>(); }
BENCH_FP32_CPU(kVSub) { BenchXYZNKernel<jit::kVSub, T, CPUPlace>(); }
// axyn #define BenchKernelVRelu BenchKernelXYN
BENCH_FP32_CPU(kVScal) { BenchAXYNKernel<jit::kVScal, T, CPUPlace>(); } #define BenchKernelVIdentity BenchKernelXYN
BENCH_FP32_CPU(kVAddBias) { BenchAXYNKernel<jit::kVAddBias, T, CPUPlace>(); } #define BenchKernelVSquare BenchKernelXYN
#define BenchKernelVExp BenchKernelXYN
#define BenchKernelVSigmoid BenchKernelXYN
#define BenchKernelVTanh BenchKernelXYN
#define BenchKernelVCopy BenchKernelXYN
// xrn #define BenchKernelHMax BenchKernelXRN
BENCH_FP32_CPU(kHSum) { BenchXRNKernel<jit::kHSum, T, CPUPlace>(); } #define BenchKernelHSum BenchKernelXRN
BENCH_FP32_CPU(kHMax) { BenchXRNKernel<jit::kHMax, T, CPUPlace>(); }
// xyn #define BenchKernelLSTMCtHt BenchKernelLSTM
BENCH_FP32_CPU(kVRelu) { BenchXYNKernel<jit::kVRelu, T, CPUPlace>(); } #define BenchKernelLSTMC1H1 BenchKernelLSTM
BENCH_FP32_CPU(kVIdentity) { BenchXYNKernel<jit::kVIdentity, T, CPUPlace>(); }
BENCH_FP32_CPU(kVSquare) { BenchXYNKernel<jit::kVSquare, T, CPUPlace>(); }
BENCH_FP32_CPU(kVExp) { BenchXYNKernel<jit::kVExp, T, CPUPlace>(); }
BENCH_FP32_CPU(kVSigmoid) { BenchXYNKernel<jit::kVSigmoid, T, CPUPlace>(); }
BENCH_FP32_CPU(kVTanh) { BenchXYNKernel<jit::kVTanh, T, CPUPlace>(); }
BENCH_FP32_CPU(kVCopy) { BenchXYNKernel<jit::kVCopy, T, CPUPlace>(); }
// lstm and peephole
BENCH_FP32_CPU(kLSTMCtHt) { BenchLSTMKernel<jit::kLSTMCtHt, T, CPUPlace>(); }
BENCH_FP32_CPU(kLSTMC1H1) { BenchLSTMKernel<jit::kLSTMC1H1, T, CPUPlace>(); }
// gru functions
BENCH_FP32_CPU(kGRUH1) { BenchGRUKernel<jit::kGRUH1, T, CPUPlace>(); }
BENCH_FP32_CPU(kGRUHtPart1) { BenchGRUKernel<jit::kGRUHtPart1, T, CPUPlace>(); }
BENCH_FP32_CPU(kGRUHtPart2) { BenchGRUKernel<jit::kGRUHtPart2, T, CPUPlace>(); }
// seq pool function
BENCH_FP32_CPU(kSeqPool) { BenchSeqPoolKernel<jit::kSeqPool, T, CPUPlace>(); }
// embedding seq pool function
BENCH_FP32_CPU(kEmbSeqPool) {
BenchEmbSeqPoolKernel<jit::kEmbSeqPool, T, CPUPlace>();
}
// sgd function #define BenchKernelGRUH1 BenchKernelGRU
BENCH_FP32_CPU(kSgd) { BenchSgdKernel<jit::kSgd, T, CPUPlace>(); } #define BenchKernelGRUHtPart1 BenchKernelGRU
#define BenchKernelGRUHtPart2 BenchKernelGRU
// matmul using CPUPlace = paddle::platform::CPUPlace;
BENCH_FP32_CPU(kMatMul) { BenchMatMulKernel<jit::kMatMul, T, CPUPlace>(); }
// softmax #define BENCH_FP32_CPU(name) \
BENCH_FP32_CPU(kSoftmax) { BenchSoftmaxKernel<jit::kSoftmax, T, CPUPlace>(); } BENCH_JITKERNEL(name, FP32, CPU) { \
BenchKernel##name<jit::name##Tuple<float>, CPUPlace>(); \
}
// layernorm // xyzn
BENCH_FP32_CPU(kLayerNorm) { BENCH_FP32_CPU(VMul);
BenchLayerNormKernel<jit::kLayerNorm, T, CPUPlace>(); BENCH_FP32_CPU(VAdd);
} BENCH_FP32_CPU(VAddRelu);
BENCH_FP32_CPU(VSub);
// crfdecoding // axyn
BENCH_FP32_CPU(kCRFDecoding) { BENCH_FP32_CPU(VScal);
BenchCRFDecodingKernel<jit::kCRFDecoding, T, CPUPlace>(); BENCH_FP32_CPU(VAddBias);
}
// vbroadcast function // xyn
BENCH_FP32_CPU(kVBroadcast) { BENCH_FP32_CPU(VRelu);
BenchVBroadcastKernel<jit::kVBroadcast, T, CPUPlace>(); BENCH_FP32_CPU(VIdentity);
} BENCH_FP32_CPU(VSquare);
BENCH_FP32_CPU(VExp);
BENCH_FP32_CPU(VSigmoid);
BENCH_FP32_CPU(VTanh);
BENCH_FP32_CPU(VCopy);
// xrn
BENCH_FP32_CPU(HMax);
BENCH_FP32_CPU(HSum);
// LSTM
BENCH_FP32_CPU(LSTMCtHt);
BENCH_FP32_CPU(LSTMC1H1);
// GRU
BENCH_FP32_CPU(GRUH1);
BENCH_FP32_CPU(GRUHtPart1);
BENCH_FP32_CPU(GRUHtPart2);
BENCH_FP32_CPU(LayerNorm);
BENCH_FP32_CPU(CRFDecoding);
BENCH_FP32_CPU(SeqPool);
BENCH_FP32_CPU(EmbSeqPool);
BENCH_FP32_CPU(MatMul);
BENCH_FP32_CPU(Softmax);
BENCH_FP32_CPU(Sgd);
BENCH_FP32_CPU(VBroadcast);
// Benchmark all jit kernels including jitcode, mkl and refer. // Benchmark all jit kernels including jitcode, mkl and refer.
// To use this tool, run command: ./benchmark [options...] // To use this tool, run command: ./benchmark [options...]
......
...@@ -19,6 +19,8 @@ extern "C" { ...@@ -19,6 +19,8 @@ extern "C" {
} }
#include <iostream> #include <iostream>
#include <string> #include <string>
#include <unordered_map>
#include <utility> // for std::move
#include <vector> #include <vector>
#include "paddle/fluid/operators/jit/gen_base.h" #include "paddle/fluid/operators/jit/gen_base.h"
#include "paddle/fluid/operators/jit/kernel_base.h" #include "paddle/fluid/operators/jit/kernel_base.h"
...@@ -30,22 +32,22 @@ namespace paddle { ...@@ -30,22 +32,22 @@ namespace paddle {
namespace operators { namespace operators {
namespace jit { namespace jit {
template <KernelType KT, typename KernelTuples, typename PlaceType> template <typename KernelTuple, typename PlaceType>
inline typename std::enable_if< inline typename std::enable_if<
std::is_same<typename KernelTuples::data_type, float>::value && std::is_same<typename KernelTuple::data_type, float>::value &&
std::is_same<PlaceType, platform::CPUPlace>::value, std::is_same<PlaceType, platform::CPUPlace>::value,
typename KernelTuples::func_type>::type typename KernelTuple::func_type>::type
GetJitCode(const typename KernelTuples::attr_type& attr) { GetJitCode(const typename KernelTuple::attr_type& attr) {
using Func = typename KernelTuples::func_type; using Func = typename KernelTuple::func_type;
using Attr = typename KernelTuples::attr_type; using Attr = typename KernelTuple::attr_type;
size_t key = JitCodeKey<Attr>(attr); size_t key = JitCodeKey<Attr>(attr);
auto& codes = JitCodePool<KT>().Instance(); auto& codes = JitCodePool<KernelTuple::kernel_type>().Instance();
if (codes.Has(key)) { if (codes.Has(key)) {
return codes.AllKernels().at(key)->template getCode<Func>(); return codes.AllKernels().at(key)->template getCode<Func>();
} }
// creator is not related with attr, so can use KernelKey as key // creator is not related with attr, so can use KernelKey as key
KernelKey kkey(KT, PlaceType()); KernelKey kkey(KernelTuple::kernel_type, PlaceType());
// pool: (KernelKey(type, place), vector<GenCreatorPtr>) // pool: (KernelKey(type, place), vector<GenCreatorPtr>)
auto& creator_map = JitCodeCreatorPool().Instance().AllCreators(); auto& creator_map = JitCodeCreatorPool().Instance().AllCreators();
auto iter = creator_map.find(kkey); auto iter = creator_map.find(kkey);
...@@ -66,27 +68,27 @@ GetJitCode(const typename KernelTuples::attr_type& attr) { ...@@ -66,27 +68,27 @@ GetJitCode(const typename KernelTuples::attr_type& attr) {
return nullptr; return nullptr;
} }
template <KernelType KT, typename KernelTuples, typename PlaceType> template <typename KernelTuple, typename PlaceType>
inline typename std::enable_if< inline typename std::enable_if<
!std::is_same<typename KernelTuples::data_type, float>::value || !std::is_same<typename KernelTuple::data_type, float>::value ||
!std::is_same<PlaceType, platform::CPUPlace>::value, !std::is_same<PlaceType, platform::CPUPlace>::value,
typename KernelTuples::func_type>::type typename KernelTuple::func_type>::type
GetJitCode(const typename KernelTuples::attr_type& attr) { GetJitCode(const typename KernelTuple::attr_type& attr) {
return nullptr; return nullptr;
} }
// Refer code do not related with attr, which is just for cast // Refer code do not related with attr, which is just for cast
// Refer is always on CPUPlace // Refer is always on CPUPlace
template <KernelType KT, typename KernelTuples> template <typename KernelTuple>
inline typename KernelTuples::func_type GetRefer() { inline typename KernelTuple::func_type GetRefer() {
auto& ref_pool = ReferKernelPool().Instance().AllKernels(); auto& ref_pool = ReferKernelPool().Instance().AllKernels();
KernelKey kkey(KT, platform::CPUPlace()); KernelKey kkey(KernelTuple::kernel_type, platform::CPUPlace());
auto ref_iter = ref_pool.find(kkey); auto ref_iter = ref_pool.find(kkey);
PADDLE_ENFORCE(ref_iter != ref_pool.end(), PADDLE_ENFORCE(ref_iter != ref_pool.end(),
"Every Kernel should have reference function."); "Every Kernel should have reference function.");
auto& ref_impls = ref_iter->second; auto& ref_impls = ref_iter->second;
for (auto& impl : ref_impls) { for (auto& impl : ref_impls) {
auto i = dynamic_cast<const ReferKernel<KernelTuples>*>(impl.get()); auto i = dynamic_cast<const ReferKernel<KernelTuple>*>(impl.get());
if (i) { if (i) {
return i->GetFunc(); return i->GetFunc();
} }
...@@ -94,23 +96,22 @@ inline typename KernelTuples::func_type GetRefer() { ...@@ -94,23 +96,22 @@ inline typename KernelTuples::func_type GetRefer() {
return nullptr; return nullptr;
} }
template <KernelType KT, typename KernelTuples, template <typename KernelTuple, typename PlaceType = platform::CPUPlace>
typename PlaceType = platform::CPUPlace> typename KernelTuple::func_type Get(
typename KernelTuples::func_type Get( const typename KernelTuple::attr_type& attr) {
const typename KernelTuples::attr_type& attr) { auto jitfunc = GetJitCode<KernelTuple, PlaceType>(attr);
auto jitfunc = GetJitCode<KT, KernelTuples, PlaceType>(attr);
if (jitfunc) { if (jitfunc) {
return jitfunc; return jitfunc;
} }
// pool: (KernelKey(type, place), vector<KernelPtr>) // pool: (KernelKey(type, place), vector<KernelPtr>)
KernelKey kkey(KT, PlaceType()); KernelKey kkey(KernelTuple::kernel_type, PlaceType());
auto& pool = KernelPool().Instance().AllKernels(); auto& pool = KernelPool().Instance().AllKernels();
auto iter = pool.find(kkey); auto iter = pool.find(kkey);
if (iter != pool.end()) { if (iter != pool.end()) {
auto& impls = iter->second; auto& impls = iter->second;
for (auto& impl : impls) { for (auto& impl : impls) {
auto i = dynamic_cast<const KernelMore<KernelTuples>*>(impl.get()); auto i = dynamic_cast<const KernelMore<KernelTuple>*>(impl.get());
if (i && i->UseMe(attr)) { if (i && i->UseMe(attr)) {
return i->GetFunc(); return i->GetFunc();
} }
...@@ -118,48 +119,50 @@ typename KernelTuples::func_type Get( ...@@ -118,48 +119,50 @@ typename KernelTuples::func_type Get(
} }
// The last implementation should be reference function on CPUPlace. // The last implementation should be reference function on CPUPlace.
return GetRefer<KT, KernelTuples>(); return GetRefer<KernelTuple>();
} }
template <KernelType KT, typename KernelTuples, typename PlaceType> template <typename KernelTuple, typename PlaceType>
class KernelFuncs { class KernelFuncs {
public: public:
KernelFuncs() = default; KernelFuncs() = default;
static KernelFuncs& Cache() { static KernelFuncs& Cache() {
static thread_local KernelFuncs<KT, KernelTuples, PlaceType> g_func_cache; static thread_local KernelFuncs<KernelTuple, PlaceType> g_func_cache;
return g_func_cache; return g_func_cache;
} }
// the exposed interface to use // the exposed interface to use
typename KernelTuples::func_type At( typename KernelTuple::func_type At(
const typename KernelTuples::attr_type& attr) { const typename KernelTuple::attr_type& attr) {
// XXH64: 13.8 GB/s // XXH64: 13.8 GB/s
int64_t key = XXH64(&attr, sizeof(typename KernelTuples::attr_type), 0); // TODO(TJ): change me, maybe not all attr change need one key, should be
// attrkey
int64_t key = XXH64(&attr, sizeof(typename KernelTuple::attr_type), 0);
if (Has(key)) { if (Has(key)) {
return funcs_.at(key); return funcs_.at(key);
} }
// If do not have this attr in cache, // If do not have this attr in cache,
// then could run some runtime benchmark of this attr and save the best one. // then could run some runtime benchmark of this attr and save the best one.
// Here just get the offline benchmarked best one. // Here just get the offline benchmarked best one.
auto func = Get<KT, KernelTuples, PlaceType>(attr); auto func = Get<KernelTuple, PlaceType>(attr);
Insert(key, func); Insert(key, func);
return func; return func;
} }
typename KernelTuples::func_type operator[]( typename KernelTuple::func_type operator[](
const typename KernelTuples::attr_type& attr) { const typename KernelTuple::attr_type& attr) {
return At(attr); return At(attr);
} }
protected: protected:
bool Has(int64_t key) const { return funcs_.find(key) != funcs_.end(); } bool Has(int64_t key) const { return funcs_.find(key) != funcs_.end(); }
void Insert(int64_t key, typename KernelTuples::func_type func) { void Insert(int64_t key, typename KernelTuple::func_type func) {
funcs_.emplace(key, func); funcs_.emplace(key, func);
} }
private: private:
std::unordered_map<int64_t, typename KernelTuples::func_type> funcs_; std::unordered_map<int64_t, typename KernelTuple::func_type> funcs_;
DISABLE_COPY_AND_ASSIGN(KernelFuncs); DISABLE_COPY_AND_ASSIGN(KernelFuncs);
}; };
......
...@@ -62,26 +62,55 @@ typedef enum { ...@@ -62,26 +62,55 @@ typedef enum {
kSqrt, kSqrt,
} SeqPoolType; } SeqPoolType;
// x, y, z, n
template <typename T> template <typename T>
struct XYZNTuples { struct XYZNTuple {
typedef T data_type; typedef T data_type;
typedef int attr_type; typedef int attr_type;
typedef void (*func_type)(const T*, const T*, T*, int); typedef void (*func_type)(const T*, const T*, T*, int);
}; };
// a, x, y, n
template <typename T> template <typename T>
struct AXYNTuples : public XYZNTuples<T> {}; struct AXYNTuple : public XYZNTuple<T> {};
// x, y, n
template <typename T> template <typename T>
struct XYNTuples { struct XYNTuple {
typedef T data_type; typedef T data_type;
typedef int attr_type; typedef int attr_type;
typedef void (*func_type)(const T*, T*, int); typedef void (*func_type)(const T*, T*, int);
}; };
// x, return and int // x, returned value, n
template <typename T> template <typename T>
struct XRNTuples : public XYNTuples<T> {}; struct XRNTuple : public XYNTuple<T> {};
#define DECLARE_KERNELTUPLE(kernel_tuple, type) \
template <typename T> \
struct type##Tuple : public kernel_tuple<T> { \
static constexpr KernelType kernel_type = k##type; \
}
// Tuple should be corresponding to the KernelType
DECLARE_KERNELTUPLE(XYZNTuple, VMul);
DECLARE_KERNELTUPLE(XYZNTuple, VAdd);
DECLARE_KERNELTUPLE(XYZNTuple, VAddRelu);
DECLARE_KERNELTUPLE(XYZNTuple, VSub);
DECLARE_KERNELTUPLE(AXYNTuple, VScal);
DECLARE_KERNELTUPLE(AXYNTuple, VAddBias);
DECLARE_KERNELTUPLE(XYNTuple, VRelu);
DECLARE_KERNELTUPLE(XYNTuple, VIdentity);
DECLARE_KERNELTUPLE(XYNTuple, VSquare);
DECLARE_KERNELTUPLE(XYNTuple, VExp);
DECLARE_KERNELTUPLE(XYNTuple, VSigmoid);
DECLARE_KERNELTUPLE(XYNTuple, VTanh);
DECLARE_KERNELTUPLE(XYNTuple, VCopy);
DECLARE_KERNELTUPLE(XRNTuple, HMax);
DECLARE_KERNELTUPLE(XRNTuple, HSum);
typedef struct { typedef struct {
void* gates; // gates: x_ch, x_ih, x_fh, x_oh void* gates; // gates: x_ch, x_ih, x_fh, x_oh
...@@ -122,21 +151,31 @@ typedef struct rnn_attr_s gru_attr_t; ...@@ -122,21 +151,31 @@ typedef struct rnn_attr_s gru_attr_t;
typedef struct lstm_attr_s lstm_attr_t; typedef struct lstm_attr_s lstm_attr_t;
template <typename T> template <typename T>
struct LSTMTuples { struct LSTMTuple {
typedef T data_type; typedef T data_type;
typedef lstm_attr_t attr_type; typedef lstm_attr_t attr_type;
typedef void (*func_type)(lstm_t*, const lstm_attr_t*); typedef void (*func_type)(lstm_t*, const lstm_attr_t*);
}; };
template <typename T> template <typename T>
struct GRUTuples { struct GRUTuple {
typedef T data_type; typedef T data_type;
typedef gru_attr_t attr_type; typedef gru_attr_t attr_type;
typedef void (*func_type)(gru_t*, const gru_attr_t*); typedef void (*func_type)(gru_t*, const gru_attr_t*);
}; };
DECLARE_KERNELTUPLE(LSTMTuple, LSTMCtHt);
DECLARE_KERNELTUPLE(LSTMTuple, LSTMC1H1);
DECLARE_KERNELTUPLE(GRUTuple, GRUH1);
DECLARE_KERNELTUPLE(GRUTuple, GRUHtPart1);
DECLARE_KERNELTUPLE(GRUTuple, GRUHtPart2);
#undef DECLARE_KERNELTUPLE
template <typename T> template <typename T>
struct VBroadcastTuples { struct VBroadcastTuple {
static constexpr KernelType kernel_type = kVBroadcast;
typedef T data_type; typedef T data_type;
typedef int64_t attr_type; typedef int64_t attr_type;
typedef void (*func_type)(const T*, T*, int64_t, int64_t); typedef void (*func_type)(const T*, T*, int64_t, int64_t);
...@@ -151,7 +190,8 @@ typedef struct seq_pool_attr_s { ...@@ -151,7 +190,8 @@ typedef struct seq_pool_attr_s {
} seq_pool_attr_t; } seq_pool_attr_t;
template <typename T> template <typename T>
struct SeqPoolTuples { struct SeqPoolTuple {
static constexpr KernelType kernel_type = kSeqPool;
typedef T data_type; typedef T data_type;
typedef seq_pool_attr_t attr_type; typedef seq_pool_attr_t attr_type;
typedef void (*func_type)(const T*, T*, const seq_pool_attr_t*); typedef void (*func_type)(const T*, T*, const seq_pool_attr_t*);
...@@ -176,7 +216,8 @@ typedef struct emb_seq_pool_attr_s { ...@@ -176,7 +216,8 @@ typedef struct emb_seq_pool_attr_s {
} emb_seq_pool_attr_t; } emb_seq_pool_attr_t;
template <typename T> template <typename T>
struct EmbSeqPoolTuples { struct EmbSeqPoolTuple {
static constexpr KernelType kernel_type = kEmbSeqPool;
typedef T data_type; typedef T data_type;
typedef emb_seq_pool_attr_t attr_type; typedef emb_seq_pool_attr_t attr_type;
typedef void (*func_type)(const T*, const int64_t*, T*, typedef void (*func_type)(const T*, const int64_t*, T*,
...@@ -198,7 +239,8 @@ typedef struct sgd_attr_s { ...@@ -198,7 +239,8 @@ typedef struct sgd_attr_s {
} sgd_attr_t; } sgd_attr_t;
template <typename T> template <typename T>
struct SgdTuples { struct SgdTuple {
static constexpr KernelType kernel_type = kSgd;
typedef T data_type; typedef T data_type;
typedef sgd_attr_t attr_type; typedef sgd_attr_t attr_type;
typedef void (*func_type)(const T*, const T*, const T*, const int64_t*, T*, typedef void (*func_type)(const T*, const T*, const T*, const int64_t*, T*,
...@@ -214,21 +256,24 @@ typedef struct matmul_attr_s { ...@@ -214,21 +256,24 @@ typedef struct matmul_attr_s {
} matmul_attr_t; } matmul_attr_t;
template <typename T> template <typename T>
struct MatMulTuples { struct MatMulTuple {
static constexpr KernelType kernel_type = kMatMul;
typedef T data_type; typedef T data_type;
typedef matmul_attr_t attr_type; typedef matmul_attr_t attr_type;
typedef void (*func_type)(const T*, const T*, T*, const matmul_attr_t*); typedef void (*func_type)(const T*, const T*, T*, const matmul_attr_t*);
}; };
template <typename T> template <typename T>
struct CRFDecodingTuples { struct CRFDecodingTuple {
static constexpr KernelType kernel_type = kCRFDecoding;
typedef T data_type; typedef T data_type;
typedef int attr_type; typedef int attr_type;
typedef void (*func_type)(const int, const T*, const T*, T*, int*, int); typedef void (*func_type)(const int, const T*, const T*, T*, int*, int);
}; };
template <typename T> template <typename T>
struct LayerNormTuples { struct LayerNormTuple {
static constexpr KernelType kernel_type = kLayerNorm;
typedef T data_type; typedef T data_type;
typedef int attr_type; typedef int attr_type;
typedef void (*func_type)(T*, T*, T*, T*, const T*, const T*, int, typedef void (*func_type)(T*, T*, T*, T*, const T*, const T*, int,
...@@ -236,7 +281,8 @@ struct LayerNormTuples { ...@@ -236,7 +281,8 @@ struct LayerNormTuples {
}; };
template <typename T> template <typename T>
struct SoftmaxTuples { struct SoftmaxTuple {
static constexpr KernelType kernel_type = kSoftmax;
typedef T data_type; typedef T data_type;
typedef int attr_type; typedef int attr_type;
typedef void (*func_type)(const T*, T*, int, int); typedef void (*func_type)(const T*, T*, int, int);
...@@ -244,7 +290,8 @@ struct SoftmaxTuples { ...@@ -244,7 +290,8 @@ struct SoftmaxTuples {
// nChw16c = nChw16c .* NC // nChw16c = nChw16c .* NC
template <typename T> template <typename T>
struct NCHW16CMulNCTuples { struct NCHW16CMulNCTuple {
static constexpr KernelType kernel_type = kNCHW16CMulNC;
typedef T data_type; typedef T data_type;
typedef int attr_type; typedef int attr_type;
typedef void (*func_type)(const T*, const T*, T*, int, int); typedef void (*func_type)(const T*, const T*, T*, int, int);
...@@ -258,12 +305,12 @@ class Kernel { ...@@ -258,12 +305,12 @@ class Kernel {
DISABLE_COPY_AND_ASSIGN(Kernel); DISABLE_COPY_AND_ASSIGN(Kernel);
}; };
template <typename KernelTuples> template <typename KernelTuple>
class KernelMore : public Kernel { class KernelMore : public Kernel {
public: public:
using T = typename KernelTuples::data_type; using T = typename KernelTuple::data_type;
using Func = typename KernelTuples::func_type; using Func = typename KernelTuple::func_type;
using Attr = typename KernelTuples::attr_type; using Attr = typename KernelTuple::attr_type;
virtual Func GetFunc() const { return func; } virtual Func GetFunc() const { return func; }
virtual bool UseMe(const Attr& attr) const = 0; virtual bool UseMe(const Attr& attr) const = 0;
virtual const char* ImplType() const = 0; virtual const char* ImplType() const = 0;
...@@ -272,11 +319,11 @@ class KernelMore : public Kernel { ...@@ -272,11 +319,11 @@ class KernelMore : public Kernel {
Func func{nullptr}; Func func{nullptr};
}; };
template <typename KernelTuples> template <typename KernelTuple>
class ReferKernel : public KernelMore<KernelTuples> { class ReferKernel : public KernelMore<KernelTuple> {
public: public:
// Refer code can always be used // Refer code can always be used
bool UseMe(const typename KernelTuples::attr_type& attr) const override { bool UseMe(const typename KernelTuple::attr_type& attr) const override {
return true; return true;
} }
const char* ImplType() const override { return "Refer"; } const char* ImplType() const override { return "Refer"; }
......
...@@ -26,11 +26,10 @@ namespace intrinsic { ...@@ -26,11 +26,10 @@ namespace intrinsic {
void CRFDecoding(const int seq_len, const float* x, const float* w, void CRFDecoding(const int seq_len, const float* x, const float* w,
float* alpha, int* track, int tag_num); float* alpha, int* track, int tag_num);
class CRFDecodingKernel : public KernelMore<CRFDecodingTuples<float>> { class CRFDecodingKernel : public KernelMore<CRFDecodingTuple<float>> {
public: public:
CRFDecodingKernel() { this->func = CRFDecoding; } CRFDecodingKernel() { this->func = CRFDecoding; }
bool UseMe( bool UseMe(const typename CRFDecodingTuple<float>::attr_type&) const override;
const typename CRFDecodingTuples<float>::attr_type&) const override;
const char* ImplType() const override { return "Intrinsic"; } const char* ImplType() const override { return "Intrinsic"; }
}; };
......
...@@ -27,10 +27,10 @@ void LayerNorm(float* x, float* out, float* mean, float* var, ...@@ -27,10 +27,10 @@ void LayerNorm(float* x, float* out, float* mean, float* var,
const float* scale, const float* bias, int height, const float* scale, const float* bias, int height,
const float epsilon, int right); const float epsilon, int right);
class LayerNormKernel : public KernelMore<LayerNormTuples<float>> { class LayerNormKernel : public KernelMore<LayerNormTuple<float>> {
public: public:
LayerNormKernel() { this->func = LayerNorm; } LayerNormKernel() { this->func = LayerNorm; }
bool UseMe(const typename LayerNormTuples<float>::attr_type&) const override; bool UseMe(const typename LayerNormTuple<float>::attr_type&) const override;
const char* ImplType() const override { return "Intrinsic"; } const char* ImplType() const override { return "Intrinsic"; }
}; };
......
...@@ -23,6 +23,8 @@ namespace jit { ...@@ -23,6 +23,8 @@ namespace jit {
namespace more { namespace more {
namespace mix { namespace mix {
using CPUPlace = platform::CPUPlace;
void VSigmoid(const T* x, T* y, int n) { void VSigmoid(const T* x, T* y, int n) {
const float min = SIGMOID_THRESHOLD_MIN; const float min = SIGMOID_THRESHOLD_MIN;
const float max = SIGMOID_THRESHOLD_MAX; const float max = SIGMOID_THRESHOLD_MAX;
...@@ -30,7 +32,7 @@ void VSigmoid(const T* x, T* y, int n) { ...@@ -30,7 +32,7 @@ void VSigmoid(const T* x, T* y, int n) {
y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = static_cast<T>(0) - y[i]; y[i] = static_cast<T>(0) - y[i];
} }
auto compute = Get<KernelType::kVExp, XYNTuples<T>, platform::CPUPlace>(n); auto compute = KernelFuncs<VExpTuple<T>, CPUPlace>::Cache().At(n);
compute(y, y, n); compute(y, y, n);
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(1) / (static_cast<T>(1) + y[i]); y[i] = static_cast<T>(1) / (static_cast<T>(1) + y[i]);
...@@ -39,9 +41,9 @@ void VSigmoid(const T* x, T* y, int n) { ...@@ -39,9 +41,9 @@ void VSigmoid(const T* x, T* y, int n) {
void VTanh(const T* x, T* y, int n) { void VTanh(const T* x, T* y, int n) {
const T a = 2, b = -1; const T a = 2, b = -1;
auto compute_scal = Get<kVScal, AXYNTuples<T>, platform::CPUPlace>(n); auto compute_scal = KernelFuncs<VScalTuple<T>, CPUPlace>::Cache().At(n);
auto compute_addbias = Get<kVAddBias, AXYNTuples<T>, platform::CPUPlace>(n); auto compute_addbias = KernelFuncs<VAddBiasTuple<T>, CPUPlace>::Cache().At(n);
auto compute_sigmoid = Get<kVSigmoid, XYNTuples<T>, platform::CPUPlace>(n); auto compute_sigmoid = KernelFuncs<VSigmoidTuple<T>, CPUPlace>::Cache().At(n);
compute_scal(&a, x, y, n); compute_scal(&a, x, y, n);
compute_sigmoid(y, y, n); compute_sigmoid(y, y, n);
compute_scal(&a, y, y, n); compute_scal(&a, y, y, n);
...@@ -49,16 +51,12 @@ void VTanh(const T* x, T* y, int n) { ...@@ -49,16 +51,12 @@ void VTanh(const T* x, T* y, int n) {
} }
void Softmax(const T* x, T* y, int n, int bs) { void Softmax(const T* x, T* y, int n, int bs) {
auto compute_hmax = auto compute_hmax = KernelFuncs<HMaxTuple<T>, CPUPlace>::Cache().At(n);
KernelFuncs<kHMax, XRNTuples<T>, platform::CPUPlace>::Cache().At(n); auto compute_hsum = KernelFuncs<HSumTuple<T>, CPUPlace>::Cache().At(n);
auto compute_hsum = auto compute_vscal = KernelFuncs<VScalTuple<T>, CPUPlace>::Cache().At(n);
KernelFuncs<kHSum, XRNTuples<T>, platform::CPUPlace>::Cache().At(n);
auto compute_vscal =
KernelFuncs<kVScal, AXYNTuples<T>, platform::CPUPlace>::Cache().At(n);
auto compute_vaddbias = auto compute_vaddbias =
KernelFuncs<kVAddBias, AXYNTuples<T>, platform::CPUPlace>::Cache().At(n); KernelFuncs<VAddBiasTuple<T>, CPUPlace>::Cache().At(n);
auto compute_vexp = auto compute_vexp = KernelFuncs<VExpTuple<T>, CPUPlace>::Cache().At(n);
KernelFuncs<kVExp, XYNTuples<T>, platform::CPUPlace>::Cache().At(n);
for (int i = 0; i < bs; ++i) { for (int i = 0; i < bs; ++i) {
T scalar; T scalar;
...@@ -76,13 +74,13 @@ void Softmax(const T* x, T* y, int n, int bs) { ...@@ -76,13 +74,13 @@ void Softmax(const T* x, T* y, int n, int bs) {
void (*getActFunc(KernelType type, int d))(const T*, T*, int) { // NOLINT void (*getActFunc(KernelType type, int d))(const T*, T*, int) { // NOLINT
if (type == kVSigmoid) { if (type == kVSigmoid) {
return Get<kVSigmoid, XYNTuples<T>, platform::CPUPlace>(d); return KernelFuncs<VSigmoidTuple<T>, CPUPlace>::Cache().At(d);
} else if (type == kVRelu) { } else if (type == kVRelu) {
return Get<kVRelu, XYNTuples<T>, platform::CPUPlace>(d); return KernelFuncs<VReluTuple<T>, CPUPlace>::Cache().At(d);
} else if (type == kVTanh) { } else if (type == kVTanh) {
return Get<kVTanh, XYNTuples<T>, platform::CPUPlace>(d); return KernelFuncs<VTanhTuple<T>, CPUPlace>::Cache().At(d);
} else if (type == kVIdentity) { } else if (type == kVIdentity) {
return Get<kVIdentity, XYNTuples<T>, platform::CPUPlace>(d); return KernelFuncs<VIdentityTuple<T>, CPUPlace>::Cache().At(d);
} }
PADDLE_THROW("Not support type: %s", type); PADDLE_THROW("Not support type: %s", type);
return nullptr; return nullptr;
...@@ -98,9 +96,9 @@ void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) { ...@@ -98,9 +96,9 @@ void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) {
const int d = attr->d; const int d = attr->d;
const int d2 = d * 2; const int d2 = d * 2;
const int d3 = d * 3; const int d3 = d * 3;
auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(d); auto vmul_d = KernelFuncs<VMulTuple<T>, CPUPlace>::Cache().At(d);
auto vadd_d = Get<kVAdd, XYZNTuples<T>, platform::CPUPlace>(d); auto vadd_d = KernelFuncs<VAddTuple<T>, CPUPlace>::Cache().At(d);
auto vadd_d2 = Get<kVAdd, XYZNTuples<T>, platform::CPUPlace>(d2); auto vadd_d2 = KernelFuncs<VAddTuple<T>, CPUPlace>::Cache().At(d2);
auto act_gate_d = getActFunc(attr->act_gate, d); auto act_gate_d = getActFunc(attr->act_gate, d);
auto act_gate_d2 = getActFunc(attr->act_gate, d2); auto act_gate_d2 = getActFunc(attr->act_gate, d2);
auto act_gate_d3 = getActFunc(attr->act_gate, d3); auto act_gate_d3 = getActFunc(attr->act_gate, d3);
...@@ -140,8 +138,8 @@ void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr) { ...@@ -140,8 +138,8 @@ void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr) {
int d = attr->d; int d = attr->d;
int d2 = d * 2; int d2 = d * 2;
int d3 = d * 3; int d3 = d * 3;
auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(d); auto vmul_d = KernelFuncs<VMulTuple<T>, CPUPlace>::Cache().At(d);
auto vadd_d = Get<kVAdd, XYZNTuples<T>, platform::CPUPlace>(d); auto vadd_d = KernelFuncs<VAddTuple<T>, CPUPlace>::Cache().At(d);
auto act_gate_d = getActFunc(attr->act_gate, d); auto act_gate_d = getActFunc(attr->act_gate, d);
auto act_cand_d = getActFunc(attr->act_cand, d); auto act_cand_d = getActFunc(attr->act_cand, d);
auto act_cell_d = getActFunc(attr->act_cell, d); auto act_cell_d = getActFunc(attr->act_cell, d);
...@@ -169,7 +167,7 @@ void GRUH1(gru_t* step, const gru_attr_t* attr) { ...@@ -169,7 +167,7 @@ void GRUH1(gru_t* step, const gru_attr_t* attr) {
int d2 = d * 2; int d2 = d * 2;
auto act_gate = getActFunc(attr->act_gate, d); auto act_gate = getActFunc(attr->act_gate, d);
auto act_cand = getActFunc(attr->act_cand, d); auto act_cand = getActFunc(attr->act_cand, d);
auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(d); auto vmul_d = KernelFuncs<VMulTuple<T>, CPUPlace>::Cache().At(d);
act_gate(gates, gates, d); act_gate(gates, gates, d);
act_cand(gates + d2, gates + d2, d); act_cand(gates + d2, gates + d2, d);
vmul_d(gates, gates + d2, ht, d); vmul_d(gates, gates + d2, ht, d);
...@@ -182,7 +180,7 @@ void GRUHtPart1(gru_t* step, const gru_attr_t* attr) { ...@@ -182,7 +180,7 @@ void GRUHtPart1(gru_t* step, const gru_attr_t* attr) {
T* ht = reinterpret_cast<T*>(step->ht); T* ht = reinterpret_cast<T*>(step->ht);
const T* ht_1 = reinterpret_cast<const T*>(step->ht_1); const T* ht_1 = reinterpret_cast<const T*>(step->ht_1);
auto act_gate = getActFunc(attr->act_gate, attr->d); auto act_gate = getActFunc(attr->act_gate, attr->d);
auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(attr->d); auto vmul_d = KernelFuncs<VMulTuple<T>, CPUPlace>::Cache().At(attr->d);
act_gate(gates + attr->d, gates + attr->d, attr->d); act_gate(gates + attr->d, gates + attr->d, attr->d);
vmul_d(ht_1, gates + attr->d, ht, attr->d); vmul_d(ht_1, gates + attr->d, ht, attr->d);
} }
...@@ -230,16 +228,16 @@ bool GRUHtPart2Kernel::UseMe(const gru_attr_t& attr) const { return true; } ...@@ -230,16 +228,16 @@ bool GRUHtPart2Kernel::UseMe(const gru_attr_t& attr) const { return true; }
namespace mix = paddle::operators::jit::more::mix; namespace mix = paddle::operators::jit::more::mix;
#define REGISTER_MORE_KERNEL(key, func) \ #define REGISTER_MORE_KERNEL(func) \
REGISTER_JITKERNEL_MORE(key, mix, mix::func##Kernel) REGISTER_JITKERNEL_MORE(k##func, mix, mix::func##Kernel)
REGISTER_MORE_KERNEL(kVSigmoid, VSigmoid); REGISTER_MORE_KERNEL(VSigmoid);
REGISTER_MORE_KERNEL(kVTanh, VTanh); REGISTER_MORE_KERNEL(VTanh);
REGISTER_MORE_KERNEL(kSoftmax, Softmax); REGISTER_MORE_KERNEL(Softmax);
REGISTER_MORE_KERNEL(kLSTMCtHt, LSTMCtHt); REGISTER_MORE_KERNEL(LSTMCtHt);
REGISTER_MORE_KERNEL(kLSTMC1H1, LSTMC1H1); REGISTER_MORE_KERNEL(LSTMC1H1);
REGISTER_MORE_KERNEL(kGRUH1, GRUH1); REGISTER_MORE_KERNEL(GRUH1);
REGISTER_MORE_KERNEL(kGRUHtPart1, GRUHtPart1); REGISTER_MORE_KERNEL(GRUHtPart1);
REGISTER_MORE_KERNEL(kGRUHtPart2, GRUHtPart2); REGISTER_MORE_KERNEL(GRUHtPart2);
#undef REGISTER_MORE_KERNEL #undef REGISTER_MORE_KERNEL
...@@ -34,27 +34,27 @@ void GRUH1(gru_t* step, const gru_attr_t* attr); ...@@ -34,27 +34,27 @@ void GRUH1(gru_t* step, const gru_attr_t* attr);
void GRUHtPart1(gru_t* step, const gru_attr_t* attr); void GRUHtPart1(gru_t* step, const gru_attr_t* attr);
void GRUHtPart2(gru_t* step, const gru_attr_t* attr); void GRUHtPart2(gru_t* step, const gru_attr_t* attr);
#define DECLARE_MORE_KERNEL(name, tuples) \ #define DECLARE_MORE_KERNEL(name) \
class name##Kernel : public KernelMore<tuples<T>> { \ class name##Kernel : public KernelMore<name##Tuple<T>> { \
public: \ public: \
name##Kernel() { this->func = name; } \ name##Kernel() { this->func = name; } \
bool UseMe(const typename tuples<T>::attr_type&) const override; \ bool UseMe(const typename name##Tuple<T>::attr_type&) const override; \
const char* ImplType() const override { return "Mixed"; } \ const char* ImplType() const override { return "Mixed"; } \
} }
// XYN // XYN
DECLARE_MORE_KERNEL(VSigmoid, XYNTuples); DECLARE_MORE_KERNEL(VSigmoid);
DECLARE_MORE_KERNEL(VTanh, XYNTuples); DECLARE_MORE_KERNEL(VTanh);
// XRN // XRN
DECLARE_MORE_KERNEL(Softmax, SoftmaxTuples); DECLARE_MORE_KERNEL(Softmax);
DECLARE_MORE_KERNEL(LSTMCtHt, LSTMTuples); DECLARE_MORE_KERNEL(LSTMCtHt);
DECLARE_MORE_KERNEL(LSTMC1H1, LSTMTuples); DECLARE_MORE_KERNEL(LSTMC1H1);
DECLARE_MORE_KERNEL(GRUH1, GRUTuples); DECLARE_MORE_KERNEL(GRUH1);
DECLARE_MORE_KERNEL(GRUHtPart1, GRUTuples); DECLARE_MORE_KERNEL(GRUHtPart1);
DECLARE_MORE_KERNEL(GRUHtPart2, GRUTuples); DECLARE_MORE_KERNEL(GRUHtPart2);
#undef DECLARE_MORE_KERNEL #undef DECLARE_MORE_KERNEL
......
...@@ -250,23 +250,23 @@ AWALYS_USE_ME_WITH_DOUBLE(Softmax); ...@@ -250,23 +250,23 @@ AWALYS_USE_ME_WITH_DOUBLE(Softmax);
namespace mkl = paddle::operators::jit::more::mkl; namespace mkl = paddle::operators::jit::more::mkl;
#define REGISTER_MKL_KERNEL(key, func) \ #define REGISTER_MKL_KERNEL(func) \
REGISTER_JITKERNEL_MORE(key, mkl, mkl::func##Kernel<float>, \ REGISTER_JITKERNEL_MORE(k##func, mkl, mkl::func##Kernel<float>, \
mkl::func##Kernel<double>) mkl::func##Kernel<double>)
REGISTER_MKL_KERNEL(kMatMul, MatMul); REGISTER_MKL_KERNEL(MatMul);
REGISTER_MKL_KERNEL(kVMul, VMul); REGISTER_MKL_KERNEL(VMul);
REGISTER_MKL_KERNEL(kVAdd, VAdd); REGISTER_MKL_KERNEL(VAdd);
REGISTER_MKL_KERNEL(kVScal, VScal); REGISTER_MKL_KERNEL(VScal);
REGISTER_MKL_KERNEL(kVExp, VExp); REGISTER_MKL_KERNEL(VExp);
REGISTER_MKL_KERNEL(kVSquare, VSquare); REGISTER_MKL_KERNEL(VSquare);
REGISTER_MKL_KERNEL(kVCopy, VCopy); REGISTER_MKL_KERNEL(VCopy);
REGISTER_MKL_KERNEL(kVBroadcast, VBroadcast); REGISTER_MKL_KERNEL(VBroadcast);
REGISTER_MKL_KERNEL(kVSigmoid, VSigmoid); REGISTER_MKL_KERNEL(VSigmoid);
REGISTER_MKL_KERNEL(kVTanh, VTanh); REGISTER_MKL_KERNEL(VTanh);
REGISTER_MKL_KERNEL(kSeqPool, SeqPool); REGISTER_MKL_KERNEL(SeqPool);
REGISTER_MKL_KERNEL(kEmbSeqPool, EmbSeqPool); REGISTER_MKL_KERNEL(EmbSeqPool);
REGISTER_MKL_KERNEL(kSoftmax, Softmax); REGISTER_MKL_KERNEL(Softmax);
REGISTER_MKL_KERNEL(kSgd, Sgd); REGISTER_MKL_KERNEL(Sgd);
#undef REGISTER_MKL_KERNEL #undef REGISTER_MKL_KERNEL
...@@ -175,41 +175,38 @@ void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows, ...@@ -175,41 +175,38 @@ void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows,
} }
} }
#define DECLARE_MKL_KERNEL(name, tuples) \ #define DECLARE_MKL_KERNEL(name) \
template <typename T> \ template <typename T> \
class name##Kernel : public KernelMore<tuples<T>> { \ class name##Kernel : public KernelMore<name##Tuple<T>> { \
public: \ public: \
name##Kernel() { this->func = name<T>; } \ name##Kernel() { this->func = name<T>; } \
bool UseMe(const typename tuples<T>::attr_type&) const override; \ bool UseMe(const typename name##Tuple<T>::attr_type&) const override; \
const char* ImplType() const override { return "MKL"; } \ const char* ImplType() const override { return "MKL"; } \
} }
// ABCMNK // ABCMNK
DECLARE_MKL_KERNEL(MatMul, MatMulTuples); DECLARE_MKL_KERNEL(MatMul);
// XYZN // XYZN
DECLARE_MKL_KERNEL(VMul, XYZNTuples); DECLARE_MKL_KERNEL(VMul);
DECLARE_MKL_KERNEL(VAdd, XYZNTuples); DECLARE_MKL_KERNEL(VAdd);
// AXYN // AXYN
DECLARE_MKL_KERNEL(VScal, AXYNTuples); DECLARE_MKL_KERNEL(VScal);
// XYN // XYN
DECLARE_MKL_KERNEL(VExp, XYNTuples); DECLARE_MKL_KERNEL(VExp);
DECLARE_MKL_KERNEL(VSigmoid, XYNTuples); DECLARE_MKL_KERNEL(VSigmoid);
DECLARE_MKL_KERNEL(VTanh, XYNTuples); DECLARE_MKL_KERNEL(VTanh);
DECLARE_MKL_KERNEL(VSquare, XYNTuples); DECLARE_MKL_KERNEL(VSquare);
DECLARE_MKL_KERNEL(VCopy, XYNTuples); DECLARE_MKL_KERNEL(VCopy);
DECLARE_MKL_KERNEL(SeqPool, SeqPoolTuples); // others
DECLARE_MKL_KERNEL(SeqPool);
DECLARE_MKL_KERNEL(EmbSeqPool, EmbSeqPoolTuples); DECLARE_MKL_KERNEL(EmbSeqPool);
DECLARE_MKL_KERNEL(Softmax);
DECLARE_MKL_KERNEL(Softmax, SoftmaxTuples); DECLARE_MKL_KERNEL(Sgd);
DECLARE_MKL_KERNEL(VBroadcast);
DECLARE_MKL_KERNEL(Sgd, SgdTuples);
DECLARE_MKL_KERNEL(VBroadcast, VBroadcastTuples);
#undef DECLARE_MKL_KERNEL #undef DECLARE_MKL_KERNEL
......
...@@ -17,51 +17,43 @@ ...@@ -17,51 +17,43 @@
namespace refer = paddle::operators::jit::refer; namespace refer = paddle::operators::jit::refer;
#define REGISTER_REFER_KERNEL(key, func) \ #define REGISTER_REFER_KERNEL(func) \
REGISTER_JITKERNEL_REFER(key, refer::func##Kernel<float>, \ REGISTER_JITKERNEL_REFER(k##func, refer::func##Kernel<float>, \
refer::func##Kernel<double>) refer::func##Kernel<double>)
REGISTER_REFER_KERNEL(kVMul, VMul); REGISTER_REFER_KERNEL(VMul);
REGISTER_REFER_KERNEL(kVAdd, VAdd); REGISTER_REFER_KERNEL(VAdd);
REGISTER_REFER_KERNEL(kVAddRelu, VAddRelu); REGISTER_REFER_KERNEL(VAddRelu);
REGISTER_REFER_KERNEL(kVSub, VSub); REGISTER_REFER_KERNEL(VSub);
REGISTER_REFER_KERNEL(kVScal, VScal); REGISTER_REFER_KERNEL(VScal);
REGISTER_REFER_KERNEL(kVAddBias, VAddBias); REGISTER_REFER_KERNEL(VAddBias);
REGISTER_REFER_KERNEL(kVRelu, VRelu); REGISTER_REFER_KERNEL(VRelu);
REGISTER_REFER_KERNEL(kVCopy, VCopy); REGISTER_REFER_KERNEL(VCopy);
REGISTER_REFER_KERNEL(kVIdentity, VIdentity); REGISTER_REFER_KERNEL(VIdentity);
REGISTER_REFER_KERNEL(kVSquare, VSquare); REGISTER_REFER_KERNEL(VSquare);
REGISTER_REFER_KERNEL(kVExp, VExp); REGISTER_REFER_KERNEL(VExp);
REGISTER_REFER_KERNEL(kVSigmoid, VSigmoid); REGISTER_REFER_KERNEL(VSigmoid);
REGISTER_REFER_KERNEL(kVTanh, VTanh); REGISTER_REFER_KERNEL(VTanh);
REGISTER_REFER_KERNEL(kLSTMCtHt, LSTMCtHt); REGISTER_REFER_KERNEL(LSTMCtHt);
REGISTER_REFER_KERNEL(kLSTMC1H1, LSTMC1H1); REGISTER_REFER_KERNEL(LSTMC1H1);
REGISTER_REFER_KERNEL(kGRUH1, GRUH1); REGISTER_REFER_KERNEL(GRUH1);
REGISTER_REFER_KERNEL(kGRUHtPart1, GRUHtPart1); REGISTER_REFER_KERNEL(GRUHtPart1);
REGISTER_REFER_KERNEL(kGRUHtPart2, GRUHtPart2); REGISTER_REFER_KERNEL(GRUHtPart2);
REGISTER_REFER_KERNEL(kCRFDecoding, CRFDecoding); REGISTER_REFER_KERNEL(CRFDecoding);
REGISTER_REFER_KERNEL(kLayerNorm, LayerNorm); REGISTER_REFER_KERNEL(LayerNorm);
REGISTER_REFER_KERNEL(NCHW16CMulNC);
REGISTER_REFER_KERNEL(kNCHW16CMulNC, NCHW16CMulNC); REGISTER_REFER_KERNEL(SeqPool);
REGISTER_REFER_KERNEL(MatMul);
REGISTER_REFER_KERNEL(kSeqPool, SeqPool); REGISTER_REFER_KERNEL(HMax);
REGISTER_REFER_KERNEL(HSum);
REGISTER_REFER_KERNEL(kMatMul, MatMul); REGISTER_REFER_KERNEL(Softmax);
REGISTER_REFER_KERNEL(EmbSeqPool);
REGISTER_REFER_KERNEL(kHMax, HMax); REGISTER_REFER_KERNEL(Sgd);
REGISTER_REFER_KERNEL(kHSum, HSum); REGISTER_REFER_KERNEL(VBroadcast);
REGISTER_REFER_KERNEL(kSoftmax, Softmax);
REGISTER_REFER_KERNEL(kEmbSeqPool, EmbSeqPool);
REGISTER_REFER_KERNEL(kSgd, Sgd);
REGISTER_REFER_KERNEL(kVBroadcast, VBroadcast);
#undef REGISTER_REFER_KERNEL #undef REGISTER_REFER_KERNEL
...@@ -490,60 +490,54 @@ void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows, ...@@ -490,60 +490,54 @@ void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows,
} }
} }
#define DECLARE_REFER_KERNEL(name, tuples) \ #define DECLARE_REFER_KERNEL(name) \
template <typename T> \ template <typename T> \
class name##Kernel : public ReferKernel<tuples<T>> { \ class name##Kernel : public ReferKernel<name##Tuple<T>> { \
public: \ public: \
name##Kernel() { this->func = name<T>; } \ name##Kernel() { this->func = name<T>; } \
} }
// const T* x, const T* y, T* z, int n // const T* x, const T* y, T* z, int n
DECLARE_REFER_KERNEL(VMul, XYZNTuples); DECLARE_REFER_KERNEL(VMul);
DECLARE_REFER_KERNEL(VAdd, XYZNTuples); DECLARE_REFER_KERNEL(VAdd);
DECLARE_REFER_KERNEL(VAddRelu, XYZNTuples); DECLARE_REFER_KERNEL(VAddRelu);
DECLARE_REFER_KERNEL(VSub, XYZNTuples); DECLARE_REFER_KERNEL(VSub);
// const T* a, const T* x, T* y, int n // const T* a, const T* x, T* y, int n
DECLARE_REFER_KERNEL(VScal, AXYNTuples); DECLARE_REFER_KERNEL(VScal);
DECLARE_REFER_KERNEL(VAddBias, AXYNTuples); DECLARE_REFER_KERNEL(VAddBias);
// const T* x, T* y, int n // const T* x, T* y, int n
DECLARE_REFER_KERNEL(VRelu, XYNTuples); DECLARE_REFER_KERNEL(VRelu);
DECLARE_REFER_KERNEL(VIdentity, XYNTuples); DECLARE_REFER_KERNEL(VIdentity);
DECLARE_REFER_KERNEL(VExp, XYNTuples); DECLARE_REFER_KERNEL(VExp);
DECLARE_REFER_KERNEL(VSigmoid, XYNTuples); DECLARE_REFER_KERNEL(VSigmoid);
DECLARE_REFER_KERNEL(VTanh, XYNTuples); DECLARE_REFER_KERNEL(VTanh);
DECLARE_REFER_KERNEL(VSquare, XYNTuples); DECLARE_REFER_KERNEL(VSquare);
DECLARE_REFER_KERNEL(VCopy, XYNTuples); DECLARE_REFER_KERNEL(VCopy);
// lstm_t*, const lstm_attr_t* // lstm_t*, const lstm_attr_t*
DECLARE_REFER_KERNEL(LSTMCtHt, LSTMTuples); DECLARE_REFER_KERNEL(LSTMCtHt);
DECLARE_REFER_KERNEL(LSTMC1H1, LSTMTuples); DECLARE_REFER_KERNEL(LSTMC1H1);
// gru_t*, const gru_attr_t* // gru_t*, const gru_attr_t*
DECLARE_REFER_KERNEL(GRUH1, GRUTuples); DECLARE_REFER_KERNEL(GRUH1);
DECLARE_REFER_KERNEL(GRUHtPart1, GRUTuples); DECLARE_REFER_KERNEL(GRUHtPart1);
DECLARE_REFER_KERNEL(GRUHtPart2, GRUTuples); DECLARE_REFER_KERNEL(GRUHtPart2);
DECLARE_REFER_KERNEL(CRFDecoding, CRFDecodingTuples); DECLARE_REFER_KERNEL(HMax);
DECLARE_REFER_KERNEL(LayerNorm, LayerNormTuples); DECLARE_REFER_KERNEL(HSum);
DECLARE_REFER_KERNEL(NCHW16CMulNC, NCHW16CMulNCTuples); // others
DECLARE_REFER_KERNEL(CRFDecoding);
DECLARE_REFER_KERNEL(SeqPool, SeqPoolTuples); DECLARE_REFER_KERNEL(LayerNorm);
DECLARE_REFER_KERNEL(NCHW16CMulNC);
DECLARE_REFER_KERNEL(MatMul, MatMulTuples); DECLARE_REFER_KERNEL(SeqPool);
DECLARE_REFER_KERNEL(MatMul);
DECLARE_REFER_KERNEL(HMax, XRNTuples); DECLARE_REFER_KERNEL(Softmax);
DECLARE_REFER_KERNEL(HSum, XRNTuples); DECLARE_REFER_KERNEL(EmbSeqPool);
DECLARE_REFER_KERNEL(Sgd);
DECLARE_REFER_KERNEL(Softmax, SoftmaxTuples); DECLARE_REFER_KERNEL(VBroadcast);
DECLARE_REFER_KERNEL(EmbSeqPool, EmbSeqPoolTuples);
DECLARE_REFER_KERNEL(Sgd, SgdTuples);
DECLARE_REFER_KERNEL(VBroadcast, VBroadcastTuples);
#undef DECLARE_REFER_KERNEL #undef DECLARE_REFER_KERNEL
......
此差异已折叠。
...@@ -229,9 +229,9 @@ class LayerNormKernel : public framework::OpKernel<T> { ...@@ -229,9 +229,9 @@ class LayerNormKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(scale->numel(), right); PADDLE_ENFORCE_EQ(scale->numel(), right);
PADDLE_ENFORCE_EQ(bias->numel(), right); PADDLE_ENFORCE_EQ(bias->numel(), right);
auto ker = jit::KernelFuncs<jit::kLayerNorm, jit::LayerNormTuples<T>, auto ker =
platform::CPUPlace>::Cache() jit::KernelFuncs<jit::LayerNormTuple<T>, platform::CPUPlace>::Cache()
.At(right); .At(right);
ker(x.data<T>(), out.data<T>(), mean->data<T>(), var->data<T>(), ker(x.data<T>(), out.data<T>(), mean->data<T>(), var->data<T>(),
scale->data<T>(), bias->data<T>(), static_cast<int>(left), scale->data<T>(), bias->data<T>(), static_cast<int>(left),
static_cast<const float>(epsilon), right); static_cast<const float>(epsilon), right);
......
...@@ -30,17 +30,16 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M, ...@@ -30,17 +30,16 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M,
return; return;
} }
if (relu) { if (relu) {
auto compute = jit::KernelFuncs<jit::kVAddRelu, jit::XYZNTuples<T>, auto compute =
platform::CPUPlace>::Cache() jit::KernelFuncs<jit::VAddReluTuple<T>, platform::CPUPlace>::Cache().At(
.At(N); N);
for (int i = 0; i < M; i++) { for (int i = 0; i < M; i++) {
T* dst = Y + i * N; T* dst = Y + i * N;
compute(B, dst, dst, N); compute(B, dst, dst, N);
} }
} else { } else {
auto compute = jit::KernelFuncs<jit::kVAdd, jit::XYZNTuples<T>, auto compute =
platform::CPUPlace>::Cache() jit::KernelFuncs<jit::VAddTuple<T>, platform::CPUPlace>::Cache().At(N);
.At(N);
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#pragma omp parallel for #pragma omp parallel for
#endif #endif
......
...@@ -255,9 +255,9 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> { ...@@ -255,9 +255,9 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
jit::seq_pool_attr_t attr( jit::seq_pool_attr_t attr(
static_cast<int>(input.numel() / input.dims()[0]), static_cast<int>(input.numel() / input.dims()[0]),
jit::SeqPoolType::kSum); jit::SeqPoolType::kSum);
auto seqpool = jit::KernelFuncs<jit::kSeqPool, jit::SeqPoolTuples<T>, auto seqpool =
platform::CPUPlace>::Cache() jit::KernelFuncs<jit::SeqPoolTuple<T>, platform::CPUPlace>::Cache()
.At(attr); .At(attr);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) { for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
attr.h = static_cast<int>(lod[i + 1] - lod[i]); attr.h = static_cast<int>(lod[i + 1] - lod[i]);
seqpool(src, dst, &attr); seqpool(src, dst, &attr);
......
...@@ -82,8 +82,7 @@ class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> { ...@@ -82,8 +82,7 @@ class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> {
const int kClassDim = 1; const int kClassDim = 1;
// 2D data. Batch x C // 2D data. Batch x C
auto compute_softmax = auto compute_softmax =
jit::KernelFuncs<jit::kSoftmax, jit::SoftmaxTuples<float>, jit::KernelFuncs<jit::SoftmaxTuple<float>, platform::CPUPlace>::Cache()
platform::CPUPlace>::Cache()
.At(in_dims[kClassDim]); .At(in_dims[kClassDim]);
compute_softmax(in_data, out_data, in_dims[kClassDim], in_dims[kBatchDim]); compute_softmax(in_data, out_data, in_dims[kClassDim], in_dims[kBatchDim]);
} }
......
...@@ -47,9 +47,9 @@ class SGDOpKernel : public framework::OpKernel<T> { ...@@ -47,9 +47,9 @@ class SGDOpKernel : public framework::OpKernel<T> {
int64_t rows_idx = 0; int64_t rows_idx = 0;
T *out_data = param_out->mutable_data<T>(ctx.GetPlace()); T *out_data = param_out->mutable_data<T>(ctx.GetPlace());
auto sgd = jit::KernelFuncs<jit::kSgd, jit::SgdTuples<T>, auto sgd =
platform::CPUPlace>::Cache() jit::KernelFuncs<jit::SgdTuple<T>, platform::CPUPlace>::Cache().At(
.At(attr); attr);
sgd(lr, param_data, grad_data, &rows_idx, out_data, &attr); sgd(lr, param_data, grad_data, &rows_idx, out_data, &attr);
} else if (grad_var->IsType<framework::SelectedRows>()) { } else if (grad_var->IsType<framework::SelectedRows>()) {
// TODO(qijun): In Sparse SGD operator, in-place update is enforced. // TODO(qijun): In Sparse SGD operator, in-place update is enforced.
...@@ -82,9 +82,9 @@ class SGDOpKernel : public framework::OpKernel<T> { ...@@ -82,9 +82,9 @@ class SGDOpKernel : public framework::OpKernel<T> {
attr.selected_rows_size = grad_rows.size(); attr.selected_rows_size = grad_rows.size();
PADDLE_ENFORCE_EQ(attr.grad_width, attr.param_width); PADDLE_ENFORCE_EQ(attr.grad_width, attr.param_width);
auto sgd = jit::KernelFuncs<jit::kSgd, jit::SgdTuples<T>, auto sgd =
platform::CPUPlace>::Cache() jit::KernelFuncs<jit::SgdTuple<T>, platform::CPUPlace>::Cache().At(
.At(attr); attr);
sgd(lr, param_data, grad_data, rows_data, out_data, &attr); sgd(lr, param_data, grad_data, rows_data, out_data, &attr);
} else { } else {
PADDLE_THROW("Unsupported Variable Type of Grad"); PADDLE_THROW("Unsupported Variable Type of Grad");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册