From 64a90b2f1c762dc4a093da413b9c945c99b82e73 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Mon, 17 Dec 2018 12:29:22 +0000 Subject: [PATCH] use vadd, vaddrelu, lstm and gru jitkernel --- paddle/fluid/operators/fused/fusion_gru_op.cc | 58 ++++++++--------- .../fluid/operators/fused/fusion_lstm_op.cc | 62 ++++++++++--------- paddle/fluid/operators/math/CMakeLists.txt | 9 --- paddle/fluid/operators/math/fc_compute.h | 14 ++--- 4 files changed, 68 insertions(+), 75 deletions(-) diff --git a/paddle/fluid/operators/fused/fusion_gru_op.cc b/paddle/fluid/operators/fused/fusion_gru_op.cc index 25b7ae7c2..d44a7ad83 100644 --- a/paddle/fluid/operators/fused/fusion_gru_op.cc +++ b/paddle/fluid/operators/fused/fusion_gru_op.cc @@ -15,9 +15,9 @@ limitations under the License. */ #include "paddle/fluid/operators/fused/fusion_gru_op.h" #include // for memcpy #include +#include "paddle/fluid/operators/jit/kernels.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/fc_compute.h" -#include "paddle/fluid/operators/math/jit_kernel.h" #include "paddle/fluid/operators/math/sequence2batch.h" namespace paddle { @@ -183,27 +183,29 @@ class FusionGRUKernel : public framework::OpKernel { const int total_T = x_dims[0]; \ const int D3 = wh_dims[1] -#define INIT_OTHER_DEFINES \ - auto* h0 = ctx.Input("H0"); \ - auto* wx = ctx.Input("WeightX"); \ - auto* bias = ctx.Input("Bias"); \ - auto* hidden_out = ctx.Output("Hidden"); \ - bool is_reverse = ctx.Attr("is_reverse"); \ - const int M = x_dims[1]; \ - const int D = wh_dims[0]; \ - const int D2 = D * 2; \ - const math::jitkernel::gru_attr_t attr( \ - D, ctx.Attr("gate_activation"), \ - ctx.Attr("activation")); \ - math::jitkernel::gru_t one_step; \ - const auto& ker = \ - math::jitkernel::KernelPool::Instance() \ - .template Get, \ - const math::jitkernel::gru_attr_t&>(attr); \ - const T* x_data = x->data(); \ - const T* wx_data = wx->data(); \ - const T* wh_data = wh->data(); \ - auto place = ctx.GetPlace(); \ +#define INIT_OTHER_DEFINES \ + auto* h0 = ctx.Input("H0"); \ + auto* wx = ctx.Input("WeightX"); \ + auto* bias = ctx.Input("Bias"); \ + auto* hidden_out = ctx.Output("Hidden"); \ + bool is_reverse = ctx.Attr("is_reverse"); \ + const int M = x_dims[1]; \ + const int D = wh_dims[0]; \ + const int D2 = D * 2; \ + const jit::gru_attr_t attr( \ + D, jit::to_kerneltype(ctx.Attr("gate_activation")), \ + jit::to_kerneltype(ctx.Attr("activation"))); \ + jit::gru_t one_step; \ + auto ComputeH1 = \ + jit::Get(attr); \ + auto ComputeHtPart1 = \ + jit::Get(attr); \ + auto ComputeHtPart2 = \ + jit::Get(attr); \ + const T* x_data = x->data(); \ + const T* wx_data = wx->data(); \ + const T* wh_data = wh->data(); \ + auto place = ctx.GetPlace(); \ T* xx_data = xx->mutable_data(place) void SeqCompute(const framework::ExecutionContext& ctx) const { @@ -242,7 +244,7 @@ class FusionGRUKernel : public framework::OpKernel { } else { one_step.gates = xx_data; one_step.ht = hidden_out_data; - ker->ComputeH1(&one_step, &attr); + ComputeH1(&one_step, &attr); prev_hidden_data = hidden_out_data; tstart = 1; move_step(); @@ -255,12 +257,12 @@ class FusionGRUKernel : public framework::OpKernel { one_step.gates = xx_data; one_step.ht_1 = prev_hidden_data; one_step.ht = hidden_out_data; - ker->ComputeHtPart1(&one_step, &attr); + ComputeHtPart1(&one_step, &attr); // gemm rt * Ws blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast(1), hidden_out_data, D, wh_state_data, D, static_cast(1), xx_data + D2, D3); - ker->ComputeHtPart2(&one_step, &attr); + ComputeHtPart2(&one_step, &attr); // save prev prev_hidden_data = hidden_out_data; move_step(); @@ -324,7 +326,7 @@ class FusionGRUKernel : public framework::OpKernel { for (int i = 0; i < max_bs; ++i) { one_step.gates = cur_in_data; one_step.ht = cur_out_data; - ker->ComputeH1(&one_step, &attr); + ComputeH1(&one_step, &attr); // add offset cur_in_data += D3; cur_out_data += D; @@ -352,7 +354,7 @@ class FusionGRUKernel : public framework::OpKernel { one_step.gates = cur_batched_data; one_step.ht_1 = cur_prev_hidden_data; one_step.ht = cur_out_data; - ker->ComputeHtPart1(&one_step, &attr); + ComputeHtPart1(&one_step, &attr); cur_batched_data += D3; cur_prev_hidden_data += D; @@ -370,7 +372,7 @@ class FusionGRUKernel : public framework::OpKernel { one_step.gates = cur_batched_data; one_step.ht_1 = cur_prev_hidden_data; one_step.ht = cur_out_data; - ker->ComputeHtPart2(&one_step, &attr); + ComputeHtPart2(&one_step, &attr); cur_batched_data += D3; cur_prev_hidden_data += D; cur_out_data += D; diff --git a/paddle/fluid/operators/fused/fusion_lstm_op.cc b/paddle/fluid/operators/fused/fusion_lstm_op.cc index 8021a896c..a62f4d18c 100644 --- a/paddle/fluid/operators/fused/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fused/fusion_lstm_op.cc @@ -14,9 +14,9 @@ limitations under the License. */ #include "paddle/fluid/operators/fused/fusion_lstm_op.h" #include +#include "paddle/fluid/operators/jit/kernels.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/fc_compute.h" -#include "paddle/fluid/operators/math/jit_kernel.h" #include "paddle/fluid/operators/math/sequence2batch.h" namespace paddle { @@ -236,31 +236,33 @@ class FuisonLSTMKernel : public framework::OpKernel { const int D = wh_dims[0]; \ const int D4 = wh_dims[1] -#define INIT_OTHER_DEFINES \ - const T* x_data = x->data(); \ - const T* wx_data = wx->data(); \ - const T* wh_data = wh->data(); \ - /* diagonal weight*/ \ - const T* wp_data = bias->data() + D4; \ - /* for peephole only*/ \ - T* checked_cell_data = nullptr; \ - auto place = ctx.GetPlace(); \ - if (use_peepholes) { \ - /* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \ - auto* checked_cell = ctx.Output("CheckedCell"); \ - checked_cell_data = checked_cell->mutable_data(place); \ - } \ - const math::jitkernel::lstm_attr_t attr( \ - D, ctx.Attr("gate_activation"), \ - ctx.Attr("candidate_activation"), \ - ctx.Attr("cell_activation"), use_peepholes); \ - math::jitkernel::lstm_t one_step; \ - one_step.wp = wp_data; \ - one_step.checked = checked_cell_data; \ - const auto& ker = \ - math::jitkernel::KernelPool::Instance() \ - .template Get, \ - const math::jitkernel::lstm_attr_t&>(attr) +#define INIT_OTHER_DEFINES \ + const T* x_data = x->data(); \ + const T* wx_data = wx->data(); \ + const T* wh_data = wh->data(); \ + /* diagonal weight*/ \ + const T* wp_data = bias->data() + D4; \ + /* for peephole only*/ \ + T* checked_cell_data = nullptr; \ + auto place = ctx.GetPlace(); \ + if (use_peepholes) { \ + /* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \ + auto* checked_cell = ctx.Output("CheckedCell"); \ + checked_cell_data = checked_cell->mutable_data(place); \ + } \ + const jit \ + : lstm_attr_t attr( \ + D, jit::to_kerneltype(ctx.Attr("gate_activation")), \ + jit::to_kerneltype(ctx.Attr("candidate_activation")), \ + jit::to_kerneltype(ctx.Attr("cell_activation")), \ + use_peepholes); \ + math::jitkernel::lstm_t one_step; \ + one_step.wp = wp_data; \ + one_step.checked = checked_cell_data; \ + auto ComputeC1H1 = \ + jit::Get(attr); \ + auto ComputeCtHt = \ + jit::Get(attr) // Wh GEMM #define GEMM_WH_ADDON(bs, prev, out) \ @@ -306,7 +308,7 @@ class FuisonLSTMKernel : public framework::OpKernel { one_step.gates = xx_data; one_step.ct = c_out_data; one_step.ht = h_out_data; - ker->ComputeC1H1(&one_step, &attr); + ComputeC1H1(&one_step, &attr); tstart = 1; // move one step prev_h_data = h_out_data; @@ -322,7 +324,7 @@ class FuisonLSTMKernel : public framework::OpKernel { one_step.ct_1 = prev_c_data; one_step.ct = c_out_data; one_step.ht = h_out_data; - ker->ComputeCtHt(&one_step, &attr); + ComputeCtHt(&one_step, &attr); // move one step prev_h_data = h_out_data; prev_c_data = c_out_data; @@ -402,7 +404,7 @@ class FuisonLSTMKernel : public framework::OpKernel { one_step.gates = cur_in_data; one_step.ct = cur_c_out_data; one_step.ht = cur_h_out_data; - ker->ComputeC1H1(&one_step, &attr); + ComputeC1H1(&one_step, &attr); cur_in_data += D4; cur_c_out_data += D; @@ -432,7 +434,7 @@ class FuisonLSTMKernel : public framework::OpKernel { one_step.ct_1 = cur_prev_c_data; one_step.ct = cur_c_out_data; one_step.ht = cur_h_out_data; - ker->ComputeCtHt(&one_step, &attr); + ComputeC1H1(&one_step, &attr); // move one batch cur_in_data += D4; diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index 8e8f83a63..ea6aebd29 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -73,12 +73,3 @@ if(WITH_GPU) endif() cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split) cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) - -# set(JIT_KERNEL_SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc jit_kernel_crf_decode.cc jit_kernel_layer_norm.cc) -# set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce) -# if(WITH_XBYAK) -# list(APPEND JIT_KERNEL_SRCS jit_gen.cc jit_code.cc) -# list(APPEND JIT_KERNEL_DEPS xbyak) -# endif() -# cc_library(jit_kernel SRCS ${JIT_KERNEL_SRCS} DEPS ${JIT_KERNEL_DEPS}) -# cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel) diff --git a/paddle/fluid/operators/math/fc_compute.h b/paddle/fluid/operators/math/fc_compute.h index 5b9953a5a..5e3093c69 100644 --- a/paddle/fluid/operators/math/fc_compute.h +++ b/paddle/fluid/operators/math/fc_compute.h @@ -14,8 +14,8 @@ limitations under the License. */ #pragma once +#include "paddle/fluid/operators/jit/kernels.h" #include "paddle/fluid/operators/math/blas.h" -#include "paddle/fluid/operators/math/jit_kernel.h" namespace paddle { namespace operators { @@ -30,22 +30,20 @@ inline void FCCompute(const BlasT& blas, const int M, return; } if (relu) { - const auto& vaddrelu = jitkernel::KernelPool::Instance() - .template Get>(N); + auto compute = + jit::Get(N); for (int i = 0; i < M; i++) { T* dst = Y + i * N; - vaddrelu->Compute(B, dst, dst, N); + compute(B, dst, dst, N); } } else { - const auto& vadd = jitkernel::KernelPool::Instance() - .template Get>(N); - + auto compute = jit::Get(N); #ifdef PADDLE_WITH_MKLML #pragma omp parallel for #endif for (int i = 0; i < M; i++) { T* dst = Y + i * N; - vadd->Compute(B, dst, dst, N); + compute(B, dst, dst, N); } } } -- GitLab