diff --git a/paddle/fluid/operators/fusion_gru_op.cc b/paddle/fluid/operators/fusion_gru_op.cc index a04c1c1263fba659e2d3f623b607e9f476bb40ed..120b2ab440156f6020fd6005dd64a48e9a6918ec 100644 --- a/paddle/fluid/operators/fusion_gru_op.cc +++ b/paddle/fluid/operators/fusion_gru_op.cc @@ -16,10 +16,9 @@ limitations under the License. */ #include // for memcpy #include #include "paddle/fluid/operators/math/blas.h" -#include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/fc_compute.h" +#include "paddle/fluid/operators/math/jit_kernel.h" #include "paddle/fluid/operators/math/sequence2batch.h" -#include "paddle/fluid/platform/cpu_info.h" namespace paddle { namespace operators { @@ -174,58 +173,44 @@ class FusionGRUKernel : public framework::OpKernel { } } -#define INIT_VEC_FUNC \ - std::function act_gate, act_state; \ - std::function cross; \ - auto& act_gate_str = ctx.Attr("gate_activation"); \ - auto& act_state_str = ctx.Attr("activation"); \ - if (platform::jit::MayIUse(platform::jit::avx)) { \ - math::VecActivations act_functor; \ - act_gate = act_functor(act_gate_str); \ - act_state = act_functor(act_state_str); \ - cross = math::vec_cross; \ - } else { \ - math::VecActivations act_functor; \ - act_gate = act_functor(act_gate_str); \ - act_state = act_functor(act_state_str); \ - cross = math::vec_cross; \ - } - -#define INIT_BASE_INPUT_OUTPUT \ - auto* h0 = ctx.Input("H0"); \ - auto* wx = ctx.Input("WeightX"); \ - auto* wh = ctx.Input("WeightH"); \ - auto* bias = ctx.Input("Bias"); \ - auto* xx = ctx.Output("XX"); \ - auto* hidden_out = ctx.Output("Hidden"); \ - bool is_reverse = ctx.Attr("is_reverse"); - -#define INIT_BASE_SIZES \ - auto x_dims = x->dims(); /* T x M*/ \ - auto wh_dims = wh->dims(); /* D x 3D*/ \ - const int total_T = x_dims[0]; \ - const int M = x_dims[1]; \ - const int D = wh_dims[0]; \ - const int D3 = wh_dims[1]; \ - const int D2 = D * 2; +#define INIT_BASE_DEFINES \ + auto* x = ctx.Input("X"); \ + auto* wh = ctx.Input("WeightH"); \ + auto* xx = ctx.Output("XX"); \ + auto x_lod = x->lod(); \ + auto x_dims = x->dims(); /* T x M*/ \ + auto wh_dims = wh->dims(); /* D x 3D*/ \ + const int total_T = x_dims[0]; \ + const int D3 = wh_dims[1] + +#define INIT_OTHER_DEFINES \ + auto* h0 = ctx.Input("H0"); \ + auto* wx = ctx.Input("WeightX"); \ + auto* bias = ctx.Input("Bias"); \ + auto* hidden_out = ctx.Output("Hidden"); \ + bool is_reverse = ctx.Attr("is_reverse"); \ + const int M = x_dims[1]; \ + const int D = wh_dims[0]; \ + const int D2 = D * 2; \ + const auto& ker = math::jitkernel::KernelPool::Instance() \ + .template Get, \ + const std::string&, const std::string&>( \ + ctx.Attr("gate_activation"), \ + ctx.Attr("activation"), D); \ + const T* x_data = x->data(); \ + const T* wx_data = wx->data(); \ + const T* wh_data = wh->data(); \ + auto place = ctx.GetPlace(); \ + T* xx_data = xx->mutable_data(place) void SeqCompute(const framework::ExecutionContext& ctx) const { using DeviceContext = paddle::platform::CPUDeviceContext; - auto* x = ctx.Input("X"); - INIT_BASE_INPUT_OUTPUT - INIT_BASE_SIZES - INIT_VEC_FUNC - - auto x_lod = x->lod(); + INIT_BASE_DEFINES; + INIT_OTHER_DEFINES; const int N = x_lod[0].size() - 1; - const T* x_data = x->data(); const T* h0_data = h0 ? h0->data() : nullptr; - const T* wx_data = wx->data(); - const T* wh_data = wh->data(); const T* wh_state_data = wh_data + D * D2; - T* xx_data = xx->mutable_data(ctx.GetPlace()); - T* hidden_out_data = hidden_out->mutable_data(ctx.GetPlace()); - + T* hidden_out_data = hidden_out->mutable_data(place); auto blas = math::GetBlas(ctx); math::FCCompute(blas, total_T, D3, M, x_data, wx_data, xx_data, @@ -252,14 +237,7 @@ class FusionGRUKernel : public framework::OpKernel { if (h0_data) { prev_hidden_data = h0_data + bid * D; } else { - // W: {W_update, W_reset; W_state} - // update gate - act_gate(D, xx_data, xx_data); - // state gate - act_state(D, xx_data + D2, xx_data + D2); - // out = a*b - blas.VMUL(D, xx_data, xx_data + D2, hidden_out_data); - // save prev + ker->ComputeH1(xx_data, hidden_out_data); prev_hidden_data = hidden_out_data; tstart = 1; move_step(); @@ -269,17 +247,12 @@ class FusionGRUKernel : public framework::OpKernel { blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D2, D, static_cast(1), prev_hidden_data, D, wh_data, D2, static_cast(1), xx_data, D3); - act_gate(D2, xx_data, xx_data); - // rt = rt*ht_1 inplace result - blas.VMUL(D, prev_hidden_data, xx_data + D, hidden_out_data); - + ker->ComputeHtPart1(xx_data, prev_hidden_data, hidden_out_data); // gemm rt * Ws blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast(1), hidden_out_data, D, wh_state_data, D, static_cast(1), xx_data + D2, D3); - act_state(D, xx_data + D2, xx_data + D2); - // out = zt*ht~ + (1-zt)*ht_1 - cross(D, xx_data, xx_data + D2, prev_hidden_data, hidden_out_data); + ker->ComputeHtPart2(xx_data, prev_hidden_data, hidden_out_data); // save prev prev_hidden_data = hidden_out_data; move_step(); @@ -289,28 +262,19 @@ class FusionGRUKernel : public framework::OpKernel { void BatchCompute(const framework::ExecutionContext& ctx) const { using DeviceContext = paddle::platform::CPUDeviceContext; - auto* x = ctx.Input("X"); - INIT_BASE_INPUT_OUTPUT - INIT_BASE_SIZES - if (x->lod()[0].size() == 2) { + INIT_BASE_DEFINES; + if (x_lod[0].size() == 2) { xx->Resize({total_T, D3}); SeqCompute(ctx); return; } - INIT_VEC_FUNC - + INIT_OTHER_DEFINES; auto* reordered_h0 = ctx.Output("ReorderedH0"); auto* batched_input = ctx.Output("BatchedInput"); auto* batched_out = ctx.Output("BatchedOut"); - - const T* x_data = x->data(); - const T* wx_data = wx->data(); - const T* wh_data = wh->data(); - T* xx_data = xx->mutable_data(ctx.GetPlace()); - T* batched_input_data = batched_input->mutable_data(ctx.GetPlace()); - T* batched_out_data = batched_out->mutable_data(ctx.GetPlace()); - hidden_out->mutable_data(ctx.GetPlace()); - + T* batched_input_data = batched_input->mutable_data(place); + T* batched_out_data = batched_out->mutable_data(place); + hidden_out->mutable_data(place); auto& dev_ctx = ctx.template device_context(); auto blas = math::GetBlas(dev_ctx); math::LoDTensor2BatchFunctor to_batch; @@ -336,7 +300,7 @@ class FusionGRUKernel : public framework::OpKernel { T* prev_hidden_data = nullptr; if (h0) { // reorder h0 - T* reordered_h0_data = reordered_h0->mutable_data(ctx.GetPlace()); + T* reordered_h0_data = reordered_h0->mutable_data(place); const T* h0_data = h0->data(); prev_hidden_data = reordered_h0_data; size_t sz = sizeof(T) * D; @@ -350,12 +314,7 @@ class FusionGRUKernel : public framework::OpKernel { T* cur_out_data = batched_out_data; // W: {W_update, W_reset; W_state} for (int i = 0; i < max_bs; ++i) { - // update gate - act_gate(D, cur_in_data, cur_in_data); - // state gate - act_state(D, cur_in_data + D2, cur_in_data + D2); - // out = a*b - blas.VMUL(D, cur_in_data, cur_in_data + D2, cur_out_data); + ker->ComputeH1(cur_in_data, cur_out_data); // add offset cur_in_data += D3; cur_out_data += D; @@ -380,10 +339,8 @@ class FusionGRUKernel : public framework::OpKernel { T* cur_out_data = batched_out_data; T* cur_prev_hidden_data = prev_hidden_data; for (int i = 0; i < cur_bs; ++i) { - act_gate(D2, cur_batched_data, cur_batched_data); - // rt = rt*ht_1 inplace result - blas.VMUL(D, cur_prev_hidden_data, cur_batched_data + D, cur_out_data); - + ker->ComputeHtPart1(cur_batched_data, cur_prev_hidden_data, + cur_out_data); cur_batched_data += D3; cur_prev_hidden_data += D; cur_out_data += D; @@ -397,12 +354,8 @@ class FusionGRUKernel : public framework::OpKernel { cur_prev_hidden_data = prev_hidden_data; for (int i = 0; i < cur_bs; ++i) { - // ht~ = act_state(...) - act_state(D, cur_batched_data + D2, cur_batched_data + D2); - // out = zt*ht~ + (1-zt)*ht_1 - cross(D, cur_batched_data, cur_batched_data + D2, cur_prev_hidden_data, - cur_out_data); - + ker->ComputeHtPart2(cur_batched_data, cur_prev_hidden_data, + cur_out_data); cur_batched_data += D3; cur_prev_hidden_data += D; cur_out_data += D; @@ -416,9 +369,8 @@ class FusionGRUKernel : public framework::OpKernel { batched_out->set_lod(batched_lod); to_seq(dev_ctx, *batched_out, hidden_out); } -#undef INIT_VEC_FUNC -#undef INIT_BASE_SIZES -#undef INIT_BASE_INPUT_OUTPUT +#undef INIT_OTHER_DEFINES +#undef INIT_BASE_DEFINES }; } // namespace operators diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index c7bdec354735773a15b4c99baf9f7798f2d92564..651fd4dfa63f611d5766992d3f921242c8881a13 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -75,6 +75,6 @@ endif() cc_test(concat_test SRCS concat_test.cc DEPS concat) cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) cc_library(jit_kernel - SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_lstm.cc + SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc DEPS cpu_info cblas) cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel) diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index e91e4e8e5adfdfff8163efe7fc1451bc602504e0..9088d0c7a6307c3fbd9707c719ec9e6f6c85fbdb 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -142,6 +142,15 @@ class LSTMKernel : public Kernel { const T *wp_data = nullptr) const = 0; }; +template +class GRUKernel : public Kernel { + public: + // compute h1 without h0 + virtual void ComputeH1(T *gates, T *ht) const = 0; + virtual void ComputeHtPart1(T *gates, const T *ht_1, T *ht) const = 0; + virtual void ComputeHtPart2(T *gates, const T *ht_1, T *ht) const = 0; +}; + } // namespace jitkernel } // namespace math } // namespace operators diff --git a/paddle/fluid/operators/math/jit_kernel_lstm.cc b/paddle/fluid/operators/math/jit_kernel_rnn.cc similarity index 87% rename from paddle/fluid/operators/math/jit_kernel_lstm.cc rename to paddle/fluid/operators/math/jit_kernel_rnn.cc index 26bd26e2e171feea569fbd646a9caf03bebbaa46..a340cdfaf6cd9009a41b2a9a4905a580c3785c3a 100644 --- a/paddle/fluid/operators/math/jit_kernel_lstm.cc +++ b/paddle/fluid/operators/math/jit_kernel_rnn.cc @@ -354,6 +354,67 @@ REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM, #undef JITKERNEL_DECLARE_LSTM #undef JITKERNEL_KEY_LSTM #undef JITKERNEL_NEW_LSTM_IMPL + +/* GRU JitKernel */ +template +class GRUKernelImpl : public GRUKernel { + public: + explicit GRUKernelImpl(const std::string& act_gate, + const std::string& act_state, int d) + : GRUKernel() { + d_ = d; + d2_ = d * 2; + act_gate_d2_ = GetActKernel(act_gate, d2_); + act_gate_d_ = GetActKernel(act_gate, d); + act_state_d_ = GetActKernel(act_state, d); + vmul_d_ = KernelPool::Instance().template Get>(d); + } + + void ComputeH1(T* gates, T* ht) const override { + act_gate_d_->Compute(gates, gates); + act_state_d_->Compute(gates + d2_, gates + d2_); + vmul_d_->Compute(gates, gates + d2_, ht); + } + void ComputeHtPart1(T* gates, const T* ht_1, T* ht) const override { + // W: {W_update, W_reset; W_state} + act_gate_d2_->Compute(gates, gates); + vmul_d_->Compute(ht_1, gates + d_, ht); + } + + void ComputeHtPart2(T* gates, const T* ht_1, T* ht) const override { + T* y = gates + d2_; + act_state_d_->Compute(y, y); + // out = zt*ht~ + (1-zt)*ht_1 + for (int i = 0; i < d_; ++i) { + ht[i] = gates[i] * y[i] + (static_cast(1) - gates[i]) * ht_1[i]; + } + } + + private: + int d_, d2_; + std::shared_ptr> act_gate_d2_, act_gate_d_, act_state_d_; + std::shared_ptr> vmul_d_; +}; + +#define JITKERNEL_DECLARE_GRU(ker_class, ker_dtype) \ + template <> \ + std::shared_ptr> KernelPool::Get< \ + GRUKernel, const std::string&, const std::string&, int>( \ + const std::string& act_gate, const std::string& act_state, int d) + +#define JITKERNEL_KEY_GRU(ker_key, dtype_key) \ + #ker_key #dtype_key + std::to_string(d) + act_gate + act_state + +#define JITKERNEL_NEW_GRU_IMPL(ker, dtype, isa, k) \ + p = std::dynamic_pointer_cast>( \ + std::make_shared>(act_gate, act_state, d)); + +REGISTER_JITKERNEL_ARGS(gru, GRUKernel, JITKERNEL_DECLARE_GRU, + JITKERNEL_KEY_GRU, JITKERNEL_NEW_GRU_IMPL); + +#undef JITKERNEL_NEW_GRU_IMPL +#undef JITKERNEL_KEY_GRU +#undef JITKERNEL_DECLARE_GRU } // namespace jitkernel } // namespace math } // namespace operators