From 14a764c930c4ff895168d482db51c21b6338f283 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Fri, 8 Mar 2019 08:04:04 +0000 Subject: [PATCH] simplify the jitkernel templates and tests test=develop --- paddle/fluid/operators/crf_decoding_op.h | 6 +- .../mkldnn/elementwise_mul_mkldnn_op.cc | 7 +- .../fused/fused_embedding_seq_pool_op.h | 6 +- paddle/fluid/operators/fused/fusion_gru_op.cc | 52 +- .../fluid/operators/fused/fusion_lstm_op.cc | 56 +- .../fused/fusion_repeated_fc_relu_op.cc | 12 +- .../fused/fusion_seqpool_concat_op.cc | 6 +- .../fused/fusion_squared_mat_sub_op.cc | 36 +- paddle/fluid/operators/jit/benchmark.cc | 269 ++-- paddle/fluid/operators/jit/helper.h | 71 +- paddle/fluid/operators/jit/kernel_base.h | 93 +- .../jit/more/intrinsic/crf_decoding.h | 5 +- .../operators/jit/more/intrinsic/layer_norm.h | 4 +- paddle/fluid/operators/jit/more/mix/mix.cc | 68 +- paddle/fluid/operators/jit/more/mix/mix.h | 28 +- paddle/fluid/operators/jit/more/mkl/mkl.cc | 32 +- paddle/fluid/operators/jit/more/mkl/mkl.h | 49 +- paddle/fluid/operators/jit/refer/refer.cc | 80 +- paddle/fluid/operators/jit/refer/refer.h | 80 +- paddle/fluid/operators/jit/test.cc | 1282 ++++++++--------- paddle/fluid/operators/layer_norm_op.h | 6 +- paddle/fluid/operators/math/fc_compute.h | 11 +- .../fluid/operators/math/sequence_pooling.cc | 6 +- paddle/fluid/operators/math/softmax_impl.h | 3 +- paddle/fluid/operators/optimizers/sgd_op.h | 12 +- 25 files changed, 1135 insertions(+), 1145 deletions(-) diff --git a/paddle/fluid/operators/crf_decoding_op.h b/paddle/fluid/operators/crf_decoding_op.h index 3d98790a4d..d6b54038ec 100644 --- a/paddle/fluid/operators/crf_decoding_op.h +++ b/paddle/fluid/operators/crf_decoding_op.h @@ -82,9 +82,9 @@ class CRFDecodingOpKernel : public framework::OpKernel { Tensor track; int* track_value = track.mutable_data(emission_dims, platform::CPUPlace()); - auto ker = jit::KernelFuncs, - platform::CPUPlace>::Cache() - .At(tag_num); + auto ker = + jit::KernelFuncs, platform::CPUPlace>::Cache() + .At(tag_num); ker(static_cast(seq_len), x, w, alpha_value, track_value, tag_num); T max_score = -std::numeric_limits::max(); int max_i = 0; diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc index e37bbd2837..f2f4d3fee0 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc @@ -110,10 +110,9 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel { constexpr int simd_width = 16; int C = c / simd_width; - auto multiply = - jit::KernelFuncs, - platform::CPUPlace>::Cache() - .At(0); + auto multiply = jit::KernelFuncs, + platform::CPUPlace>::Cache() + .At(0); #pragma omp parallel for collapse(2) for (int ni = 0; ni < n; ni++) { for (int ci = 0; ci < C; ci++) { diff --git a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h index fe43545e60..5e2e336e71 100644 --- a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h +++ b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h @@ -53,8 +53,7 @@ struct EmbeddingVSumFunctor { for (size_t i = 0; i != ids_lod.size() - 1; ++i) { attr.index_height = ids_lod[i + 1] - ids_lod[i]; auto emb_seqpool = - jit::KernelFuncs, - platform::CPUPlace>::Cache() + jit::KernelFuncs, platform::CPUPlace>::Cache() .At(attr); emb_seqpool(table, ids + ids_lod[i] * idx_width, output + i * out_width, &attr); @@ -138,8 +137,7 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel { const T *d_output_data = d_output->data(); auto vbroadcast = - jit::KernelFuncs, - platform::CPUPlace>::Cache() + jit::KernelFuncs, platform::CPUPlace>::Cache() .At(out_width); for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { int64_t h = static_cast(lod[i + 1] - lod[i]); diff --git a/paddle/fluid/operators/fused/fusion_gru_op.cc b/paddle/fluid/operators/fused/fusion_gru_op.cc index cd8a6a55d4..ba5f0747c4 100644 --- a/paddle/fluid/operators/fused/fusion_gru_op.cc +++ b/paddle/fluid/operators/fused/fusion_gru_op.cc @@ -182,32 +182,32 @@ class FusionGRUKernel : public framework::OpKernel { const int total_T = x_dims[0]; \ const int D3 = wh_dims[1] -#define INIT_OTHER_DEFINES \ - auto* h0 = ctx.Input("H0"); \ - auto* wx = ctx.Input("WeightX"); \ - auto* bias = ctx.Input("Bias"); \ - auto* hidden_out = ctx.Output("Hidden"); \ - bool is_reverse = ctx.Attr("is_reverse"); \ - const int M = x_dims[1]; \ - const int D = wh_dims[0]; \ - const int D2 = D * 2; \ - const jit::gru_attr_t attr( \ - D, jit::to_kerneltype(ctx.Attr("gate_activation")), \ - jit::to_kerneltype(ctx.Attr("activation"))); \ - jit::gru_t one_step; \ - auto ComputeH1 = jit::KernelFuncs, \ - platform::CPUPlace>::Cache() \ - .At(attr); \ - auto ComputeHtPart1 = jit::KernelFuncs, \ - platform::CPUPlace>::Cache() \ - .At(attr); \ - auto ComputeHtPart2 = jit::KernelFuncs, \ - platform::CPUPlace>::Cache() \ - .At(attr); \ - const T* x_data = x->data(); \ - const T* wx_data = wx->data(); \ - const T* wh_data = wh->data(); \ - auto place = ctx.GetPlace(); \ +#define INIT_OTHER_DEFINES \ + auto* h0 = ctx.Input("H0"); \ + auto* wx = ctx.Input("WeightX"); \ + auto* bias = ctx.Input("Bias"); \ + auto* hidden_out = ctx.Output("Hidden"); \ + bool is_reverse = ctx.Attr("is_reverse"); \ + const int M = x_dims[1]; \ + const int D = wh_dims[0]; \ + const int D2 = D * 2; \ + const jit::gru_attr_t attr( \ + D, jit::to_kerneltype(ctx.Attr("gate_activation")), \ + jit::to_kerneltype(ctx.Attr("activation"))); \ + jit::gru_t one_step; \ + auto ComputeH1 = \ + jit::KernelFuncs, platform::CPUPlace>::Cache().At( \ + attr); \ + auto ComputeHtPart1 = \ + jit::KernelFuncs, platform::CPUPlace>::Cache() \ + .At(attr); \ + auto ComputeHtPart2 = \ + jit::KernelFuncs, platform::CPUPlace>::Cache() \ + .At(attr); \ + const T* x_data = x->data(); \ + const T* wx_data = wx->data(); \ + const T* wh_data = wh->data(); \ + auto place = ctx.GetPlace(); \ T* xx_data = xx->mutable_data(place) void SeqCompute(const framework::ExecutionContext& ctx) const { diff --git a/paddle/fluid/operators/fused/fusion_lstm_op.cc b/paddle/fluid/operators/fused/fusion_lstm_op.cc index d7d12df4bf..c8c07bd126 100644 --- a/paddle/fluid/operators/fused/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fused/fusion_lstm_op.cc @@ -235,34 +235,34 @@ class FuisonLSTMKernel : public framework::OpKernel { const int D = wh_dims[0]; \ const int D4 = wh_dims[1] -#define INIT_OTHER_DEFINES \ - const T* x_data = x->data(); \ - const T* wx_data = wx->data(); \ - const T* wh_data = wh->data(); \ - /* diagonal weight*/ \ - const T* wp_data = bias->data() + D4; \ - /* for peephole only*/ \ - T* checked_cell_data = nullptr; \ - auto place = ctx.GetPlace(); \ - if (use_peepholes) { \ - /* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \ - auto* checked_cell = ctx.Output("CheckedCell"); \ - checked_cell_data = checked_cell->mutable_data(place); \ - } \ - const jit::lstm_attr_t attr( \ - D, jit::to_kerneltype(ctx.Attr("gate_activation")), \ - jit::to_kerneltype(ctx.Attr("candidate_activation")), \ - jit::to_kerneltype(ctx.Attr("cell_activation")), \ - use_peepholes); \ - jit::lstm_t one_step; \ - one_step.wp = wp_data; \ - one_step.checked = checked_cell_data; \ - auto ComputeC1H1 = jit::KernelFuncs, \ - platform::CPUPlace>::Cache() \ - .At(attr); \ - auto ComputeCtHt = jit::KernelFuncs, \ - platform::CPUPlace>::Cache() \ - .At(attr) +#define INIT_OTHER_DEFINES \ + const T* x_data = x->data(); \ + const T* wx_data = wx->data(); \ + const T* wh_data = wh->data(); \ + /* diagonal weight*/ \ + const T* wp_data = bias->data() + D4; \ + /* for peephole only*/ \ + T* checked_cell_data = nullptr; \ + auto place = ctx.GetPlace(); \ + if (use_peepholes) { \ + /* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \ + auto* checked_cell = ctx.Output("CheckedCell"); \ + checked_cell_data = checked_cell->mutable_data(place); \ + } \ + const jit::lstm_attr_t attr( \ + D, jit::to_kerneltype(ctx.Attr("gate_activation")), \ + jit::to_kerneltype(ctx.Attr("candidate_activation")), \ + jit::to_kerneltype(ctx.Attr("cell_activation")), \ + use_peepholes); \ + jit::lstm_t one_step; \ + one_step.wp = wp_data; \ + one_step.checked = checked_cell_data; \ + auto ComputeC1H1 = \ + jit::KernelFuncs, platform::CPUPlace>::Cache().At( \ + attr); \ + auto ComputeCtHt = \ + jit::KernelFuncs, platform::CPUPlace>::Cache().At( \ + attr) // Wh GEMM #define GEMM_WH_ADDON(bs, prev, out) \ diff --git a/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc b/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc index e057724b5a..6be35de65f 100644 --- a/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc +++ b/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc @@ -81,12 +81,12 @@ void FusionRepeatedFCReluOpMaker::Make() { template static void fc_relu(const T* x, const T* w, const T* b, T* y, const jit::matmul_attr_t& attr) { - auto matmul = jit::KernelFuncs, - platform::CPUPlace>::Cache() - .At(attr); - auto addbias_relu = jit::KernelFuncs, - platform::CPUPlace>::Cache() - .At(attr.n); + auto matmul = + jit::KernelFuncs, platform::CPUPlace>::Cache().At( + attr); + auto addbias_relu = + jit::KernelFuncs, platform::CPUPlace>::Cache().At( + attr.n); matmul(x, w, y, &attr); T* dst = y; for (int i = 0; i < attr.m; ++i) { diff --git a/paddle/fluid/operators/fused/fusion_seqpool_concat_op.cc b/paddle/fluid/operators/fused/fusion_seqpool_concat_op.cc index 7aeeabc512..25916768c0 100644 --- a/paddle/fluid/operators/fused/fusion_seqpool_concat_op.cc +++ b/paddle/fluid/operators/fused/fusion_seqpool_concat_op.cc @@ -97,9 +97,9 @@ class FusionSeqPoolConcatKernel : public framework::OpKernel { } else if (pooltype == "SQRT") { attr.type = jit::SeqPoolType::kSqrt; } - auto seqpool = jit::KernelFuncs, - platform::CPUPlace>::Cache() - .At(attr); + auto seqpool = + jit::KernelFuncs, platform::CPUPlace>::Cache().At( + attr); size_t n = ins.size(); size_t dst_step_size = n * w; for (size_t i = 0; i < n; ++i) { diff --git a/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc b/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc index 9382bf0ebb..53679ebdde 100644 --- a/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc +++ b/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc @@ -93,24 +93,24 @@ class FusionSquaredMatSubKernel : public framework::OpKernel { attr.n = y_dims[1]; int o_numel = attr.m * attr.n; - auto vsquare_x = jit::KernelFuncs, - platform::CPUPlace>::Cache() - .At(attr.m * attr.k); - auto vsquare_y = jit::KernelFuncs, - platform::CPUPlace>::Cache() - .At(attr.k * attr.n); - auto vsquare_xy = jit::KernelFuncs, - platform::CPUPlace>::Cache() - .At(o_numel); - auto vsub = jit::KernelFuncs, - platform::CPUPlace>::Cache() - .At(o_numel); - auto vscal = jit::KernelFuncs, - platform::CPUPlace>::Cache() - .At(o_numel); - auto matmul = jit::KernelFuncs, - platform::CPUPlace>::Cache() - .At(attr); + auto vsquare_x = + jit::KernelFuncs, platform::CPUPlace>::Cache().At( + attr.m * attr.k); + auto vsquare_y = + jit::KernelFuncs, platform::CPUPlace>::Cache().At( + attr.k * attr.n); + auto vsquare_xy = + jit::KernelFuncs, platform::CPUPlace>::Cache().At( + o_numel); + auto vsub = + jit::KernelFuncs, platform::CPUPlace>::Cache().At( + o_numel); + auto vscal = + jit::KernelFuncs, platform::CPUPlace>::Cache().At( + o_numel); + auto matmul = + jit::KernelFuncs, platform::CPUPlace>::Cache().At( + attr); const T* x_data = x->data(); const T* y_data = y->data(); diff --git a/paddle/fluid/operators/jit/benchmark.cc b/paddle/fluid/operators/jit/benchmark.cc index deb96ee6cd..773cf38eb9 100644 --- a/paddle/fluid/operators/jit/benchmark.cc +++ b/paddle/fluid/operators/jit/benchmark.cc @@ -59,8 +59,6 @@ BenchJITKernel* InsertBenchmark(BenchJITKernel* b) { InsertBenchmark(new BenchJITKernel_##name##_##dtype##_##place##_()); \ void BenchJITKernel_##name##_##dtype##_##place##_::Run() -#define BENCH_FP32_CPU(name) BENCH_JITKERNEL(name, FP32, CPU) - void RUN_ALL_BENCHMARK() { for (auto p : g_all_benchmarks) { if (!FLAGS_filter.empty() && FLAGS_filter != p->Name()) { @@ -90,11 +88,11 @@ std::vector TestSizes() { return s; } -template +template struct BenchFunc { // return this function avg time // TODO(TJ): clear cache every time - double operator()(const typename KernelTuples::func_type tgt, Args... args) { + double operator()(const typename KernelTuple::func_type tgt, Args... args) { for (int i = 0; i < FLAGS_burning; ++i) { tgt(args...); } @@ -109,31 +107,30 @@ struct BenchFunc { namespace jit = paddle::operators::jit; -template -void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) { - BenchFunc benchmark; +template +void BenchAllImpls(const typename KernelTuple::attr_type& attr, Args... args) { + BenchFunc benchmark; std::vector> infos; // test refer - auto refer = jit::GetRefer(); + auto refer = jit::GetRefer(); if (!refer) { LOG(FATAL) << "Refer can not be empty!"; } infos.push_back(std::make_pair("Refer", benchmark(refer, args...))); // test jitcode - auto jitcode = jit::GetJitCode(attr); + auto jitcode = jit::GetJitCode(attr); if (jitcode) { infos.push_back(std::make_pair("JitCode", benchmark(jitcode, args...))); } // test all impls in more - jit::KernelKey kkey(KT, PlaceType()); + jit::KernelKey kkey(KernelTuple::kernel_type, PlaceType()); auto& pool = jit::KernelPool().Instance().AllKernels(); auto iter = pool.find(kkey); if (iter != pool.end()) { auto& impls = iter->second; for (auto& impl : impls) { - auto i = dynamic_cast*>(impl.get()); + auto i = dynamic_cast*>(impl.get()); if (i && i->UseMe(attr)) { auto more = i->GetFunc(); infos.push_back( @@ -142,7 +139,7 @@ void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) { } } // Test result from Get function - auto tgt = jit::KernelFuncs::Cache().At(attr); + auto tgt = jit::KernelFuncs::Cache().At(attr); if (!tgt) { LOG(FATAL) << "Target can not be empty!"; } @@ -150,7 +147,8 @@ void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) { // print std::ostringstream loginfos; - loginfos << "Kernel Type " << jit::to_string(KT) << ": " << attr << ": "; + loginfos << "Kernel Type " << jit::to_string(KernelTuple::kernel_type) << ": " + << attr << ": "; for (auto pair : infos) { loginfos << pair.first << " takes " << pair.second << " us; "; } @@ -159,8 +157,9 @@ void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) { using Tensor = paddle::framework::Tensor; -template -void BenchXYZNKernel() { +template +void BenchKernelXYZN() { + using T = typename KernelTuple::data_type; for (int d : TestSizes()) { Tensor x, y, z; x.Resize({d}); @@ -171,16 +170,16 @@ void BenchXYZNKernel() { T* z_data = z.mutable_data(PlaceType()); RandomVec(d, x_data); RandomVec(d, y_data); - BenchAllImpls, PlaceType>(d, x.data(), - y.data(), z_data, d); + BenchAllImpls(d, x.data(), y.data(), z_data, + d); // test inplace - BenchAllImpls, PlaceType>(d, x.data(), z_data, - z_data, d); + BenchAllImpls(d, x.data(), z_data, z_data, d); } } -template -void BenchAXYNKernel() { +template +void BenchKernelAXYN() { + using T = typename KernelTuple::data_type; for (int d : TestSizes()) { const T a = static_cast(3); Tensor x, y; @@ -189,26 +188,26 @@ void BenchAXYNKernel() { T* x_data = x.mutable_data(PlaceType()); T* y_data = y.mutable_data(PlaceType()); RandomVec(d, x_data); - BenchAllImpls, PlaceType>(d, &a, x.data(), y_data, - d); + BenchAllImpls(d, &a, x.data(), y_data, d); // test inplace - BenchAllImpls, PlaceType>(d, &a, x.data(), x_data, - d); + BenchAllImpls(d, &a, x.data(), x_data, d); } } -template -void BenchXRNKernel() { +template +void BenchKernelXRN() { + using T = typename KernelTuple::data_type; for (int d : TestSizes()) { Tensor x; RandomVec(d, x.mutable_data({d}, PlaceType())); T res; - BenchAllImpls, PlaceType>(d, x.data(), &res, d); + BenchAllImpls(d, x.data(), &res, d); } } -template -void BenchXYNKernel() { +template +void BenchKernelXYN() { + using T = typename KernelTuple::data_type; for (int d : TestSizes()) { Tensor x, y; x.Resize({d}); @@ -216,12 +215,13 @@ void BenchXYNKernel() { T* x_data = x.mutable_data(PlaceType()); T* y_data = y.mutable_data(PlaceType()); RandomVec(d, x_data); - BenchAllImpls, PlaceType>(d, x.data(), y_data, d); + BenchAllImpls(d, x.data(), y_data, d); } } -template -void BenchLSTMKernel() { +template +void BenchKernelLSTM() { + using T = typename KernelTuple::data_type; for (bool use_peephole : {true, false}) { for (int d : TestSizes()) { const jit::lstm_attr_t attr(d, jit::kVSigmoid, jit::kVTanh, jit::kVTanh, @@ -252,13 +252,14 @@ void BenchLSTMKernel() { step.wp = wp_data; step.checked = checked_data; } - BenchAllImpls, PlaceType>(attr, &step, &attr); + BenchAllImpls(attr, &step, &attr); } } } -template -void BenchGRUKernel() { +template +void BenchKernelGRU() { + using T = typename KernelTuple::data_type; for (int d : TestSizes()) { const jit::gru_attr_t attr(d, jit::kVSigmoid, jit::kVTanh); auto place = PlaceType(); @@ -275,12 +276,13 @@ void BenchGRUKernel() { step.gates = x_data; step.ht_1 = ht_1_data; step.ht = ht_data; - BenchAllImpls, PlaceType>(attr, &step, &attr); + BenchAllImpls(attr, &step, &attr); } } -template -void BenchSeqPoolKernel() { +template +void BenchKernelSeqPool() { + using T = typename KernelTuple::data_type; std::vector pool_types = { jit::SeqPoolType::kSum, jit::SeqPoolType::kAvg, jit::SeqPoolType::kSqrt}; for (auto type : pool_types) { @@ -294,15 +296,15 @@ void BenchSeqPoolKernel() { RandomVec(h * w, x.mutable_data(PlaceType()), -2.f, 2.f); const T* x_data = x.data(); T* y_data = y.mutable_data(PlaceType()); - BenchAllImpls, PlaceType>(attr, x_data, - y_data, &attr); + BenchAllImpls(attr, x_data, y_data, &attr); } } } } -template -void BenchEmbSeqPoolKernel() { +template +void BenchKernelEmbSeqPool() { + using T = typename KernelTuple::data_type; std::vector pool_types = {jit::SeqPoolType::kSum}; int64_t tbl_h = 1e4; for (int tbl_w : {10, 16, 256}) { @@ -324,16 +326,17 @@ void BenchEmbSeqPoolKernel() { tbl_h - 1); const int64_t* idx_data = idx.data(); T* o_data = out.mutable_data(PlaceType()); - BenchAllImpls, PlaceType>( - attr, table_data, idx_data, o_data, &attr); + BenchAllImpls(attr, table_data, idx_data, + o_data, &attr); } } } } } -template -void BenchSgdKernel() { +template +void BenchKernelSgd() { + using T = typename KernelTuple::data_type; const T lr = 0.1; auto UnDuplicatedRandomVec = [](int n, const int64_t lower, const int64_t upper) -> std::vector { @@ -364,15 +367,16 @@ void BenchSgdKernel() { const T* grad_data = grad.data(); const int64_t* rows_data = rows.data(); jit::sgd_attr_t attr(param_h, grad_w, rows_size, grad_w, rows_size); - BenchAllImpls, PlaceType>( - attr, &lr, param_data, grad_data, rows_data, param_data, &attr); + BenchAllImpls(attr, &lr, param_data, grad_data, + rows_data, param_data, &attr); } } } } -template -void BenchMatMulKernel() { +template +void BenchKernelMatMul() { + using T = typename KernelTuple::data_type; for (int m : {1, 2, 3, 4}) { for (int n : TestSizes()) { for (int k : TestSizes()) { @@ -386,15 +390,16 @@ void BenchMatMulKernel() { const T* b_data = b.data(); T* c_data = c.mutable_data(PlaceType()); const jit::matmul_attr_t attr{m, n, k}; - BenchAllImpls, PlaceType>(attr, a_data, b_data, - c_data, &attr); + BenchAllImpls(attr, a_data, b_data, c_data, + &attr); } } } } -template -void BenchSoftmaxKernel() { +template +void BenchKernelSoftmax() { + using T = typename KernelTuple::data_type; for (int bs : {1, 2, 10}) { for (int n : TestSizes()) { Tensor x, y; @@ -403,14 +408,14 @@ void BenchSoftmaxKernel() { RandomVec(bs * n, x.mutable_data(PlaceType()), -2.f, 2.f); const T* x_data = x.data(); T* y_data = y.mutable_data(PlaceType()); - BenchAllImpls, PlaceType>(n, x_data, y_data, n, - bs); + BenchAllImpls(n, x_data, y_data, n, bs); } } } -template -void BenchLayerNormKernel() { +template +void BenchKernelLayerNorm() { + using T = typename KernelTuple::data_type; const T epsilon = 9.99999975e-06; for (int n : {1, 2, 10}) { for (int x_dim_0 : {1, 9, 17, 50}) { @@ -439,16 +444,17 @@ void BenchLayerNormKernel() { T* var_data = var.data(); T* out_data = out.mutable_data(PlaceType()); - BenchAllImpls, PlaceType>( - right, x_data, out_data, mean_data, var_data, scale_data, bias_data, - left, epsilon, right); + BenchAllImpls(right, x_data, out_data, + mean_data, var_data, scale_data, + bias_data, left, epsilon, right); } } } } -template -void BenchCRFDecodingKernel() { +template +void BenchKernelCRFDecoding() { + using T = typename KernelTuple::data_type; constexpr int state_trans_base_idx = 2; for (int seq_len : {1, 11, 17, 50}) { for (int tag_num : TestSizes()) { @@ -468,14 +474,15 @@ void BenchCRFDecodingKernel() { T* alpha_data = alpha.mutable_data(PlaceType()); int* track_data = track.mutable_data(PlaceType()); - BenchAllImpls, PlaceType>( - tag_num, seq_len, x_data, w_data, alpha_data, track_data, tag_num); + BenchAllImpls(tag_num, seq_len, x_data, w_data, + alpha_data, track_data, tag_num); } } } -template -void BenchVBroadcastKernel() { +template +void BenchKernelVBroadcast() { + using T = typename KernelTuple::data_type; for (int64_t w : {1, 16, 64, 100, 256}) { Tensor x; x.Resize({w}); @@ -485,78 +492,86 @@ void BenchVBroadcastKernel() { Tensor y; y.Resize({h * w}); T* y_data = y.mutable_data(PlaceType()); - BenchAllImpls, PlaceType>( - w, x_data, y_data, static_cast(h), w); + BenchAllImpls(w, x_data, y_data, + static_cast(h), w); } } } -using T = float; -using CPUPlace = paddle::platform::CPUPlace; +#define BenchKernelVMul BenchKernelXYZN +#define BenchKernelVAdd BenchKernelXYZN +#define BenchKernelVAddRelu BenchKernelXYZN +#define BenchKernelVSub BenchKernelXYZN -// xyzn -BENCH_FP32_CPU(kVMul) { BenchXYZNKernel(); } -BENCH_FP32_CPU(kVAdd) { BenchXYZNKernel(); } -BENCH_FP32_CPU(kVAddRelu) { BenchXYZNKernel(); } -BENCH_FP32_CPU(kVSub) { BenchXYZNKernel(); } +#define BenchKernelVScal BenchKernelAXYN +#define BenchKernelVAddBias BenchKernelAXYN -// axyn -BENCH_FP32_CPU(kVScal) { BenchAXYNKernel(); } -BENCH_FP32_CPU(kVAddBias) { BenchAXYNKernel(); } +#define BenchKernelVRelu BenchKernelXYN +#define BenchKernelVIdentity BenchKernelXYN +#define BenchKernelVSquare BenchKernelXYN +#define BenchKernelVExp BenchKernelXYN +#define BenchKernelVSigmoid BenchKernelXYN +#define BenchKernelVTanh BenchKernelXYN +#define BenchKernelVCopy BenchKernelXYN -// xrn -BENCH_FP32_CPU(kHSum) { BenchXRNKernel(); } -BENCH_FP32_CPU(kHMax) { BenchXRNKernel(); } +#define BenchKernelHMax BenchKernelXRN +#define BenchKernelHSum BenchKernelXRN -// xyn -BENCH_FP32_CPU(kVRelu) { BenchXYNKernel(); } -BENCH_FP32_CPU(kVIdentity) { BenchXYNKernel(); } -BENCH_FP32_CPU(kVSquare) { BenchXYNKernel(); } -BENCH_FP32_CPU(kVExp) { BenchXYNKernel(); } -BENCH_FP32_CPU(kVSigmoid) { BenchXYNKernel(); } -BENCH_FP32_CPU(kVTanh) { BenchXYNKernel(); } -BENCH_FP32_CPU(kVCopy) { BenchXYNKernel(); } - -// lstm and peephole -BENCH_FP32_CPU(kLSTMCtHt) { BenchLSTMKernel(); } -BENCH_FP32_CPU(kLSTMC1H1) { BenchLSTMKernel(); } - -// gru functions -BENCH_FP32_CPU(kGRUH1) { BenchGRUKernel(); } -BENCH_FP32_CPU(kGRUHtPart1) { BenchGRUKernel(); } -BENCH_FP32_CPU(kGRUHtPart2) { BenchGRUKernel(); } - -// seq pool function -BENCH_FP32_CPU(kSeqPool) { BenchSeqPoolKernel(); } - -// embedding seq pool function -BENCH_FP32_CPU(kEmbSeqPool) { - BenchEmbSeqPoolKernel(); -} +#define BenchKernelLSTMCtHt BenchKernelLSTM +#define BenchKernelLSTMC1H1 BenchKernelLSTM -// sgd function -BENCH_FP32_CPU(kSgd) { BenchSgdKernel(); } +#define BenchKernelGRUH1 BenchKernelGRU +#define BenchKernelGRUHtPart1 BenchKernelGRU +#define BenchKernelGRUHtPart2 BenchKernelGRU -// matmul -BENCH_FP32_CPU(kMatMul) { BenchMatMulKernel(); } +using CPUPlace = paddle::platform::CPUPlace; -// softmax -BENCH_FP32_CPU(kSoftmax) { BenchSoftmaxKernel(); } +#define BENCH_FP32_CPU(name) \ + BENCH_JITKERNEL(name, FP32, CPU) { \ + BenchKernel##name, CPUPlace>(); \ + } -// layernorm -BENCH_FP32_CPU(kLayerNorm) { - BenchLayerNormKernel(); -} +// xyzn +BENCH_FP32_CPU(VMul); +BENCH_FP32_CPU(VAdd); +BENCH_FP32_CPU(VAddRelu); +BENCH_FP32_CPU(VSub); -// crfdecoding -BENCH_FP32_CPU(kCRFDecoding) { - BenchCRFDecodingKernel(); -} +// axyn +BENCH_FP32_CPU(VScal); +BENCH_FP32_CPU(VAddBias); -// vbroadcast function -BENCH_FP32_CPU(kVBroadcast) { - BenchVBroadcastKernel(); -} +// xyn +BENCH_FP32_CPU(VRelu); +BENCH_FP32_CPU(VIdentity); +BENCH_FP32_CPU(VSquare); +BENCH_FP32_CPU(VExp); +BENCH_FP32_CPU(VSigmoid); +BENCH_FP32_CPU(VTanh); +BENCH_FP32_CPU(VCopy); + +// xrn +BENCH_FP32_CPU(HMax); +BENCH_FP32_CPU(HSum); + +// LSTM +BENCH_FP32_CPU(LSTMCtHt); +BENCH_FP32_CPU(LSTMC1H1); + +// GRU +BENCH_FP32_CPU(GRUH1); +BENCH_FP32_CPU(GRUHtPart1); +BENCH_FP32_CPU(GRUHtPart2); + +BENCH_FP32_CPU(LayerNorm); +BENCH_FP32_CPU(CRFDecoding); + +BENCH_FP32_CPU(SeqPool); +BENCH_FP32_CPU(EmbSeqPool); +BENCH_FP32_CPU(MatMul); +BENCH_FP32_CPU(Softmax); +BENCH_FP32_CPU(Sgd); +BENCH_FP32_CPU(VBroadcast); // Benchmark all jit kernels including jitcode, mkl and refer. // To use this tool, run command: ./benchmark [options...] diff --git a/paddle/fluid/operators/jit/helper.h b/paddle/fluid/operators/jit/helper.h index 1af1add3ee..85f4072dd3 100644 --- a/paddle/fluid/operators/jit/helper.h +++ b/paddle/fluid/operators/jit/helper.h @@ -19,6 +19,8 @@ extern "C" { } #include #include +#include +#include // for std::move #include #include "paddle/fluid/operators/jit/gen_base.h" #include "paddle/fluid/operators/jit/kernel_base.h" @@ -30,22 +32,22 @@ namespace paddle { namespace operators { namespace jit { -template +template inline typename std::enable_if< - std::is_same::value && + std::is_same::value && std::is_same::value, - typename KernelTuples::func_type>::type -GetJitCode(const typename KernelTuples::attr_type& attr) { - using Func = typename KernelTuples::func_type; - using Attr = typename KernelTuples::attr_type; + typename KernelTuple::func_type>::type +GetJitCode(const typename KernelTuple::attr_type& attr) { + using Func = typename KernelTuple::func_type; + using Attr = typename KernelTuple::attr_type; size_t key = JitCodeKey(attr); - auto& codes = JitCodePool().Instance(); + auto& codes = JitCodePool().Instance(); if (codes.Has(key)) { return codes.AllKernels().at(key)->template getCode(); } // creator is not related with attr, so can use KernelKey as key - KernelKey kkey(KT, PlaceType()); + KernelKey kkey(KernelTuple::kernel_type, PlaceType()); // pool: (KernelKey(type, place), vector) auto& creator_map = JitCodeCreatorPool().Instance().AllCreators(); auto iter = creator_map.find(kkey); @@ -66,27 +68,27 @@ GetJitCode(const typename KernelTuples::attr_type& attr) { return nullptr; } -template +template inline typename std::enable_if< - !std::is_same::value || + !std::is_same::value || !std::is_same::value, - typename KernelTuples::func_type>::type -GetJitCode(const typename KernelTuples::attr_type& attr) { + typename KernelTuple::func_type>::type +GetJitCode(const typename KernelTuple::attr_type& attr) { return nullptr; } // Refer code do not related with attr, which is just for cast // Refer is always on CPUPlace -template -inline typename KernelTuples::func_type GetRefer() { +template +inline typename KernelTuple::func_type GetRefer() { auto& ref_pool = ReferKernelPool().Instance().AllKernels(); - KernelKey kkey(KT, platform::CPUPlace()); + KernelKey kkey(KernelTuple::kernel_type, platform::CPUPlace()); auto ref_iter = ref_pool.find(kkey); PADDLE_ENFORCE(ref_iter != ref_pool.end(), "Every Kernel should have reference function."); auto& ref_impls = ref_iter->second; for (auto& impl : ref_impls) { - auto i = dynamic_cast*>(impl.get()); + auto i = dynamic_cast*>(impl.get()); if (i) { return i->GetFunc(); } @@ -94,23 +96,22 @@ inline typename KernelTuples::func_type GetRefer() { return nullptr; } -template -typename KernelTuples::func_type Get( - const typename KernelTuples::attr_type& attr) { - auto jitfunc = GetJitCode(attr); +template +typename KernelTuple::func_type Get( + const typename KernelTuple::attr_type& attr) { + auto jitfunc = GetJitCode(attr); if (jitfunc) { return jitfunc; } // pool: (KernelKey(type, place), vector) - KernelKey kkey(KT, PlaceType()); + KernelKey kkey(KernelTuple::kernel_type, PlaceType()); auto& pool = KernelPool().Instance().AllKernels(); auto iter = pool.find(kkey); if (iter != pool.end()) { auto& impls = iter->second; for (auto& impl : impls) { - auto i = dynamic_cast*>(impl.get()); + auto i = dynamic_cast*>(impl.get()); if (i && i->UseMe(attr)) { return i->GetFunc(); } @@ -118,48 +119,50 @@ typename KernelTuples::func_type Get( } // The last implementation should be reference function on CPUPlace. - return GetRefer(); + return GetRefer(); } -template +template class KernelFuncs { public: KernelFuncs() = default; static KernelFuncs& Cache() { - static thread_local KernelFuncs g_func_cache; + static thread_local KernelFuncs g_func_cache; return g_func_cache; } // the exposed interface to use - typename KernelTuples::func_type At( - const typename KernelTuples::attr_type& attr) { + typename KernelTuple::func_type At( + const typename KernelTuple::attr_type& attr) { // XXH64: 13.8 GB/s - int64_t key = XXH64(&attr, sizeof(typename KernelTuples::attr_type), 0); + // TODO(TJ): change me, maybe not all attr change need one key, should be + // attrkey + int64_t key = XXH64(&attr, sizeof(typename KernelTuple::attr_type), 0); if (Has(key)) { return funcs_.at(key); } // If do not have this attr in cache, // then could run some runtime benchmark of this attr and save the best one. // Here just get the offline benchmarked best one. - auto func = Get(attr); + auto func = Get(attr); Insert(key, func); return func; } - typename KernelTuples::func_type operator[]( - const typename KernelTuples::attr_type& attr) { + typename KernelTuple::func_type operator[]( + const typename KernelTuple::attr_type& attr) { return At(attr); } protected: bool Has(int64_t key) const { return funcs_.find(key) != funcs_.end(); } - void Insert(int64_t key, typename KernelTuples::func_type func) { + void Insert(int64_t key, typename KernelTuple::func_type func) { funcs_.emplace(key, func); } private: - std::unordered_map funcs_; + std::unordered_map funcs_; DISABLE_COPY_AND_ASSIGN(KernelFuncs); }; diff --git a/paddle/fluid/operators/jit/kernel_base.h b/paddle/fluid/operators/jit/kernel_base.h index 96e162a21b..e8dbcced4f 100644 --- a/paddle/fluid/operators/jit/kernel_base.h +++ b/paddle/fluid/operators/jit/kernel_base.h @@ -62,26 +62,55 @@ typedef enum { kSqrt, } SeqPoolType; +// x, y, z, n template -struct XYZNTuples { +struct XYZNTuple { typedef T data_type; typedef int attr_type; typedef void (*func_type)(const T*, const T*, T*, int); }; +// a, x, y, n template -struct AXYNTuples : public XYZNTuples {}; +struct AXYNTuple : public XYZNTuple {}; +// x, y, n template -struct XYNTuples { +struct XYNTuple { typedef T data_type; typedef int attr_type; typedef void (*func_type)(const T*, T*, int); }; -// x, return and int +// x, returned value, n template -struct XRNTuples : public XYNTuples {}; +struct XRNTuple : public XYNTuple {}; + +#define DECLARE_KERNELTUPLE(kernel_tuple, type) \ + template \ + struct type##Tuple : public kernel_tuple { \ + static constexpr KernelType kernel_type = k##type; \ + } + +// Tuple should be corresponding to the KernelType +DECLARE_KERNELTUPLE(XYZNTuple, VMul); +DECLARE_KERNELTUPLE(XYZNTuple, VAdd); +DECLARE_KERNELTUPLE(XYZNTuple, VAddRelu); +DECLARE_KERNELTUPLE(XYZNTuple, VSub); + +DECLARE_KERNELTUPLE(AXYNTuple, VScal); +DECLARE_KERNELTUPLE(AXYNTuple, VAddBias); + +DECLARE_KERNELTUPLE(XYNTuple, VRelu); +DECLARE_KERNELTUPLE(XYNTuple, VIdentity); +DECLARE_KERNELTUPLE(XYNTuple, VSquare); +DECLARE_KERNELTUPLE(XYNTuple, VExp); +DECLARE_KERNELTUPLE(XYNTuple, VSigmoid); +DECLARE_KERNELTUPLE(XYNTuple, VTanh); +DECLARE_KERNELTUPLE(XYNTuple, VCopy); + +DECLARE_KERNELTUPLE(XRNTuple, HMax); +DECLARE_KERNELTUPLE(XRNTuple, HSum); typedef struct { void* gates; // gates: x_ch, x_ih, x_fh, x_oh @@ -122,21 +151,31 @@ typedef struct rnn_attr_s gru_attr_t; typedef struct lstm_attr_s lstm_attr_t; template -struct LSTMTuples { +struct LSTMTuple { typedef T data_type; typedef lstm_attr_t attr_type; typedef void (*func_type)(lstm_t*, const lstm_attr_t*); }; template -struct GRUTuples { +struct GRUTuple { typedef T data_type; typedef gru_attr_t attr_type; typedef void (*func_type)(gru_t*, const gru_attr_t*); }; +DECLARE_KERNELTUPLE(LSTMTuple, LSTMCtHt); +DECLARE_KERNELTUPLE(LSTMTuple, LSTMC1H1); + +DECLARE_KERNELTUPLE(GRUTuple, GRUH1); +DECLARE_KERNELTUPLE(GRUTuple, GRUHtPart1); +DECLARE_KERNELTUPLE(GRUTuple, GRUHtPart2); + +#undef DECLARE_KERNELTUPLE + template -struct VBroadcastTuples { +struct VBroadcastTuple { + static constexpr KernelType kernel_type = kVBroadcast; typedef T data_type; typedef int64_t attr_type; typedef void (*func_type)(const T*, T*, int64_t, int64_t); @@ -151,7 +190,8 @@ typedef struct seq_pool_attr_s { } seq_pool_attr_t; template -struct SeqPoolTuples { +struct SeqPoolTuple { + static constexpr KernelType kernel_type = kSeqPool; typedef T data_type; typedef seq_pool_attr_t attr_type; typedef void (*func_type)(const T*, T*, const seq_pool_attr_t*); @@ -176,7 +216,8 @@ typedef struct emb_seq_pool_attr_s { } emb_seq_pool_attr_t; template -struct EmbSeqPoolTuples { +struct EmbSeqPoolTuple { + static constexpr KernelType kernel_type = kEmbSeqPool; typedef T data_type; typedef emb_seq_pool_attr_t attr_type; typedef void (*func_type)(const T*, const int64_t*, T*, @@ -198,7 +239,8 @@ typedef struct sgd_attr_s { } sgd_attr_t; template -struct SgdTuples { +struct SgdTuple { + static constexpr KernelType kernel_type = kSgd; typedef T data_type; typedef sgd_attr_t attr_type; typedef void (*func_type)(const T*, const T*, const T*, const int64_t*, T*, @@ -214,21 +256,24 @@ typedef struct matmul_attr_s { } matmul_attr_t; template -struct MatMulTuples { +struct MatMulTuple { + static constexpr KernelType kernel_type = kMatMul; typedef T data_type; typedef matmul_attr_t attr_type; typedef void (*func_type)(const T*, const T*, T*, const matmul_attr_t*); }; template -struct CRFDecodingTuples { +struct CRFDecodingTuple { + static constexpr KernelType kernel_type = kCRFDecoding; typedef T data_type; typedef int attr_type; typedef void (*func_type)(const int, const T*, const T*, T*, int*, int); }; template -struct LayerNormTuples { +struct LayerNormTuple { + static constexpr KernelType kernel_type = kLayerNorm; typedef T data_type; typedef int attr_type; typedef void (*func_type)(T*, T*, T*, T*, const T*, const T*, int, @@ -236,7 +281,8 @@ struct LayerNormTuples { }; template -struct SoftmaxTuples { +struct SoftmaxTuple { + static constexpr KernelType kernel_type = kSoftmax; typedef T data_type; typedef int attr_type; typedef void (*func_type)(const T*, T*, int, int); @@ -244,7 +290,8 @@ struct SoftmaxTuples { // nChw16c = nChw16c .* NC template -struct NCHW16CMulNCTuples { +struct NCHW16CMulNCTuple { + static constexpr KernelType kernel_type = kNCHW16CMulNC; typedef T data_type; typedef int attr_type; typedef void (*func_type)(const T*, const T*, T*, int, int); @@ -258,12 +305,12 @@ class Kernel { DISABLE_COPY_AND_ASSIGN(Kernel); }; -template +template class KernelMore : public Kernel { public: - using T = typename KernelTuples::data_type; - using Func = typename KernelTuples::func_type; - using Attr = typename KernelTuples::attr_type; + using T = typename KernelTuple::data_type; + using Func = typename KernelTuple::func_type; + using Attr = typename KernelTuple::attr_type; virtual Func GetFunc() const { return func; } virtual bool UseMe(const Attr& attr) const = 0; virtual const char* ImplType() const = 0; @@ -272,11 +319,11 @@ class KernelMore : public Kernel { Func func{nullptr}; }; -template -class ReferKernel : public KernelMore { +template +class ReferKernel : public KernelMore { public: // Refer code can always be used - bool UseMe(const typename KernelTuples::attr_type& attr) const override { + bool UseMe(const typename KernelTuple::attr_type& attr) const override { return true; } const char* ImplType() const override { return "Refer"; } diff --git a/paddle/fluid/operators/jit/more/intrinsic/crf_decoding.h b/paddle/fluid/operators/jit/more/intrinsic/crf_decoding.h index 24179d90dd..f4187bd3ba 100644 --- a/paddle/fluid/operators/jit/more/intrinsic/crf_decoding.h +++ b/paddle/fluid/operators/jit/more/intrinsic/crf_decoding.h @@ -26,11 +26,10 @@ namespace intrinsic { void CRFDecoding(const int seq_len, const float* x, const float* w, float* alpha, int* track, int tag_num); -class CRFDecodingKernel : public KernelMore> { +class CRFDecodingKernel : public KernelMore> { public: CRFDecodingKernel() { this->func = CRFDecoding; } - bool UseMe( - const typename CRFDecodingTuples::attr_type&) const override; + bool UseMe(const typename CRFDecodingTuple::attr_type&) const override; const char* ImplType() const override { return "Intrinsic"; } }; diff --git a/paddle/fluid/operators/jit/more/intrinsic/layer_norm.h b/paddle/fluid/operators/jit/more/intrinsic/layer_norm.h index 89da2940f4..dfa4c2f072 100644 --- a/paddle/fluid/operators/jit/more/intrinsic/layer_norm.h +++ b/paddle/fluid/operators/jit/more/intrinsic/layer_norm.h @@ -27,10 +27,10 @@ void LayerNorm(float* x, float* out, float* mean, float* var, const float* scale, const float* bias, int height, const float epsilon, int right); -class LayerNormKernel : public KernelMore> { +class LayerNormKernel : public KernelMore> { public: LayerNormKernel() { this->func = LayerNorm; } - bool UseMe(const typename LayerNormTuples::attr_type&) const override; + bool UseMe(const typename LayerNormTuple::attr_type&) const override; const char* ImplType() const override { return "Intrinsic"; } }; diff --git a/paddle/fluid/operators/jit/more/mix/mix.cc b/paddle/fluid/operators/jit/more/mix/mix.cc index 0036d1c238..9ee1032e95 100644 --- a/paddle/fluid/operators/jit/more/mix/mix.cc +++ b/paddle/fluid/operators/jit/more/mix/mix.cc @@ -23,6 +23,8 @@ namespace jit { namespace more { namespace mix { +using CPUPlace = platform::CPUPlace; + void VSigmoid(const T* x, T* y, int n) { const float min = SIGMOID_THRESHOLD_MIN; const float max = SIGMOID_THRESHOLD_MAX; @@ -30,7 +32,7 @@ void VSigmoid(const T* x, T* y, int n) { y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); y[i] = static_cast(0) - y[i]; } - auto compute = Get, platform::CPUPlace>(n); + auto compute = KernelFuncs, CPUPlace>::Cache().At(n); compute(y, y, n); for (int i = 0; i < n; ++i) { y[i] = static_cast(1) / (static_cast(1) + y[i]); @@ -39,9 +41,9 @@ void VSigmoid(const T* x, T* y, int n) { void VTanh(const T* x, T* y, int n) { const T a = 2, b = -1; - auto compute_scal = Get, platform::CPUPlace>(n); - auto compute_addbias = Get, platform::CPUPlace>(n); - auto compute_sigmoid = Get, platform::CPUPlace>(n); + auto compute_scal = KernelFuncs, CPUPlace>::Cache().At(n); + auto compute_addbias = KernelFuncs, CPUPlace>::Cache().At(n); + auto compute_sigmoid = KernelFuncs, CPUPlace>::Cache().At(n); compute_scal(&a, x, y, n); compute_sigmoid(y, y, n); compute_scal(&a, y, y, n); @@ -49,16 +51,12 @@ void VTanh(const T* x, T* y, int n) { } void Softmax(const T* x, T* y, int n, int bs) { - auto compute_hmax = - KernelFuncs, platform::CPUPlace>::Cache().At(n); - auto compute_hsum = - KernelFuncs, platform::CPUPlace>::Cache().At(n); - auto compute_vscal = - KernelFuncs, platform::CPUPlace>::Cache().At(n); + auto compute_hmax = KernelFuncs, CPUPlace>::Cache().At(n); + auto compute_hsum = KernelFuncs, CPUPlace>::Cache().At(n); + auto compute_vscal = KernelFuncs, CPUPlace>::Cache().At(n); auto compute_vaddbias = - KernelFuncs, platform::CPUPlace>::Cache().At(n); - auto compute_vexp = - KernelFuncs, platform::CPUPlace>::Cache().At(n); + KernelFuncs, CPUPlace>::Cache().At(n); + auto compute_vexp = KernelFuncs, CPUPlace>::Cache().At(n); for (int i = 0; i < bs; ++i) { T scalar; @@ -76,13 +74,13 @@ void Softmax(const T* x, T* y, int n, int bs) { void (*getActFunc(KernelType type, int d))(const T*, T*, int) { // NOLINT if (type == kVSigmoid) { - return Get, platform::CPUPlace>(d); + return KernelFuncs, CPUPlace>::Cache().At(d); } else if (type == kVRelu) { - return Get, platform::CPUPlace>(d); + return KernelFuncs, CPUPlace>::Cache().At(d); } else if (type == kVTanh) { - return Get, platform::CPUPlace>(d); + return KernelFuncs, CPUPlace>::Cache().At(d); } else if (type == kVIdentity) { - return Get, platform::CPUPlace>(d); + return KernelFuncs, CPUPlace>::Cache().At(d); } PADDLE_THROW("Not support type: %s", type); return nullptr; @@ -98,9 +96,9 @@ void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) { const int d = attr->d; const int d2 = d * 2; const int d3 = d * 3; - auto vmul_d = Get, platform::CPUPlace>(d); - auto vadd_d = Get, platform::CPUPlace>(d); - auto vadd_d2 = Get, platform::CPUPlace>(d2); + auto vmul_d = KernelFuncs, CPUPlace>::Cache().At(d); + auto vadd_d = KernelFuncs, CPUPlace>::Cache().At(d); + auto vadd_d2 = KernelFuncs, CPUPlace>::Cache().At(d2); auto act_gate_d = getActFunc(attr->act_gate, d); auto act_gate_d2 = getActFunc(attr->act_gate, d2); auto act_gate_d3 = getActFunc(attr->act_gate, d3); @@ -140,8 +138,8 @@ void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr) { int d = attr->d; int d2 = d * 2; int d3 = d * 3; - auto vmul_d = Get, platform::CPUPlace>(d); - auto vadd_d = Get, platform::CPUPlace>(d); + auto vmul_d = KernelFuncs, CPUPlace>::Cache().At(d); + auto vadd_d = KernelFuncs, CPUPlace>::Cache().At(d); auto act_gate_d = getActFunc(attr->act_gate, d); auto act_cand_d = getActFunc(attr->act_cand, d); auto act_cell_d = getActFunc(attr->act_cell, d); @@ -169,7 +167,7 @@ void GRUH1(gru_t* step, const gru_attr_t* attr) { int d2 = d * 2; auto act_gate = getActFunc(attr->act_gate, d); auto act_cand = getActFunc(attr->act_cand, d); - auto vmul_d = Get, platform::CPUPlace>(d); + auto vmul_d = KernelFuncs, CPUPlace>::Cache().At(d); act_gate(gates, gates, d); act_cand(gates + d2, gates + d2, d); vmul_d(gates, gates + d2, ht, d); @@ -182,7 +180,7 @@ void GRUHtPart1(gru_t* step, const gru_attr_t* attr) { T* ht = reinterpret_cast(step->ht); const T* ht_1 = reinterpret_cast(step->ht_1); auto act_gate = getActFunc(attr->act_gate, attr->d); - auto vmul_d = Get, platform::CPUPlace>(attr->d); + auto vmul_d = KernelFuncs, CPUPlace>::Cache().At(attr->d); act_gate(gates + attr->d, gates + attr->d, attr->d); vmul_d(ht_1, gates + attr->d, ht, attr->d); } @@ -230,16 +228,16 @@ bool GRUHtPart2Kernel::UseMe(const gru_attr_t& attr) const { return true; } namespace mix = paddle::operators::jit::more::mix; -#define REGISTER_MORE_KERNEL(key, func) \ - REGISTER_JITKERNEL_MORE(key, mix, mix::func##Kernel) - -REGISTER_MORE_KERNEL(kVSigmoid, VSigmoid); -REGISTER_MORE_KERNEL(kVTanh, VTanh); -REGISTER_MORE_KERNEL(kSoftmax, Softmax); -REGISTER_MORE_KERNEL(kLSTMCtHt, LSTMCtHt); -REGISTER_MORE_KERNEL(kLSTMC1H1, LSTMC1H1); -REGISTER_MORE_KERNEL(kGRUH1, GRUH1); -REGISTER_MORE_KERNEL(kGRUHtPart1, GRUHtPart1); -REGISTER_MORE_KERNEL(kGRUHtPart2, GRUHtPart2); +#define REGISTER_MORE_KERNEL(func) \ + REGISTER_JITKERNEL_MORE(k##func, mix, mix::func##Kernel) + +REGISTER_MORE_KERNEL(VSigmoid); +REGISTER_MORE_KERNEL(VTanh); +REGISTER_MORE_KERNEL(Softmax); +REGISTER_MORE_KERNEL(LSTMCtHt); +REGISTER_MORE_KERNEL(LSTMC1H1); +REGISTER_MORE_KERNEL(GRUH1); +REGISTER_MORE_KERNEL(GRUHtPart1); +REGISTER_MORE_KERNEL(GRUHtPart2); #undef REGISTER_MORE_KERNEL diff --git a/paddle/fluid/operators/jit/more/mix/mix.h b/paddle/fluid/operators/jit/more/mix/mix.h index d64af19219..17eb96462f 100644 --- a/paddle/fluid/operators/jit/more/mix/mix.h +++ b/paddle/fluid/operators/jit/more/mix/mix.h @@ -34,27 +34,27 @@ void GRUH1(gru_t* step, const gru_attr_t* attr); void GRUHtPart1(gru_t* step, const gru_attr_t* attr); void GRUHtPart2(gru_t* step, const gru_attr_t* attr); -#define DECLARE_MORE_KERNEL(name, tuples) \ - class name##Kernel : public KernelMore> { \ - public: \ - name##Kernel() { this->func = name; } \ - bool UseMe(const typename tuples::attr_type&) const override; \ - const char* ImplType() const override { return "Mixed"; } \ +#define DECLARE_MORE_KERNEL(name) \ + class name##Kernel : public KernelMore> { \ + public: \ + name##Kernel() { this->func = name; } \ + bool UseMe(const typename name##Tuple::attr_type&) const override; \ + const char* ImplType() const override { return "Mixed"; } \ } // XYN -DECLARE_MORE_KERNEL(VSigmoid, XYNTuples); -DECLARE_MORE_KERNEL(VTanh, XYNTuples); +DECLARE_MORE_KERNEL(VSigmoid); +DECLARE_MORE_KERNEL(VTanh); // XRN -DECLARE_MORE_KERNEL(Softmax, SoftmaxTuples); +DECLARE_MORE_KERNEL(Softmax); -DECLARE_MORE_KERNEL(LSTMCtHt, LSTMTuples); -DECLARE_MORE_KERNEL(LSTMC1H1, LSTMTuples); +DECLARE_MORE_KERNEL(LSTMCtHt); +DECLARE_MORE_KERNEL(LSTMC1H1); -DECLARE_MORE_KERNEL(GRUH1, GRUTuples); -DECLARE_MORE_KERNEL(GRUHtPart1, GRUTuples); -DECLARE_MORE_KERNEL(GRUHtPart2, GRUTuples); +DECLARE_MORE_KERNEL(GRUH1); +DECLARE_MORE_KERNEL(GRUHtPart1); +DECLARE_MORE_KERNEL(GRUHtPart2); #undef DECLARE_MORE_KERNEL diff --git a/paddle/fluid/operators/jit/more/mkl/mkl.cc b/paddle/fluid/operators/jit/more/mkl/mkl.cc index 4f51353bce..084ea571ce 100644 --- a/paddle/fluid/operators/jit/more/mkl/mkl.cc +++ b/paddle/fluid/operators/jit/more/mkl/mkl.cc @@ -250,23 +250,23 @@ AWALYS_USE_ME_WITH_DOUBLE(Softmax); namespace mkl = paddle::operators::jit::more::mkl; -#define REGISTER_MKL_KERNEL(key, func) \ - REGISTER_JITKERNEL_MORE(key, mkl, mkl::func##Kernel, \ +#define REGISTER_MKL_KERNEL(func) \ + REGISTER_JITKERNEL_MORE(k##func, mkl, mkl::func##Kernel, \ mkl::func##Kernel) -REGISTER_MKL_KERNEL(kMatMul, MatMul); -REGISTER_MKL_KERNEL(kVMul, VMul); -REGISTER_MKL_KERNEL(kVAdd, VAdd); -REGISTER_MKL_KERNEL(kVScal, VScal); -REGISTER_MKL_KERNEL(kVExp, VExp); -REGISTER_MKL_KERNEL(kVSquare, VSquare); -REGISTER_MKL_KERNEL(kVCopy, VCopy); -REGISTER_MKL_KERNEL(kVBroadcast, VBroadcast); -REGISTER_MKL_KERNEL(kVSigmoid, VSigmoid); -REGISTER_MKL_KERNEL(kVTanh, VTanh); -REGISTER_MKL_KERNEL(kSeqPool, SeqPool); -REGISTER_MKL_KERNEL(kEmbSeqPool, EmbSeqPool); -REGISTER_MKL_KERNEL(kSoftmax, Softmax); -REGISTER_MKL_KERNEL(kSgd, Sgd); +REGISTER_MKL_KERNEL(MatMul); +REGISTER_MKL_KERNEL(VMul); +REGISTER_MKL_KERNEL(VAdd); +REGISTER_MKL_KERNEL(VScal); +REGISTER_MKL_KERNEL(VExp); +REGISTER_MKL_KERNEL(VSquare); +REGISTER_MKL_KERNEL(VCopy); +REGISTER_MKL_KERNEL(VBroadcast); +REGISTER_MKL_KERNEL(VSigmoid); +REGISTER_MKL_KERNEL(VTanh); +REGISTER_MKL_KERNEL(SeqPool); +REGISTER_MKL_KERNEL(EmbSeqPool); +REGISTER_MKL_KERNEL(Softmax); +REGISTER_MKL_KERNEL(Sgd); #undef REGISTER_MKL_KERNEL diff --git a/paddle/fluid/operators/jit/more/mkl/mkl.h b/paddle/fluid/operators/jit/more/mkl/mkl.h index db2d6faed4..8c1d8b57e0 100644 --- a/paddle/fluid/operators/jit/more/mkl/mkl.h +++ b/paddle/fluid/operators/jit/more/mkl/mkl.h @@ -175,41 +175,38 @@ void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows, } } -#define DECLARE_MKL_KERNEL(name, tuples) \ - template \ - class name##Kernel : public KernelMore> { \ - public: \ - name##Kernel() { this->func = name; } \ - bool UseMe(const typename tuples::attr_type&) const override; \ - const char* ImplType() const override { return "MKL"; } \ +#define DECLARE_MKL_KERNEL(name) \ + template \ + class name##Kernel : public KernelMore> { \ + public: \ + name##Kernel() { this->func = name; } \ + bool UseMe(const typename name##Tuple::attr_type&) const override; \ + const char* ImplType() const override { return "MKL"; } \ } // ABCMNK -DECLARE_MKL_KERNEL(MatMul, MatMulTuples); +DECLARE_MKL_KERNEL(MatMul); // XYZN -DECLARE_MKL_KERNEL(VMul, XYZNTuples); -DECLARE_MKL_KERNEL(VAdd, XYZNTuples); +DECLARE_MKL_KERNEL(VMul); +DECLARE_MKL_KERNEL(VAdd); // AXYN -DECLARE_MKL_KERNEL(VScal, AXYNTuples); +DECLARE_MKL_KERNEL(VScal); // XYN -DECLARE_MKL_KERNEL(VExp, XYNTuples); -DECLARE_MKL_KERNEL(VSigmoid, XYNTuples); -DECLARE_MKL_KERNEL(VTanh, XYNTuples); -DECLARE_MKL_KERNEL(VSquare, XYNTuples); -DECLARE_MKL_KERNEL(VCopy, XYNTuples); - -DECLARE_MKL_KERNEL(SeqPool, SeqPoolTuples); - -DECLARE_MKL_KERNEL(EmbSeqPool, EmbSeqPoolTuples); - -DECLARE_MKL_KERNEL(Softmax, SoftmaxTuples); - -DECLARE_MKL_KERNEL(Sgd, SgdTuples); - -DECLARE_MKL_KERNEL(VBroadcast, VBroadcastTuples); +DECLARE_MKL_KERNEL(VExp); +DECLARE_MKL_KERNEL(VSigmoid); +DECLARE_MKL_KERNEL(VTanh); +DECLARE_MKL_KERNEL(VSquare); +DECLARE_MKL_KERNEL(VCopy); + +// others +DECLARE_MKL_KERNEL(SeqPool); +DECLARE_MKL_KERNEL(EmbSeqPool); +DECLARE_MKL_KERNEL(Softmax); +DECLARE_MKL_KERNEL(Sgd); +DECLARE_MKL_KERNEL(VBroadcast); #undef DECLARE_MKL_KERNEL diff --git a/paddle/fluid/operators/jit/refer/refer.cc b/paddle/fluid/operators/jit/refer/refer.cc index c279d1b2ca..0d1c477090 100644 --- a/paddle/fluid/operators/jit/refer/refer.cc +++ b/paddle/fluid/operators/jit/refer/refer.cc @@ -17,51 +17,43 @@ namespace refer = paddle::operators::jit::refer; -#define REGISTER_REFER_KERNEL(key, func) \ - REGISTER_JITKERNEL_REFER(key, refer::func##Kernel, \ +#define REGISTER_REFER_KERNEL(func) \ + REGISTER_JITKERNEL_REFER(k##func, refer::func##Kernel, \ refer::func##Kernel) -REGISTER_REFER_KERNEL(kVMul, VMul); -REGISTER_REFER_KERNEL(kVAdd, VAdd); -REGISTER_REFER_KERNEL(kVAddRelu, VAddRelu); -REGISTER_REFER_KERNEL(kVSub, VSub); - -REGISTER_REFER_KERNEL(kVScal, VScal); -REGISTER_REFER_KERNEL(kVAddBias, VAddBias); - -REGISTER_REFER_KERNEL(kVRelu, VRelu); -REGISTER_REFER_KERNEL(kVCopy, VCopy); -REGISTER_REFER_KERNEL(kVIdentity, VIdentity); -REGISTER_REFER_KERNEL(kVSquare, VSquare); -REGISTER_REFER_KERNEL(kVExp, VExp); -REGISTER_REFER_KERNEL(kVSigmoid, VSigmoid); -REGISTER_REFER_KERNEL(kVTanh, VTanh); - -REGISTER_REFER_KERNEL(kLSTMCtHt, LSTMCtHt); -REGISTER_REFER_KERNEL(kLSTMC1H1, LSTMC1H1); - -REGISTER_REFER_KERNEL(kGRUH1, GRUH1); -REGISTER_REFER_KERNEL(kGRUHtPart1, GRUHtPart1); -REGISTER_REFER_KERNEL(kGRUHtPart2, GRUHtPart2); - -REGISTER_REFER_KERNEL(kCRFDecoding, CRFDecoding); -REGISTER_REFER_KERNEL(kLayerNorm, LayerNorm); - -REGISTER_REFER_KERNEL(kNCHW16CMulNC, NCHW16CMulNC); - -REGISTER_REFER_KERNEL(kSeqPool, SeqPool); - -REGISTER_REFER_KERNEL(kMatMul, MatMul); - -REGISTER_REFER_KERNEL(kHMax, HMax); -REGISTER_REFER_KERNEL(kHSum, HSum); - -REGISTER_REFER_KERNEL(kSoftmax, Softmax); - -REGISTER_REFER_KERNEL(kEmbSeqPool, EmbSeqPool); - -REGISTER_REFER_KERNEL(kSgd, Sgd); - -REGISTER_REFER_KERNEL(kVBroadcast, VBroadcast); +REGISTER_REFER_KERNEL(VMul); +REGISTER_REFER_KERNEL(VAdd); +REGISTER_REFER_KERNEL(VAddRelu); +REGISTER_REFER_KERNEL(VSub); + +REGISTER_REFER_KERNEL(VScal); +REGISTER_REFER_KERNEL(VAddBias); + +REGISTER_REFER_KERNEL(VRelu); +REGISTER_REFER_KERNEL(VCopy); +REGISTER_REFER_KERNEL(VIdentity); +REGISTER_REFER_KERNEL(VSquare); +REGISTER_REFER_KERNEL(VExp); +REGISTER_REFER_KERNEL(VSigmoid); +REGISTER_REFER_KERNEL(VTanh); + +REGISTER_REFER_KERNEL(LSTMCtHt); +REGISTER_REFER_KERNEL(LSTMC1H1); + +REGISTER_REFER_KERNEL(GRUH1); +REGISTER_REFER_KERNEL(GRUHtPart1); +REGISTER_REFER_KERNEL(GRUHtPart2); + +REGISTER_REFER_KERNEL(CRFDecoding); +REGISTER_REFER_KERNEL(LayerNorm); +REGISTER_REFER_KERNEL(NCHW16CMulNC); +REGISTER_REFER_KERNEL(SeqPool); +REGISTER_REFER_KERNEL(MatMul); +REGISTER_REFER_KERNEL(HMax); +REGISTER_REFER_KERNEL(HSum); +REGISTER_REFER_KERNEL(Softmax); +REGISTER_REFER_KERNEL(EmbSeqPool); +REGISTER_REFER_KERNEL(Sgd); +REGISTER_REFER_KERNEL(VBroadcast); #undef REGISTER_REFER_KERNEL diff --git a/paddle/fluid/operators/jit/refer/refer.h b/paddle/fluid/operators/jit/refer/refer.h index b3b2097828..cac705a484 100644 --- a/paddle/fluid/operators/jit/refer/refer.h +++ b/paddle/fluid/operators/jit/refer/refer.h @@ -490,60 +490,54 @@ void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows, } } -#define DECLARE_REFER_KERNEL(name, tuples) \ - template \ - class name##Kernel : public ReferKernel> { \ - public: \ - name##Kernel() { this->func = name; } \ +#define DECLARE_REFER_KERNEL(name) \ + template \ + class name##Kernel : public ReferKernel> { \ + public: \ + name##Kernel() { this->func = name; } \ } // const T* x, const T* y, T* z, int n -DECLARE_REFER_KERNEL(VMul, XYZNTuples); -DECLARE_REFER_KERNEL(VAdd, XYZNTuples); -DECLARE_REFER_KERNEL(VAddRelu, XYZNTuples); -DECLARE_REFER_KERNEL(VSub, XYZNTuples); +DECLARE_REFER_KERNEL(VMul); +DECLARE_REFER_KERNEL(VAdd); +DECLARE_REFER_KERNEL(VAddRelu); +DECLARE_REFER_KERNEL(VSub); // const T* a, const T* x, T* y, int n -DECLARE_REFER_KERNEL(VScal, AXYNTuples); -DECLARE_REFER_KERNEL(VAddBias, AXYNTuples); +DECLARE_REFER_KERNEL(VScal); +DECLARE_REFER_KERNEL(VAddBias); // const T* x, T* y, int n -DECLARE_REFER_KERNEL(VRelu, XYNTuples); -DECLARE_REFER_KERNEL(VIdentity, XYNTuples); -DECLARE_REFER_KERNEL(VExp, XYNTuples); -DECLARE_REFER_KERNEL(VSigmoid, XYNTuples); -DECLARE_REFER_KERNEL(VTanh, XYNTuples); -DECLARE_REFER_KERNEL(VSquare, XYNTuples); -DECLARE_REFER_KERNEL(VCopy, XYNTuples); +DECLARE_REFER_KERNEL(VRelu); +DECLARE_REFER_KERNEL(VIdentity); +DECLARE_REFER_KERNEL(VExp); +DECLARE_REFER_KERNEL(VSigmoid); +DECLARE_REFER_KERNEL(VTanh); +DECLARE_REFER_KERNEL(VSquare); +DECLARE_REFER_KERNEL(VCopy); // lstm_t*, const lstm_attr_t* -DECLARE_REFER_KERNEL(LSTMCtHt, LSTMTuples); -DECLARE_REFER_KERNEL(LSTMC1H1, LSTMTuples); +DECLARE_REFER_KERNEL(LSTMCtHt); +DECLARE_REFER_KERNEL(LSTMC1H1); // gru_t*, const gru_attr_t* -DECLARE_REFER_KERNEL(GRUH1, GRUTuples); -DECLARE_REFER_KERNEL(GRUHtPart1, GRUTuples); -DECLARE_REFER_KERNEL(GRUHtPart2, GRUTuples); - -DECLARE_REFER_KERNEL(CRFDecoding, CRFDecodingTuples); -DECLARE_REFER_KERNEL(LayerNorm, LayerNormTuples); - -DECLARE_REFER_KERNEL(NCHW16CMulNC, NCHW16CMulNCTuples); - -DECLARE_REFER_KERNEL(SeqPool, SeqPoolTuples); - -DECLARE_REFER_KERNEL(MatMul, MatMulTuples); - -DECLARE_REFER_KERNEL(HMax, XRNTuples); -DECLARE_REFER_KERNEL(HSum, XRNTuples); - -DECLARE_REFER_KERNEL(Softmax, SoftmaxTuples); - -DECLARE_REFER_KERNEL(EmbSeqPool, EmbSeqPoolTuples); - -DECLARE_REFER_KERNEL(Sgd, SgdTuples); - -DECLARE_REFER_KERNEL(VBroadcast, VBroadcastTuples); +DECLARE_REFER_KERNEL(GRUH1); +DECLARE_REFER_KERNEL(GRUHtPart1); +DECLARE_REFER_KERNEL(GRUHtPart2); + +DECLARE_REFER_KERNEL(HMax); +DECLARE_REFER_KERNEL(HSum); + +// others +DECLARE_REFER_KERNEL(CRFDecoding); +DECLARE_REFER_KERNEL(LayerNorm); +DECLARE_REFER_KERNEL(NCHW16CMulNC); +DECLARE_REFER_KERNEL(SeqPool); +DECLARE_REFER_KERNEL(MatMul); +DECLARE_REFER_KERNEL(Softmax); +DECLARE_REFER_KERNEL(EmbSeqPool); +DECLARE_REFER_KERNEL(Sgd); +DECLARE_REFER_KERNEL(VBroadcast); #undef DECLARE_REFER_KERNEL diff --git a/paddle/fluid/operators/jit/test.cc b/paddle/fluid/operators/jit/test.cc index 18f8c09f14..a574bf2079 100644 --- a/paddle/fluid/operators/jit/test.cc +++ b/paddle/fluid/operators/jit/test.cc @@ -64,413 +64,43 @@ std::vector TestSizes() { namespace jit = paddle::operators::jit; using CPUPlace = paddle::platform::CPUPlace; -template -struct TestFuncWithRefer { - void operator()(const typename KernelTuples::func_type tgt, Args... args) { - LOG(FATAL) << "Should specify this function."; - } -}; - -template -struct TestFuncWithRefer, std::vector, std::vector, - std::vector> { - void operator()(const typename jit::XYZNTuples::func_type tgt, - const std::vector& x, const std::vector& y, - const std::vector& zref) { - EXPECT_TRUE(tgt != nullptr); - EXPECT_EQ(zref.size(), x.size()); - EXPECT_EQ(zref.size(), y.size()); - const T* x_data = x.data(); - const T* y_data = y.data(); - const T* zref_data = zref.data(); - const int d = zref.size(); - - std::vector ztgt(d); - T* ztgt_data = ztgt.data(); - // test normal - tgt(x_data, y_data, ztgt_data, d); - ExpectEQ(ztgt_data, zref_data, d); - // test inplace x - std::copy(x.begin(), x.end(), ztgt.begin()); - tgt(ztgt_data, y_data, ztgt_data, d); - ExpectEQ(ztgt_data, zref_data, d); - // test inplace y - std::copy(y.begin(), y.end(), ztgt.begin()); - tgt(x_data, ztgt_data, ztgt_data, d); - ExpectEQ(ztgt_data, zref_data, d); - } -}; - -template -struct TestFuncWithRefer, T, std::vector, - std::vector> { - void operator()(const typename jit::AXYNTuples::func_type tgt, const T a, - const std::vector& x, const std::vector& yref) { - EXPECT_TRUE(tgt != nullptr); - EXPECT_EQ(yref.size(), x.size()); - const T* x_data = x.data(); - const T* yref_data = yref.data(); - const int d = yref.size(); - std::vector ytgt(d); - T* ytgt_data = ytgt.data(); - // test normal - tgt(&a, x_data, ytgt_data, d); - ExpectEQ(ytgt_data, yref_data, d); - // test inplace x - std::copy(x.begin(), x.end(), ytgt.begin()); - tgt(&a, ytgt_data, ytgt_data, d); - ExpectEQ(ytgt_data, yref_data, d); - } -}; - -template -struct TestFuncWithRefer, std::vector, std::vector, - int, int> { - void operator()(const typename jit::SoftmaxTuples::func_type tgt, - const std::vector& x, const std::vector& yref, int n, - int bs) { - EXPECT_TRUE(tgt != nullptr); - EXPECT_EQ(yref.size(), x.size()); - EXPECT_EQ(x.size(), static_cast(n * bs)); - const T* x_data = x.data(); - const T* yref_data = yref.data(); - std::vector ytgt(n * bs); - T* ytgt_data = ytgt.data(); - // test normal - tgt(x_data, ytgt_data, n, bs); - ExpectEQ(ytgt_data, yref_data, n * bs); - // test inplace x - std::copy(x.begin(), x.end(), ytgt.begin()); - tgt(ytgt_data, ytgt_data, n, bs); - ExpectEQ(ytgt_data, yref_data, n * bs); - } -}; - -template -struct TestFuncWithRefer, std::vector, T> { - void operator()(const typename jit::XRNTuples::func_type tgt, - const std::vector& x, const T ref_res) { - EXPECT_TRUE(tgt != nullptr); - T tgt_res; - tgt(x.data(), &tgt_res, x.size()); - ExpectEQ(&tgt_res, &ref_res, 1); - } -}; - -template -struct TestFuncWithRefer, std::vector, - std::vector, int64_t, - typename jit::VBroadcastTuples::attr_type> { - void operator()(const typename jit::VBroadcastTuples::func_type tgt, - const std::vector& x, const std::vector& yref, - int64_t h, - const typename jit::VBroadcastTuples::attr_type& attr) { - EXPECT_TRUE(tgt != nullptr); - EXPECT_EQ(x.size(), static_cast(attr)); - EXPECT_EQ(yref.size(), x.size() * h); - std::vector y(yref.size()); - const T* x_data = x.data(); - const T* yref_data = yref.data(); - T* y_data = y.data(); - tgt(x_data, y_data, h, attr); - ExpectEQ(y_data, yref_data, yref.size()); - } -}; - -template -struct TestFuncWithRefer, std::vector, std::vector> { - void operator()(const typename jit::XYNTuples::func_type tgt, - const std::vector& x, const std::vector& yref) { - EXPECT_TRUE(tgt != nullptr); - EXPECT_EQ(yref.size(), x.size()); - const T* x_data = x.data(); - const T* yref_data = yref.data(); - const int d = yref.size(); - std::vector ytgt(d); - T* ytgt_data = ytgt.data(); - // test normal - tgt(x_data, ytgt_data, d); - ExpectEQ(ytgt_data, yref_data, d); - // test inplace x - std::copy(x.begin(), x.end(), ytgt.begin()); - tgt(ytgt_data, ytgt_data, d); - ExpectEQ(ytgt_data, yref_data, d); - } -}; - -template -struct TestFuncWithRefer, std::vector, std::vector, - std::vector, std::vector, std::vector, - typename jit::LSTMTuples::attr_type> { - void operator()(const typename jit::LSTMTuples::func_type tgt, - const std::vector& xsrc, const std::vector& wp, - const std::vector& ct_1, const std::vector& ct_ref, - const std::vector& ht_ref, - const typename jit::LSTMTuples::attr_type& attr) { - EXPECT_TRUE(tgt != nullptr); - EXPECT_EQ(ct_ref.size(), ht_ref.size()); - EXPECT_EQ(ct_1.size(), ht_ref.size()); - EXPECT_EQ(xsrc.size(), 4 * ht_ref.size()); - EXPECT_EQ(wp.size(), 3 * ht_ref.size()); - - // x could be changed after compute, so copy to save src - int d = ht_ref.size(); - std::vector x(xsrc.size()), ct(ct_ref.size()), ht(ht_ref.size()); - std::vector checked(2 * d); - std::copy(xsrc.begin(), xsrc.end(), x.begin()); - - const T* ct_1_data = ct_1.data(); - const T* wp_data = wp.data(); - const T* ct_ref_data = ct_ref.data(); - const T* ht_ref_data = ht_ref.data(); - T* x_data = x.data(); - T* ct_data = ct.data(); - T* ht_data = ht.data(); - T* checked_data = checked.data(); - - jit::lstm_t step; - step.gates = x_data; - step.ct_1 = ct_1_data; - step.ct = ct_data; - step.ht = ht_data; - if (attr.use_peephole) { - step.wp = wp_data; - step.checked = checked_data; - } - - tgt(&step, &attr); - ExpectEQ(ct_data, ct_ref_data, d); - ExpectEQ(ht_data, ht_ref_data, d); - } -}; - -template -struct TestFuncWithRefer, std::vector, std::vector, - std::vector, - typename jit::GRUTuples::attr_type> { - void operator()(const typename jit::GRUTuples::func_type tgt, - const std::vector& xsrc, const std::vector& ht_1, - const std::vector& ht_ref, - const typename jit::GRUTuples::attr_type& attr) { - EXPECT_TRUE(tgt != nullptr); - EXPECT_EQ(ht_1.size(), ht_ref.size()); - EXPECT_EQ(xsrc.size(), 3 * ht_ref.size()); - - // x could be changed after compute, so copy to save src - int d = ht_ref.size(); - std::vector x(xsrc.size()), ht(ht_ref.size()); - std::copy(xsrc.begin(), xsrc.end(), x.begin()); - const T* ht_1_data = ht_1.data(); - const T* ht_ref_data = ht_ref.data(); - T* x_data = x.data(); - T* ht_data = ht.data(); - jit::gru_t step; - step.gates = x_data; - step.ht_1 = ht_1_data; - step.ht = ht_data; - tgt(&step, &attr); - ExpectEQ(ht_data, ht_ref_data, d); - } -}; - -template -struct TestFuncWithRefer, std::vector, std::vector, - typename jit::SeqPoolTuples::attr_type> { - void operator()(const typename jit::SeqPoolTuples::func_type tgt, - const std::vector& x, const std::vector& yref, - const typename jit::SeqPoolTuples::attr_type& attr) { - EXPECT_TRUE(tgt != nullptr); - EXPECT_EQ(x.size() % yref.size(), static_cast(0)); - int w = yref.size(); - std::vector y(w); - const T* x_data = x.data(); - const T* yref_data = yref.data(); - T* y_data = y.data(); - tgt(x_data, y_data, &attr); - ExpectEQ(y_data, yref_data, w); - } -}; - -template -struct TestFuncWithRefer, std::vector, - std::vector, std::vector, - typename jit::EmbSeqPoolTuples::attr_type> { - void operator()(const typename jit::EmbSeqPoolTuples::func_type tgt, - const std::vector& table, const std::vector& idx, - const std::vector& oref, - const typename jit::EmbSeqPoolTuples::attr_type& attr) { - EXPECT_TRUE(tgt != nullptr); - EXPECT_EQ(table.size(), - static_cast(attr.table_height * attr.table_width)); - EXPECT_EQ(idx.size(), - static_cast(attr.index_height * attr.index_width)); - EXPECT_EQ(oref.size(), - static_cast(attr.table_width * attr.index_width)); - const T* table_data = table.data(); - const int64_t* idx_data = idx.data(); - const T* oref_data = oref.data(); - int o_w = oref.size(); - std::vector out(o_w); - T* o_data = out.data(); - tgt(table_data, idx_data, o_data, &attr); - ExpectEQ(o_data, oref_data, o_w); - } -}; - -template -struct TestFuncWithRefer, T, std::vector, std::vector, - std::vector, std::vector, - typename jit::SgdTuples::attr_type> { - void operator()(const typename jit::SgdTuples::func_type tgt, const T lr, - const std::vector& param, const std::vector& grad, - const std::vector& rows, const std::vector& oref, - const typename jit::SgdTuples::attr_type& attr) { - EXPECT_TRUE(tgt != nullptr); - EXPECT_EQ(param.size(), - static_cast(attr.param_height * attr.param_width)); - EXPECT_EQ(grad.size(), - static_cast(attr.grad_height * attr.grad_width)); - EXPECT_EQ(rows.size(), static_cast(attr.selected_rows_size)); - EXPECT_EQ(param.size(), oref.size()); - const T* param_data = param.data(); - const T* grad_data = grad.data(); - const int64_t* rows_data = rows.data(); - const T* oref_data = oref.data(); - - std::vector out(oref.size()); - T* o_data = out.data(); - tgt(&lr, param_data, grad_data, rows_data, o_data, &attr); - // only the selected rows should be equal - for (size_t i = 0; i < rows.size(); ++i) { - ExpectEQ(o_data + rows[i] * attr.grad_width, - oref_data + rows[i] * attr.grad_width, attr.grad_width); - } - - // inplace - std::copy(param.begin(), param.end(), out.begin()); - tgt(&lr, o_data, grad_data, rows_data, o_data, &attr); - for (size_t i = 0; i < rows.size(); ++i) { - ExpectEQ(o_data + rows[i] * attr.grad_width, - oref_data + rows[i] * attr.grad_width, attr.grad_width); - } - } -}; - -template -struct TestFuncWithRefer, std::vector, std::vector, - std::vector, - typename jit::MatMulTuples::attr_type> { - void operator()(const typename jit::MatMulTuples::func_type tgt, - const std::vector& a, const std::vector& b, - const std::vector& cref, - const typename jit::MatMulTuples::attr_type& attr) { - EXPECT_TRUE(tgt != nullptr); - EXPECT_EQ(a.size(), static_cast(attr.m * attr.k)); - EXPECT_EQ(b.size(), static_cast(attr.k * attr.n)); - EXPECT_EQ(cref.size(), static_cast(attr.m * attr.n)); - std::vector c(cref.size()); - const T* a_data = a.data(); - const T* b_data = b.data(); - const T* cref_data = cref.data(); - T* c_data = c.data(); - tgt(a_data, b_data, c_data, &attr); - ExpectEQ(c_data, cref_data, attr.m * attr.n); - } -}; - -template -struct TestFuncWithRefer, std::vector, - std::vector, std::vector, std::vector, - std::vector, std::vector, int, float, int> { - void operator()(const typename jit::LayerNormTuples::func_type tgt, - std::vector& x, std::vector& outref, // NOLINT - std::vector& mean, std::vector& var, // NOLINT - const std::vector& scale, const std::vector& bias, - int left, const float epsilon, int right) { - EXPECT_TRUE(tgt != nullptr); - EXPECT_EQ(x.size(), static_cast(left * right)); - EXPECT_EQ(outref.size(), static_cast(left * right)); - EXPECT_EQ(mean.size(), static_cast(left)); - EXPECT_EQ(var.size(), static_cast(left)); - EXPECT_EQ(scale.size(), static_cast(right)); - EXPECT_EQ(bias.size(), static_cast(right)); - std::vector outtgt(outref.size()); - const T* scale_data = scale.data(); - const T* bias_data = bias.data(); - T* x_data = x.data(); - T* mean_data = mean.data(); - T* var_data = var.data(); - T* outref_data = outref.data(); - T* outtgt_data = outtgt.data(); - - tgt(x_data, outtgt_data, mean_data, var_data, scale_data, bias_data, left, - epsilon, right); - ExpectEQ(outtgt_data, outref_data, left * right); - } -}; - -template -struct TestFuncWithRefer, int, std::vector, - std::vector, std::vector, std::vector, - int> { - void operator()(const typename jit::CRFDecodingTuples::func_type tgt, - const int seq_len, const std::vector& x, - const std::vector& w, std::vector& alpharef, // NOLINT - std::vector& trackref, int tag_num) { // NOLINT - constexpr int state_trans_base_idx = 2; - EXPECT_TRUE(tgt != nullptr); - EXPECT_EQ(x.size(), static_cast(seq_len * tag_num)); - EXPECT_EQ(w.size(), - static_cast((tag_num + state_trans_base_idx) * tag_num)); - EXPECT_EQ(alpharef.size(), static_cast(seq_len * tag_num)); - EXPECT_EQ(trackref.size(), static_cast(seq_len * tag_num)); - std::vector alphatgt(alpharef.size()); - std::vector tracktgt(trackref.size()); - - memcpy(trackref.data(), tracktgt.data(), tag_num * sizeof(int)); - tgt(seq_len, (const T*)x.data(), (const T*)w.data(), alphatgt.data(), - tracktgt.data(), tag_num); - ExpectEQ(alpharef.data(), alphatgt.data(), seq_len * tag_num); - ExpectEQ(trackref.data(), tracktgt.data(), seq_len * tag_num); - } -}; - -template -void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) { - TestFuncWithRefer test; +void TestAllImpls(const typename KernelTuple::attr_type& attr, + const Tester& verifier, const Args&... args) { // test jitcode - auto jitcode = jit::GetJitCode(attr); + auto jitcode = jit::GetJitCode(attr); if (jitcode) { VLOG(10) << "Test Jitcode Kernel "; - test(jitcode, args...); + verifier(jitcode, args...); } // test all impls in more - jit::KernelKey kkey(KT, PlaceType()); + jit::KernelKey kkey(KernelTuple::kernel_type, PlaceType()); auto& pool = jit::KernelPool().Instance().AllKernels(); auto iter = pool.find(kkey); if (iter != pool.end()) { auto& impls = iter->second; for (auto& impl : impls) { - auto i = dynamic_cast*>(impl.get()); + auto i = dynamic_cast*>(impl.get()); if (i && i->UseMe(attr)) { auto more = i->GetFunc(); VLOG(10) << "Test More Kernel : " << i->ImplType(); - test(more, args...); + verifier(more, args...); } } } // test result from Get function - // VLOG(10) << "Test Get function "; - auto tgt = jit::KernelFuncs::Cache().At(attr); - test(tgt, args...); + VLOG(10) << "Test final get function "; + auto tgt = jit::KernelFuncs::Cache().At(attr); + verifier(tgt, args...); } -template -void TestKernelXYZNTuples() { - VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); +template +void TestKernelXYZN() { + using T = typename KernelTuple::data_type; + VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type); for (int d : TestSizes()) { - auto ref = jit::GetRefer>(); + auto ref = jit::GetRefer(); EXPECT_TRUE(ref != nullptr); std::vector x(d), y(d), zref(d); @@ -494,16 +124,42 @@ void TestKernelXYZNTuples() { ExpectEQ(xinp_data, zref_data, d); ExpectEQ(yinp_data, zref_data, d); - TestAllImpls, PlaceType, std::vector, - std::vector, std::vector>(d, x, y, zref); + auto verifier = [](const typename KernelTuple::func_type tgt, + const std::vector& x, const std::vector& y, + const std::vector& zref) { + EXPECT_TRUE(tgt != nullptr); + EXPECT_EQ(zref.size(), x.size()); + EXPECT_EQ(zref.size(), y.size()); + const T* x_data = x.data(); + const T* y_data = y.data(); + const T* zref_data = zref.data(); + const int d = zref.size(); + + std::vector ztgt(d); + T* ztgt_data = ztgt.data(); + // test normal + tgt(x_data, y_data, ztgt_data, d); + ExpectEQ(ztgt_data, zref_data, d); + // test inplace x + std::copy(x.begin(), x.end(), ztgt.begin()); + tgt(ztgt_data, y_data, ztgt_data, d); + ExpectEQ(ztgt_data, zref_data, d); + // test inplace y + std::copy(y.begin(), y.end(), ztgt.begin()); + tgt(x_data, ztgt_data, ztgt_data, d); + ExpectEQ(ztgt_data, zref_data, d); + }; + + TestAllImpls(d, verifier, x, y, zref); } } -template -void TestKernelAXYNTuples() { - VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); +template +void TestKernelAXYN() { + using T = typename KernelTuple::data_type; + VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type); for (int d : TestSizes()) { - auto ref = jit::GetRefer>(); + auto ref = jit::GetRefer(); EXPECT_TRUE(ref != nullptr); const T a = static_cast(3); @@ -520,34 +176,33 @@ void TestKernelAXYNTuples() { ref(&a, xinp_data, xinp_data, d); ExpectEQ(xinp_data, yref_data, d); - TestAllImpls, PlaceType, T, std::vector, - std::vector>(d, a, x, yref); - } -} - -template -void TestKernelXRNTuples() { - VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); - auto last_acc = FLAGS_acc; - FLAGS_acc = 1e-4; - for (int d : TestSizes()) { - auto ref = jit::GetRefer>(); - EXPECT_TRUE(ref != nullptr); - std::vector x(d); - RandomVec(d, x.data()); - T ref_res; - ref(x.data(), &ref_res, d); - TestAllImpls, PlaceType, std::vector, T>(d, x, - ref_res); + auto verifier = [](const typename KernelTuple::func_type tgt, const T a, + const std::vector& x, const std::vector& yref) { + EXPECT_TRUE(tgt != nullptr); + EXPECT_EQ(yref.size(), x.size()); + const T* x_data = x.data(); + const T* yref_data = yref.data(); + const int d = yref.size(); + std::vector ytgt(d); + T* ytgt_data = ytgt.data(); + // test normal + tgt(&a, x_data, ytgt_data, d); + ExpectEQ(ytgt_data, yref_data, d); + // test inplace x + std::copy(x.begin(), x.end(), ytgt.begin()); + tgt(&a, ytgt_data, ytgt_data, d); + ExpectEQ(ytgt_data, yref_data, d); + }; + TestAllImpls(d, verifier, a, x, yref); } - FLAGS_acc = last_acc; } -template -void TestKernelXYNTuples() { - VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); +template +void TestKernelXYN() { + using T = typename KernelTuple::data_type; + VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type); for (int d : TestSizes()) { - auto ref = jit::GetRefer>(); + auto ref = jit::GetRefer(); EXPECT_TRUE(ref != nullptr); std::vector x(d), yref(d); @@ -562,15 +217,57 @@ void TestKernelXYNTuples() { ref(x_data, yref_data, d); ref(xinp_data, xinp_data, d); ExpectEQ(xinp_data, yref_data, d); + auto verifier = [](const typename KernelTuple::func_type tgt, + const std::vector& x, const std::vector& yref) { + EXPECT_TRUE(tgt != nullptr); + EXPECT_EQ(yref.size(), x.size()); + const T* x_data = x.data(); + const T* yref_data = yref.data(); + const int d = yref.size(); + std::vector ytgt(d); + T* ytgt_data = ytgt.data(); + // test normal + tgt(x_data, ytgt_data, d); + ExpectEQ(ytgt_data, yref_data, d); + // test inplace x + std::copy(x.begin(), x.end(), ytgt.begin()); + tgt(ytgt_data, ytgt_data, d); + ExpectEQ(ytgt_data, yref_data, d); + }; + TestAllImpls(d, verifier, x, yref); + } +} + +template +void TestKernelXRN() { + using T = typename KernelTuple::data_type; + VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type); + auto last_acc = FLAGS_acc; + FLAGS_acc = 1e-4; + for (int d : TestSizes()) { + auto ref = jit::GetRefer(); + EXPECT_TRUE(ref != nullptr); + std::vector x(d); + RandomVec(d, x.data()); + T ref_res; + ref(x.data(), &ref_res, d); - TestAllImpls, PlaceType, std::vector, - std::vector>(d, x, yref); + auto verifier = [](const typename KernelTuple::func_type tgt, + const std::vector& x, const T ref_res) { + EXPECT_TRUE(tgt != nullptr); + T tgt_res; + tgt(x.data(), &tgt_res, x.size()); + ExpectEQ(&tgt_res, &ref_res, 1); + }; + TestAllImpls(d, verifier, x, ref_res); } + FLAGS_acc = last_acc; } -template -void TestKernelLSTMTuples() { - VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); +template +void TestKernelLSTM() { + using T = typename KernelTuple::data_type; + VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type); std::vector all_acts = {"sigmoid", "tanh", "relu", "identity"}; auto test_sizes = TestSizes(); test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 1000)); @@ -582,7 +279,7 @@ void TestKernelLSTMTuples() { const jit::lstm_attr_t attr( d, jit::to_kerneltype(act_gate), jit::to_kerneltype(act_cand), jit::to_kerneltype(act_cell), use_peephole); - auto ref = jit::GetRefer>(); + auto ref = jit::GetRefer(); EXPECT_TRUE(ref != nullptr); std::vector xsrc(4 * d), wp(3 * d), ct_1(d); std::vector ct_ref(d), ht_ref(d), checked(2 * d); @@ -609,10 +306,51 @@ void TestKernelLSTMTuples() { } ref(&step, &attr); VLOG(10) << attr; - TestAllImpls, PlaceType, std::vector, - std::vector, std::vector, std::vector, - std::vector>(attr, xsrc, wp, ct_1, ct_ref, ht_ref, - attr); + + auto verifier = []( + const typename KernelTuple::func_type tgt, + const std::vector& xsrc, const std::vector& wp, + const std::vector& ct_1, const std::vector& ct_ref, + const std::vector& ht_ref, + const typename KernelTuple::attr_type& attr) { + EXPECT_TRUE(tgt != nullptr); + EXPECT_EQ(ct_ref.size(), ht_ref.size()); + EXPECT_EQ(ct_1.size(), ht_ref.size()); + EXPECT_EQ(xsrc.size(), 4 * ht_ref.size()); + EXPECT_EQ(wp.size(), 3 * ht_ref.size()); + + // x could be changed after compute, so copy to save src + int d = ht_ref.size(); + std::vector x(xsrc.size()), ct(ct_ref.size()), + ht(ht_ref.size()); + std::vector checked(2 * d); + std::copy(xsrc.begin(), xsrc.end(), x.begin()); + + const T* ct_1_data = ct_1.data(); + const T* wp_data = wp.data(); + const T* ct_ref_data = ct_ref.data(); + const T* ht_ref_data = ht_ref.data(); + T* x_data = x.data(); + T* ct_data = ct.data(); + T* ht_data = ht.data(); + T* checked_data = checked.data(); + + jit::lstm_t step; + step.gates = x_data; + step.ct_1 = ct_1_data; + step.ct = ct_data; + step.ht = ht_data; + if (attr.use_peephole) { + step.wp = wp_data; + step.checked = checked_data; + } + + tgt(&step, &attr); + ExpectEQ(ct_data, ct_ref_data, d); + ExpectEQ(ht_data, ht_ref_data, d); + }; + TestAllImpls(attr, verifier, xsrc, wp, ct_1, + ct_ref, ht_ref, attr); } } } @@ -620,9 +358,10 @@ void TestKernelLSTMTuples() { } } -template -void TestKernelGRUTuples() { - VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); +template +void TestKernelGRU() { + using T = typename KernelTuple::data_type; + VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type); std::vector all_acts = {"sigmoid", "tanh", "relu", "identity"}; auto test_sizes = TestSizes(); test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 1000)); @@ -631,7 +370,7 @@ void TestKernelGRUTuples() { for (auto& act_cand : all_acts) { const jit::gru_attr_t attr(d, jit::to_kerneltype(act_gate), jit::to_kerneltype(act_cand)); - auto ref = jit::GetRefer>(); + auto ref = jit::GetRefer(); EXPECT_TRUE(ref != nullptr); std::vector xsrc(3 * d), ht_1(d), ht_ref(d); RandomVec(3 * d, xsrc.data()); @@ -648,17 +387,216 @@ void TestKernelGRUTuples() { step.ht = ht_ref_data; ref(&step, &attr); VLOG(10) << attr; - TestAllImpls, PlaceType, std::vector, - std::vector, std::vector>(attr, xsrc, ht_1, ht_ref, - attr); + auto verifier = [](const typename KernelTuple::func_type tgt, + const std::vector& xsrc, + const std::vector& ht_1, + const std::vector& ht_ref, + const typename KernelTuple::attr_type& attr) { + EXPECT_TRUE(tgt != nullptr); + EXPECT_EQ(ht_1.size(), ht_ref.size()); + EXPECT_EQ(xsrc.size(), 3 * ht_ref.size()); + + // x could be changed after compute, so copy to save src + int d = ht_ref.size(); + std::vector x(xsrc.size()), ht(ht_ref.size()); + std::copy(xsrc.begin(), xsrc.end(), x.begin()); + const T* ht_1_data = ht_1.data(); + const T* ht_ref_data = ht_ref.data(); + T* x_data = x.data(); + T* ht_data = ht.data(); + jit::gru_t step; + step.gates = x_data; + step.ht_1 = ht_1_data; + step.ht = ht_data; + tgt(&step, &attr); + ExpectEQ(ht_data, ht_ref_data, d); + }; + TestAllImpls(attr, verifier, xsrc, ht_1, ht_ref, + attr); } } } } -template -void TestKernelSeqPoolTuples() { - VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); +template +void TestKernelNCHW16CMulNC() { + using T = typename KernelTuple::data_type; + VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type); + const int n = 3, c = 16 * 4, h = 10, w = 10; + auto ref = jit::GetRefer(); + EXPECT_TRUE(ref != nullptr); + int sz = n * c * h * w; + std::vector x(sz), y(n * c), zref(sz); + std::vector ztgt(sz), zjit(sz); + RandomVec(sz, x.data()); + RandomVec(n * c, y.data()); + + const T* x_data = x.data(); + const T* y_data = y.data(); + T* zref_data = zref.data(); + T* ztgt_data = ztgt.data(); + T* zjit_data = zjit.data(); + constexpr int simd_width = ZMM_FLOAT_BLOCK; + int C = c / simd_width; + auto tgt = jit::KernelFuncs::Cache().At(0); + auto jitcode = jit::GetJitCode(0); + EXPECT_TRUE(tgt != nullptr); + + if (std::is_same::value && + paddle::platform::MayIUse(paddle::platform::avx512f)) { + EXPECT_TRUE(jitcode != nullptr); + } + for (int ni = 0; ni < n; ni++) { + for (int ci = 0; ci < C; ci++) { + auto ptr_x = + x_data + ni * C * h * w * simd_width + ci * h * w * simd_width; + auto ptr_y = y_data + ni * C * simd_width + ci * simd_width; + auto ptr_zref = + zref_data + ni * C * h * w * simd_width + ci * h * w * simd_width; + auto ptr_ztgt = + ztgt_data + ni * C * h * w * simd_width + ci * h * w * simd_width; + + ref(ptr_x, ptr_y, ptr_zref, h, w); + tgt(ptr_x, ptr_y, ptr_ztgt, h, w); + + if (jitcode) { + auto ptr_zjit = + zjit_data + ni * C * h * w * simd_width + ci * h * w * simd_width; + jitcode(ptr_x, ptr_y, ptr_zjit, h, w); + } + } + } + ExpectEQ(ztgt_data, zref_data, sz); + if (jitcode) { + ExpectEQ(zjit_data, zref_data, sz); + } +} + +template +void TestKernelLayerNorm() { + using T = typename KernelTuple::data_type; + VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type); + const T epsilon = 9.99999975e-06; + for (int n : {1, 2, 10}) { + for (int x_dim_0 : {1, 9, 17, 50}) { + int left = n * x_dim_0; + for (int x_dim_1 : TestSizes()) { + int right = x_dim_1; + auto ref = jit::GetRefer(); + EXPECT_TRUE(ref != nullptr); + int sz = left * right; + std::vector x(sz), mean(left), var(left), scale(right), bias(right), + outref(sz); + RandomVec(sz, x.data()); + RandomVec(left, mean.data()); + RandomVec(left, var.data()); + RandomVec(right, scale.data()); + RandomVec(right, bias.data()); + + const T* scale_data = scale.data(); + const T* bias_data = bias.data(); + T* x_data = x.data(); + T* mean_data = mean.data(); + T* var_data = var.data(); + T* outref_data = outref.data(); + + ref(x_data, outref_data, mean_data, var_data, scale_data, bias_data, + left, epsilon, right); + + auto verifier = []( + const typename KernelTuple::func_type tgt, const std::vector& x_, + const std::vector& outref_, const std::vector& mean_, + const std::vector& var_, const std::vector& scale, + const std::vector& bias, const int& left, const float& epsilon, + const typename KernelTuple::attr_type& right) { + EXPECT_TRUE(tgt != nullptr); + std::vector outtgt(outref_.size()); + std::vector x(x_.size()); + std::vector mean(mean_.size()); + std::vector var(var_.size()); + std::vector outref(outref_.size()); + std::copy(x_.begin(), x_.end(), x.begin()); + std::copy(mean_.begin(), mean_.end(), mean.begin()); + std::copy(var_.begin(), var_.end(), var.begin()); + std::copy(outref_.begin(), outref_.end(), outref.begin()); + + EXPECT_EQ(x.size(), static_cast(left * right)); + EXPECT_EQ(outref.size(), static_cast(left * right)); + EXPECT_EQ(mean.size(), static_cast(left)); + EXPECT_EQ(var.size(), static_cast(left)); + EXPECT_EQ(scale.size(), static_cast(right)); + EXPECT_EQ(bias.size(), static_cast(right)); + + const T* scale_data = scale.data(); + const T* bias_data = bias.data(); + T* x_data = x.data(); + T* mean_data = mean.data(); + T* var_data = var.data(); + T* outref_data = outref.data(); + T* outtgt_data = outtgt.data(); + tgt(x_data, outtgt_data, mean_data, var_data, scale_data, bias_data, + left, epsilon, right); + ExpectEQ(outtgt_data, outref_data, left * right); + }; + TestAllImpls(right, verifier, x, outref, mean, + var, scale, bias, left, epsilon, + right); + } + } + } +} + +template +void TestKernelCRFDecoding() { + using T = typename KernelTuple::data_type; + VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type); + constexpr int state_trans_base_idx = 2; + auto test_sizes = TestSizes(); + test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 2000)); + for (int seq_len : {1, 11, 17, 50}) { + for (int tag_num : test_sizes) { + auto ref = jit::GetRefer(); + EXPECT_TRUE(ref != nullptr); + int x_sz = seq_len * tag_num; + int w_sz = (tag_num + state_trans_base_idx) * tag_num; + std::vector x(x_sz), w(w_sz), alpharef(x_sz); + std::vector trackref(x_sz); + RandomVec(x_sz, x.data()); + RandomVec(w_sz, w.data()); + + ref(seq_len, (const T*)x.data(), (const T*)w.data(), alpharef.data(), + trackref.data(), tag_num); + + auto verifier = []( + const typename KernelTuple::func_type tgt, const int& seq_len, + const std::vector& x, const std::vector& w, + const std::vector& alpharef, const std::vector& trackref, + const typename KernelTuple::attr_type& tag_num) { + constexpr int state_trans_base_idx = 2; + EXPECT_TRUE(tgt != nullptr); + EXPECT_EQ(x.size(), static_cast(seq_len * tag_num)); + EXPECT_EQ(w.size(), static_cast( + (tag_num + state_trans_base_idx) * tag_num)); + EXPECT_EQ(alpharef.size(), static_cast(seq_len * tag_num)); + EXPECT_EQ(trackref.size(), static_cast(seq_len * tag_num)); + std::vector alphatgt(alpharef.size()); + std::vector tracktgt(trackref.size()); + memcpy(tracktgt.data(), trackref.data(), tag_num * sizeof(int)); + tgt(seq_len, (const T*)x.data(), (const T*)w.data(), alphatgt.data(), + tracktgt.data(), tag_num); + ExpectEQ(alpharef.data(), alphatgt.data(), seq_len * tag_num); + ExpectEQ(trackref.data(), tracktgt.data(), seq_len * tag_num); + }; + TestAllImpls(tag_num, verifier, seq_len, x, w, + alpharef, trackref, tag_num); + } + } +} + +template +void TestKernelSeqPool() { + using T = typename KernelTuple::data_type; + VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type); std::vector pool_types = { jit::SeqPoolType::kSum, jit::SeqPoolType::kAvg, jit::SeqPoolType::kSqrt}; auto test_sizes = TestSizes(); @@ -668,7 +606,7 @@ void TestKernelSeqPoolTuples() { jit::seq_pool_attr_t attr(w, type); for (int h : test_sizes) { attr.h = h; - auto ref = jit::GetRefer>(); + auto ref = jit::GetRefer(); EXPECT_TRUE(ref != nullptr); std::vector x(h * w), yref(w); RandomVec(h * w, x.data()); @@ -676,16 +614,86 @@ void TestKernelSeqPoolTuples() { T* yref_data = yref.data(); ref(x_data, yref_data, &attr); VLOG(10) << attr; - TestAllImpls, PlaceType, std::vector, - std::vector>(attr, x, yref, attr); + auto verifier = [](const typename KernelTuple::func_type tgt, + const std::vector& x, const std::vector& yref, + const typename KernelTuple::attr_type& attr) { + EXPECT_TRUE(tgt != nullptr); + EXPECT_EQ(x.size() % yref.size(), static_cast(0)); + int w = yref.size(); + std::vector y(w); + const T* x_data = x.data(); + const T* yref_data = yref.data(); + T* y_data = y.data(); + tgt(x_data, y_data, &attr); + ExpectEQ(y_data, yref_data, w); + }; + TestAllImpls(attr, verifier, x, yref, attr); + } + } + } +} + +template +void TestKernelEmbSeqPool() { + using T = typename KernelTuple::data_type; + VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type); + int64_t tbl_h = 1e4; + std::vector pool_types = { + jit::SeqPoolType::kSum}; // only support sum yet + auto test_sizes = TestSizes(); + test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 1000)); + for (int tbl_w : test_sizes) { + std::vector table(tbl_h * tbl_w); + RandomVec(tbl_h * tbl_w, table.data()); + const T* table_data = table.data(); + for (auto type : pool_types) { + for (int idx_w : {1, 2, 10, 16}) { + for (int idx_h : {1, 2, 9, 13, 16}) { + auto ref = jit::GetRefer(); + EXPECT_TRUE(ref != nullptr); + std::vector idx(idx_h * idx_w); + RandomVec(idx_h * idx_w, idx.data(), 0, tbl_h - 1); + int64_t out_w = tbl_w * idx_w; + std::vector oref(out_w); + const int64_t* idx_data = idx.data(); + T* o_data = oref.data(); + jit::emb_seq_pool_attr_t attr(tbl_h, tbl_w, idx_h, idx_w, out_w, + type); + ref(table_data, idx_data, o_data, &attr); + + auto verifier = [](const typename KernelTuple::func_type tgt, + const std::vector& table, + const std::vector& idx, + const std::vector& oref, + const typename KernelTuple::attr_type& attr) { + EXPECT_TRUE(tgt != nullptr); + EXPECT_EQ(table.size(), static_cast(attr.table_height * + attr.table_width)); + EXPECT_EQ(idx.size(), static_cast(attr.index_height * + attr.index_width)); + EXPECT_EQ(oref.size(), + static_cast(attr.table_width * attr.index_width)); + const T* table_data = table.data(); + const int64_t* idx_data = idx.data(); + const T* oref_data = oref.data(); + int o_w = oref.size(); + std::vector out(o_w); + T* o_data = out.data(); + tgt(table_data, idx_data, o_data, &attr); + ExpectEQ(o_data, oref_data, o_w); + }; + TestAllImpls(attr, verifier, table, idx, oref, + attr); + } } } } } -template -void TestKernelMatMulTuples() { - VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); +template +void TestKernelMatMul() { + using T = typename KernelTuple::data_type; + VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type); auto last_acc = FLAGS_acc; // export MKL_CBWR=AVX would make MKL force to use AVX // export KMP_DETERMINISTIC_REDUCTION=yes would make the result deterministic @@ -693,7 +701,7 @@ void TestKernelMatMulTuples() { for (int m : {1, 2, 3, 4}) { for (int n : {1, 2, 3, 4}) { for (int k : TestSizes()) { - auto ref = jit::GetRefer>(); + auto ref = jit::GetRefer(); EXPECT_TRUE(ref != nullptr); std::vector a(m * k), b(k * n), c(m * n); RandomVec(m * k, a.data()); @@ -703,20 +711,36 @@ void TestKernelMatMulTuples() { T* c_data = c.data(); const jit::matmul_attr_t attr{m, n, k}; ref(a_data, b_data, c_data, &attr); - TestAllImpls, PlaceType, std::vector, - std::vector, std::vector>(attr, a, b, c, attr); + auto verifier = [](const typename KernelTuple::func_type tgt, + const std::vector& a, const std::vector& b, + const std::vector& cref, + const typename KernelTuple::attr_type& attr) { + EXPECT_TRUE(tgt != nullptr); + EXPECT_EQ(a.size(), static_cast(attr.m * attr.k)); + EXPECT_EQ(b.size(), static_cast(attr.k * attr.n)); + EXPECT_EQ(cref.size(), static_cast(attr.m * attr.n)); + std::vector c(cref.size()); + const T* a_data = a.data(); + const T* b_data = b.data(); + const T* cref_data = cref.data(); + T* c_data = c.data(); + tgt(a_data, b_data, c_data, &attr); + ExpectEQ(c_data, cref_data, attr.m * attr.n); + }; + TestAllImpls(attr, verifier, a, b, c, attr); } } } FLAGS_acc = last_acc; } -template -void TestKernelSoftmaxTuples() { - VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); +template +void TestKernelSoftmax() { + using T = typename KernelTuple::data_type; + VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type); for (int bs : {1, 2, 10}) { for (int n : TestSizes()) { - auto ref = jit::GetRefer>(); + auto ref = jit::GetRefer(); EXPECT_TRUE(ref != nullptr); std::vector x(bs * n), y(bs * n); RandomVec(bs * n, x.data()); @@ -730,51 +754,33 @@ void TestKernelSoftmaxTuples() { ref(xinp_data, xinp_data, n, bs); ExpectEQ(xinp_data, y_data, n * bs); - TestAllImpls, PlaceType, std::vector, - std::vector>(n, x, y, n, bs); - } - } -} - -template -void TestKernelEmbSeqPoolTuples() { - VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); - int64_t tbl_h = 1e4; - std::vector pool_types = { - jit::SeqPoolType::kSum}; // only support sum yet - auto test_sizes = TestSizes(); - test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 1000)); - for (int tbl_w : test_sizes) { - std::vector table(tbl_h * tbl_w); - RandomVec(tbl_h * tbl_w, table.data()); - const T* table_data = table.data(); - for (auto type : pool_types) { - for (int idx_w : {1, 2, 10, 16}) { - for (int idx_h : {1, 2, 9, 13, 16}) { - auto ref = jit::GetRefer>(); - EXPECT_TRUE(ref != nullptr); - std::vector idx(idx_h * idx_w); - RandomVec(idx_h * idx_w, idx.data(), 0, tbl_h - 1); - int64_t out_w = tbl_w * idx_w; - std::vector oref(out_w); - const int64_t* idx_data = idx.data(); - T* o_data = oref.data(); - jit::emb_seq_pool_attr_t attr(tbl_h, tbl_w, idx_h, idx_w, out_w, - type); - ref(table_data, idx_data, o_data, &attr); - - TestAllImpls, PlaceType, std::vector, - std::vector, std::vector>(attr, table, idx, - oref, attr); - } - } + auto verifier = [](const typename KernelTuple::func_type tgt, + const std::vector& x, const std::vector& yref, + int n, int bs) { + EXPECT_TRUE(tgt != nullptr); + EXPECT_EQ(yref.size(), x.size()); + EXPECT_EQ(x.size(), static_cast(n * bs)); + const T* x_data = x.data(); + const T* yref_data = yref.data(); + std::vector ytgt(n * bs); + T* ytgt_data = ytgt.data(); + // test normal + tgt(x_data, ytgt_data, n, bs); + ExpectEQ(ytgt_data, yref_data, n * bs); + // test inplace x + std::copy(x.begin(), x.end(), ytgt.begin()); + tgt(ytgt_data, ytgt_data, n, bs); + ExpectEQ(ytgt_data, yref_data, n * bs); + }; + TestAllImpls(n, verifier, x, y, n, bs); } } } -template -void TestKernelSgdTuples() { - VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); +template +void TestKernelSgd() { + using T = typename KernelTuple::data_type; + VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type); const T lr = 0.1; auto UnDuplicatedRandomVec = [](int n, const int64_t lower, const int64_t upper) -> std::vector { @@ -802,7 +808,7 @@ void TestKernelSgdTuples() { RandomVec(rows_size * grad_w, grad.data()); const int64_t* rows_data = rows.data(); const T* grad_data = grad.data(); - auto ref = jit::GetRefer>(); + auto ref = jit::GetRefer(); EXPECT_TRUE(ref != nullptr); jit::sgd_attr_t attr(param_h, grad_w, rows_size, grad_w, rows_size); ref(&lr, param_data, grad_data, rows_data, out_data, &attr); @@ -818,199 +824,150 @@ void TestKernelSgdTuples() { grad_w); } - TestAllImpls, PlaceType, T, std::vector, - std::vector, std::vector, std::vector>( - attr, lr, param, grad, rows, param_out, attr); - } - } - } -} - -template -void TestKernelNCHW16CMulNCTuples() { - VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); - const int n = 3, c = 16 * 4, h = 10, w = 10; - auto ref = jit::GetRefer>(); - EXPECT_TRUE(ref != nullptr); - int sz = n * c * h * w; - std::vector x(sz), y(n * c), zref(sz); - std::vector ztgt(sz), zjit(sz); - RandomVec(sz, x.data()); - RandomVec(n * c, y.data()); - - const T* x_data = x.data(); - const T* y_data = y.data(); - T* zref_data = zref.data(); - T* ztgt_data = ztgt.data(); - T* zjit_data = zjit.data(); - constexpr int simd_width = ZMM_FLOAT_BLOCK; - int C = c / simd_width; - auto tgt = - jit::KernelFuncs, PlaceType>::Cache().At( - 0); - auto jitcode = jit::GetJitCode, PlaceType>(0); - EXPECT_TRUE(tgt != nullptr); - - if (std::is_same::value && - paddle::platform::MayIUse(paddle::platform::avx512f)) { - EXPECT_TRUE(jitcode != nullptr); - } - for (int ni = 0; ni < n; ni++) { - for (int ci = 0; ci < C; ci++) { - auto ptr_x = - x_data + ni * C * h * w * simd_width + ci * h * w * simd_width; - auto ptr_y = y_data + ni * C * simd_width + ci * simd_width; - auto ptr_zref = - zref_data + ni * C * h * w * simd_width + ci * h * w * simd_width; - auto ptr_ztgt = - ztgt_data + ni * C * h * w * simd_width + ci * h * w * simd_width; - - ref(ptr_x, ptr_y, ptr_zref, h, w); - tgt(ptr_x, ptr_y, ptr_ztgt, h, w); - - if (jitcode) { - auto ptr_zjit = - zjit_data + ni * C * h * w * simd_width + ci * h * w * simd_width; - jitcode(ptr_x, ptr_y, ptr_zjit, h, w); - } - } - } - ExpectEQ(ztgt_data, zref_data, sz); - if (jitcode) { - ExpectEQ(zjit_data, zref_data, sz); - } -} - -template -void TestKernelLayerNormTuples() { - VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); - const T epsilon = 9.99999975e-06; - for (int n : {1, 2, 10}) { - for (int x_dim_0 : {1, 9, 17, 50}) { - int left = n * x_dim_0; - for (int x_dim_1 : TestSizes()) { - int right = x_dim_1; - auto ref = jit::GetRefer>(); - EXPECT_TRUE(ref != nullptr); - int sz = left * right; - std::vector x(sz), mean(left), var(left), scale(right), bias(right), - outref(sz); - RandomVec(sz, x.data()); - RandomVec(left, mean.data()); - RandomVec(left, var.data()); - RandomVec(right, scale.data()); - RandomVec(right, bias.data()); - - const T* scale_data = scale.data(); - const T* bias_data = bias.data(); - T* x_data = x.data(); - T* mean_data = mean.data(); - T* var_data = var.data(); - T* outref_data = outref.data(); - - ref(x_data, outref_data, mean_data, var_data, scale_data, bias_data, - left, epsilon, right); + auto verifier = []( + const typename KernelTuple::func_type tgt, const T lr, + const std::vector& param, const std::vector& grad, + const std::vector& rows, const std::vector& oref, + const typename KernelTuple::attr_type& attr) { + EXPECT_TRUE(tgt != nullptr); + EXPECT_EQ(param.size(), + static_cast(attr.param_height * attr.param_width)); + EXPECT_EQ(grad.size(), + static_cast(attr.grad_height * attr.grad_width)); + EXPECT_EQ(rows.size(), static_cast(attr.selected_rows_size)); + EXPECT_EQ(param.size(), oref.size()); + const T* param_data = param.data(); + const T* grad_data = grad.data(); + const int64_t* rows_data = rows.data(); + const T* oref_data = oref.data(); + + std::vector out(oref.size()); + T* o_data = out.data(); + tgt(&lr, param_data, grad_data, rows_data, o_data, &attr); + // only the selected rows should be equal + for (size_t i = 0; i < rows.size(); ++i) { + ExpectEQ(o_data + rows[i] * attr.grad_width, + oref_data + rows[i] * attr.grad_width, attr.grad_width); + } - TestAllImpls, PlaceType, std::vector, - std::vector, std::vector, std::vector, - std::vector, std::vector, int, float>( - right, x, outref, mean, var, scale, bias, left, epsilon, right); + // inplace + std::copy(param.begin(), param.end(), out.begin()); + tgt(&lr, o_data, grad_data, rows_data, o_data, &attr); + for (size_t i = 0; i < rows.size(); ++i) { + ExpectEQ(o_data + rows[i] * attr.grad_width, + oref_data + rows[i] * attr.grad_width, attr.grad_width); + } + }; + TestAllImpls(attr, verifier, lr, param, grad, + rows, param_out, attr); } } } } -template -void TestKernelCRFDecodingTuples() { - VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); - constexpr int state_trans_base_idx = 2; - auto test_sizes = TestSizes(); - test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 2000)); - for (int seq_len : {1, 11, 17, 50}) { - for (int tag_num : test_sizes) { - auto ref = jit::GetRefer>(); - EXPECT_TRUE(ref != nullptr); - int x_sz = seq_len * tag_num; - int w_sz = (tag_num + state_trans_base_idx) * tag_num; - std::vector x(x_sz), w(w_sz), alpharef(x_sz); - std::vector trackref(x_sz); - RandomVec(x_sz, x.data()); - RandomVec(w_sz, w.data()); - - ref(seq_len, (const T*)x.data(), (const T*)w.data(), alpharef.data(), - trackref.data(), tag_num); - - TestAllImpls, PlaceType, int, - std::vector, std::vector, std::vector, - std::vector, int>(tag_num, seq_len, x, w, alpharef, - trackref, tag_num); - } - } -} - -template -void TestKernelVBroadcastTuples() { - VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); +template +void TestKernelVBroadcast() { + using T = typename KernelTuple::data_type; + VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type); for (int w : TestSizes()) { std::vector x(w); RandomVec(w, x.data()); const T* x_data = x.data(); for (int64_t h : {1, 2, 6}) { - auto ref = jit::GetRefer>(); + auto ref = jit::GetRefer(); EXPECT_TRUE(ref != nullptr); std::vector y(w * h); T* y_data = y.data(); ref(x_data, y_data, h, w); - TestAllImpls, PlaceType, std::vector, - std::vector, int64_t>(static_cast(w), x, y, h, - static_cast(w)); + auto verifier = [](const typename KernelTuple::func_type tgt, + const std::vector& x, const std::vector& yref, + const int64_t& h, + const typename KernelTuple::attr_type& attr) { + EXPECT_TRUE(tgt != nullptr); + EXPECT_EQ(x.size(), static_cast(attr)); + EXPECT_EQ(yref.size(), x.size() * h); + std::vector y(yref.size()); + const T* x_data = x.data(); + const T* yref_data = yref.data(); + T* y_data = y.data(); + tgt(x_data, y_data, h, attr); + ExpectEQ(y_data, yref_data, yref.size()); + }; + TestAllImpls(static_cast(w), verifier, x, + y, h, static_cast(w)); } } } -#define TEST_CPU_KERNEL(test_tuple, kernel_type) \ - TEST(JITKernel, kernel_type) { \ - TestKernel##test_tuple(); \ - TestKernel##test_tuple(); \ +#define TestKernelVMul TestKernelXYZN +#define TestKernelVAdd TestKernelXYZN +#define TestKernelVAddRelu TestKernelXYZN +#define TestKernelVSub TestKernelXYZN + +#define TestKernelVScal TestKernelAXYN +#define TestKernelVAddBias TestKernelAXYN + +#define TestKernelVRelu TestKernelXYN +#define TestKernelVIdentity TestKernelXYN +#define TestKernelVSquare TestKernelXYN +#define TestKernelVExp TestKernelXYN +#define TestKernelVSigmoid TestKernelXYN +#define TestKernelVTanh TestKernelXYN +#define TestKernelVCopy TestKernelXYN + +#define TestKernelHMax TestKernelXRN +#define TestKernelHSum TestKernelXRN + +#define TestKernelLSTMCtHt TestKernelLSTM +#define TestKernelLSTMC1H1 TestKernelLSTM + +#define TestKernelGRUH1 TestKernelGRU +#define TestKernelGRUHtPart1 TestKernelGRU +#define TestKernelGRUHtPart2 TestKernelGRU + +#define TEST_CPU_KERNEL(kernel_type) \ + TEST(JITKernel, kernel_type) { \ + TestKernel##kernel_type, CPUPlace>(); \ + TestKernel##kernel_type, CPUPlace>(); \ } -TEST_CPU_KERNEL(XYZNTuples, kVMul); -TEST_CPU_KERNEL(XYZNTuples, kVAdd); -TEST_CPU_KERNEL(XYZNTuples, kVAddRelu); -TEST_CPU_KERNEL(XYZNTuples, kVSub); +TEST_CPU_KERNEL(VMul); +TEST_CPU_KERNEL(VAdd); +TEST_CPU_KERNEL(VAddRelu); +TEST_CPU_KERNEL(VSub); -TEST_CPU_KERNEL(AXYNTuples, kVScal); -TEST_CPU_KERNEL(AXYNTuples, kVAddBias); +TEST_CPU_KERNEL(VScal); +TEST_CPU_KERNEL(VAddBias); -TEST_CPU_KERNEL(XRNTuples, kHMax); -TEST_CPU_KERNEL(XRNTuples, kHSum); +TEST_CPU_KERNEL(VRelu); +TEST_CPU_KERNEL(VIdentity); +TEST_CPU_KERNEL(VSquare); +TEST_CPU_KERNEL(VExp); +TEST_CPU_KERNEL(VSigmoid); +TEST_CPU_KERNEL(VTanh); +TEST_CPU_KERNEL(VCopy); -TEST_CPU_KERNEL(XYNTuples, kVRelu); -TEST_CPU_KERNEL(XYNTuples, kVIdentity); -TEST_CPU_KERNEL(XYNTuples, kVSquare); -TEST_CPU_KERNEL(XYNTuples, kVExp); -TEST_CPU_KERNEL(XYNTuples, kVSigmoid); -TEST_CPU_KERNEL(XYNTuples, kVTanh); -TEST_CPU_KERNEL(XYNTuples, kVCopy); +TEST_CPU_KERNEL(HMax); +TEST_CPU_KERNEL(HSum); -TEST_CPU_KERNEL(LSTMTuples, kLSTMCtHt); -TEST_CPU_KERNEL(LSTMTuples, kLSTMC1H1); +TEST_CPU_KERNEL(LSTMCtHt); +TEST_CPU_KERNEL(LSTMC1H1); -TEST_CPU_KERNEL(GRUTuples, kGRUH1); -TEST_CPU_KERNEL(GRUTuples, kGRUHtPart1); -TEST_CPU_KERNEL(GRUTuples, kGRUHtPart2); +TEST_CPU_KERNEL(GRUH1); +TEST_CPU_KERNEL(GRUHtPart1); +TEST_CPU_KERNEL(GRUHtPart2); -TEST_CPU_KERNEL(NCHW16CMulNCTuples, kNCHW16CMulNC); +TEST_CPU_KERNEL(NCHW16CMulNC); +TEST_CPU_KERNEL(LayerNorm); +TEST_CPU_KERNEL(CRFDecoding); -TEST_CPU_KERNEL(SeqPoolTuples, kSeqPool); -TEST_CPU_KERNEL(MatMulTuples, kMatMul); -TEST_CPU_KERNEL(SoftmaxTuples, kSoftmax); -TEST_CPU_KERNEL(EmbSeqPoolTuples, kEmbSeqPool); -TEST_CPU_KERNEL(SgdTuples, kSgd); -TEST_CPU_KERNEL(LayerNormTuples, kLayerNorm); -TEST_CPU_KERNEL(CRFDecodingTuples, kCRFDecoding); -TEST_CPU_KERNEL(VBroadcastTuples, kVBroadcast); +TEST_CPU_KERNEL(SeqPool); +TEST_CPU_KERNEL(EmbSeqPool); +TEST_CPU_KERNEL(MatMul); +TEST_CPU_KERNEL(Softmax); +TEST_CPU_KERNEL(Sgd); +TEST_CPU_KERNEL(VBroadcast); TEST(JITKernel_key, lstm) { jit::lstm_attr_t attr1(8, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh); @@ -1045,16 +1002,9 @@ TEST(JITKernel_key, gru) { } TEST(JITKernel, kernel_func) { - auto f1 = - jit::KernelFuncs, CPUPlace>::Cache() - .At(3); - auto f2 = jit::KernelFuncs, - CPUPlace>::Cache()[3]; + auto f1 = jit::KernelFuncs, CPUPlace>::Cache().At(3); + auto f2 = jit::KernelFuncs, CPUPlace>::Cache()[3]; + EXPECT_TRUE(f1 != nullptr); EXPECT_TRUE(f1 == f2); - - f1 = jit::KernelFuncs, CPUPlace>::Cache() - .At(3); - f2 = jit::KernelFuncs, CPUPlace>::Cache() - .At(4); - EXPECT_TRUE(f1 != f2); + // TODO(TJ): check not equal } diff --git a/paddle/fluid/operators/layer_norm_op.h b/paddle/fluid/operators/layer_norm_op.h index f0c3064d41..8627c83b43 100644 --- a/paddle/fluid/operators/layer_norm_op.h +++ b/paddle/fluid/operators/layer_norm_op.h @@ -229,9 +229,9 @@ class LayerNormKernel : public framework::OpKernel { PADDLE_ENFORCE_EQ(scale->numel(), right); PADDLE_ENFORCE_EQ(bias->numel(), right); - auto ker = jit::KernelFuncs, - platform::CPUPlace>::Cache() - .At(right); + auto ker = + jit::KernelFuncs, platform::CPUPlace>::Cache() + .At(right); ker(x.data(), out.data(), mean->data(), var->data(), scale->data(), bias->data(), static_cast(left), static_cast(epsilon), right); diff --git a/paddle/fluid/operators/math/fc_compute.h b/paddle/fluid/operators/math/fc_compute.h index 0ad57c51be..66ce57594a 100644 --- a/paddle/fluid/operators/math/fc_compute.h +++ b/paddle/fluid/operators/math/fc_compute.h @@ -30,17 +30,16 @@ inline void FCCompute(const BlasT& blas, const int M, return; } if (relu) { - auto compute = jit::KernelFuncs, - platform::CPUPlace>::Cache() - .At(N); + auto compute = + jit::KernelFuncs, platform::CPUPlace>::Cache().At( + N); for (int i = 0; i < M; i++) { T* dst = Y + i * N; compute(B, dst, dst, N); } } else { - auto compute = jit::KernelFuncs, - platform::CPUPlace>::Cache() - .At(N); + auto compute = + jit::KernelFuncs, platform::CPUPlace>::Cache().At(N); #ifdef PADDLE_WITH_MKLML #pragma omp parallel for #endif diff --git a/paddle/fluid/operators/math/sequence_pooling.cc b/paddle/fluid/operators/math/sequence_pooling.cc index db103e5fab..7af44f2b2c 100644 --- a/paddle/fluid/operators/math/sequence_pooling.cc +++ b/paddle/fluid/operators/math/sequence_pooling.cc @@ -255,9 +255,9 @@ class SequencePoolFunctor { jit::seq_pool_attr_t attr( static_cast(input.numel() / input.dims()[0]), jit::SeqPoolType::kSum); - auto seqpool = jit::KernelFuncs, - platform::CPUPlace>::Cache() - .At(attr); + auto seqpool = + jit::KernelFuncs, platform::CPUPlace>::Cache() + .At(attr); for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { attr.h = static_cast(lod[i + 1] - lod[i]); seqpool(src, dst, &attr); diff --git a/paddle/fluid/operators/math/softmax_impl.h b/paddle/fluid/operators/math/softmax_impl.h index a1cb3f9728..d77b6712c5 100644 --- a/paddle/fluid/operators/math/softmax_impl.h +++ b/paddle/fluid/operators/math/softmax_impl.h @@ -82,8 +82,7 @@ class SoftmaxFunctor> { const int kClassDim = 1; // 2D data. Batch x C auto compute_softmax = - jit::KernelFuncs, - platform::CPUPlace>::Cache() + jit::KernelFuncs, platform::CPUPlace>::Cache() .At(in_dims[kClassDim]); compute_softmax(in_data, out_data, in_dims[kClassDim], in_dims[kBatchDim]); } diff --git a/paddle/fluid/operators/optimizers/sgd_op.h b/paddle/fluid/operators/optimizers/sgd_op.h index 0425a3d194..5dd5f67e00 100644 --- a/paddle/fluid/operators/optimizers/sgd_op.h +++ b/paddle/fluid/operators/optimizers/sgd_op.h @@ -47,9 +47,9 @@ class SGDOpKernel : public framework::OpKernel { int64_t rows_idx = 0; T *out_data = param_out->mutable_data(ctx.GetPlace()); - auto sgd = jit::KernelFuncs, - platform::CPUPlace>::Cache() - .At(attr); + auto sgd = + jit::KernelFuncs, platform::CPUPlace>::Cache().At( + attr); sgd(lr, param_data, grad_data, &rows_idx, out_data, &attr); } else if (grad_var->IsType()) { // TODO(qijun): In Sparse SGD operator, in-place update is enforced. @@ -82,9 +82,9 @@ class SGDOpKernel : public framework::OpKernel { attr.selected_rows_size = grad_rows.size(); PADDLE_ENFORCE_EQ(attr.grad_width, attr.param_width); - auto sgd = jit::KernelFuncs, - platform::CPUPlace>::Cache() - .At(attr); + auto sgd = + jit::KernelFuncs, platform::CPUPlace>::Cache().At( + attr); sgd(lr, param_data, grad_data, rows_data, out_data, &attr); } else { PADDLE_THROW("Unsupported Variable Type of Grad"); -- GitLab