diff --git a/paddle/fluid/operators/gru_op.cc b/paddle/fluid/operators/gru_op.cc index 5c746878823b3dcde2573feec00d3d9dac5ceab8..087f903a8bba9a4bfcd7eaabd7098555442a904e 100644 --- a/paddle/fluid/operators/gru_op.cc +++ b/paddle/fluid/operators/gru_op.cc @@ -14,6 +14,11 @@ limitations under the License. */ #include "paddle/fluid/operators/gru_op.h" #include +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h" +#include "paddle/fluid/operators/math/detail/gru_kernel.h" + +DECLARE_int32(paddle_num_threads); namespace paddle { namespace operators { @@ -211,6 +216,158 @@ class GRUGradOp : public framework::OperatorWithKernel { } }; +template +class GRUCPUKernel : public framework::OpKernel { + public: + void BatchCompute(const framework::ExecutionContext& context) const { + using DeviceContext = paddle::platform::CPUDeviceContext; + auto* input = context.Input("Input"); + auto* h0 = context.Input("H0"); + auto* weight = context.Input("Weight"); + const T* weight_data = weight->data(); + auto* bias = context.Input("Bias"); + auto* batch_gate = context.Output("BatchGate"); + batch_gate->mutable_data(context.GetPlace()); + auto* batch_reset_hidden_prev = + context.Output("BatchResetHiddenPrev"); + batch_reset_hidden_prev->mutable_data(context.GetPlace()); + auto* batch_hidden = context.Output("BatchHidden"); + batch_hidden->mutable_data(context.GetPlace()); + auto* hidden = context.Output("Hidden"); + hidden->mutable_data(context.GetPlace()); + + auto hidden_dims = hidden->dims(); + + bool is_reverse = context.Attr("is_reverse"); + math::LoDTensor2BatchFunctor to_batch; + auto& dev_ctx = context.template device_context(); + to_batch(dev_ctx, *input, batch_gate, true, is_reverse); + + if (bias) { + math::RowwiseAdd add_bias; + add_bias(dev_ctx, *batch_gate, *bias, batch_gate); + } + + int frame_size = hidden_dims[1]; + math::GRUMetaValue gru_value; + gru_value.gate_weight = const_cast(weight_data); + gru_value.state_weight = + const_cast(weight_data + 2 * frame_size * frame_size); + Tensor ordered_h0; + + framework::Vector order(batch_gate->lod()[2]); + + if (h0) { + // Since the batch computing for GRU reorders the input sequences + // according to their length. The initialized cell state also needs + // to reorder. + ReorderInitState( + context.template device_context(), *h0, order, + &ordered_h0, true); + gru_value.prev_out_value = ordered_h0.data(); + } else { + gru_value.prev_out_value = nullptr; + } + auto batch_starts = batch_gate->lod()[0]; + size_t seq_len = batch_starts.size() - 1; + auto active_node = math::detail::GetActivationType( + context.Attr("activation")); + auto active_gate = math::detail::GetActivationType( + context.Attr("gate_activation")); + +#ifdef PADDLE_WITH_MKLML + // use MKL packed to speedup GEMM + if (FLAGS_paddle_num_threads >= 4) { + auto blas = math::GetBlas(dev_ctx); + T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/, + frame_size * 2 /*width of weight*/, + frame_size /*height of height*/); + PADDLE_ENFORCE(packed_gate); + blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size * 2, + frame_size, T(1.0), gru_value.gate_weight, frame_size * 2, + packed_gate); + T* packed_state = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/, + frame_size /*width of weight*/, + frame_size /*height of height*/); + PADDLE_ENFORCE(packed_state); + blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size, + frame_size, T(1.0), gru_value.state_weight, frame_size, + packed_state); + for (size_t n = 0; n < seq_len; n++) { + int bstart = static_cast(batch_starts[n]); + int bend = static_cast(batch_starts[n + 1]); + int cur_batch_size = bend - bstart; + + Tensor gate_t = batch_gate->Slice(bstart, bend); + Tensor reset_hidden_prev_t = + batch_reset_hidden_prev->Slice(bstart, bend); + Tensor hidden_t = batch_hidden->Slice(bstart, bend); + gru_value.output_value = hidden_t.data(); + gru_value.gate_value = gate_t.data(); + gru_value.reset_output_value = reset_hidden_prev_t.data(); + + if (gru_value.prev_out_value) { + blas.GEMM_COMPUTE( + CblasNoTrans, CblasPacked, cur_batch_size, frame_size * 2, + frame_size, gru_value.prev_out_value, frame_size, packed_gate, + frame_size * 2, T(1), gru_value.gate_value, frame_size * 3); + } + + math::detail::forward_reset_output( + math::detail::forward::gru_resetOutput(), gru_value, frame_size, + cur_batch_size, active_gate); + + if (gru_value.prev_out_value) { + blas.GEMM_COMPUTE( + CblasNoTrans, CblasPacked, cur_batch_size, frame_size, frame_size, + gru_value.reset_output_value, frame_size, packed_state, + frame_size, T(1), gru_value.gate_value + frame_size * 2, + frame_size * 3); + } + + math::detail::forward_final_output( + math::detail::forward::gru_finalOutput(), gru_value, frame_size, + cur_batch_size, active_node); + + gru_value.prev_out_value = gru_value.output_value; + } + + blas.GEMM_FREE(packed_gate); + blas.GEMM_FREE(packed_state); + } else { +#endif + for (size_t n = 0; n < seq_len; n++) { + int bstart = static_cast(batch_starts[n]); + int bend = static_cast(batch_starts[n + 1]); + int cur_batch_size = bend - bstart; + + Tensor gate_t = batch_gate->Slice(bstart, bend); + Tensor reset_hidden_prev_t = + batch_reset_hidden_prev->Slice(bstart, bend); + Tensor hidden_t = batch_hidden->Slice(bstart, bend); + gru_value.output_value = hidden_t.data(); + gru_value.gate_value = gate_t.data(); + gru_value.reset_output_value = reset_hidden_prev_t.data(); + + math::GRUUnitFunctor::compute( + dev_ctx, gru_value, frame_size, cur_batch_size, active_node, + active_gate); + + gru_value.prev_out_value = gru_value.output_value; + } +#ifdef PADDLE_WITH_MKLML + } +#endif + math::Batch2LoDTensorFunctor to_seq; + batch_hidden->set_lod(batch_gate->lod()); + to_seq(dev_ctx, *batch_hidden, hidden); + } + + void Compute(const framework::ExecutionContext& context) const override { + BatchCompute(context); + } +}; + } // namespace operators } // namespace paddle @@ -218,9 +375,8 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(gru, ops::GRUOp, ops::GRUOpMaker, paddle::framework::DefaultGradOpDescMaker); REGISTER_OPERATOR(gru_grad, ops::GRUGradOp); -REGISTER_OP_CPU_KERNEL( - gru, ops::GRUKernel, - ops::GRUKernel); +REGISTER_OP_CPU_KERNEL(gru, ops::GRUCPUKernel, + ops::GRUCPUKernel); REGISTER_OP_CPU_KERNEL( gru_grad, ops::GRUGradKernel, ops::GRUGradKernel); diff --git a/paddle/fluid/operators/gru_op.cu.cc b/paddle/fluid/operators/gru_op.cu.cc index baf455a840314d1ab94eb8e0a2e5c660ba4202da..55721c283dd18c2f9642563a9ce1eabfce16fd7b 100644 --- a/paddle/fluid/operators/gru_op.cu.cc +++ b/paddle/fluid/operators/gru_op.cu.cc @@ -14,6 +14,96 @@ limitations under the License. */ #include "paddle/fluid/operators/gru_op.h" +namespace paddle { +namespace operators { + +template +class GRUKernel : public framework::OpKernel { + public: + void BatchCompute(const framework::ExecutionContext& context) const { + auto* input = context.Input("Input"); + auto* h0 = context.Input("H0"); + auto* weight = context.Input("Weight"); + const T* weight_data = weight->data(); + auto* bias = context.Input("Bias"); + auto* batch_gate = context.Output("BatchGate"); + batch_gate->mutable_data(context.GetPlace()); + auto* batch_reset_hidden_prev = + context.Output("BatchResetHiddenPrev"); + batch_reset_hidden_prev->mutable_data(context.GetPlace()); + auto* batch_hidden = context.Output("BatchHidden"); + batch_hidden->mutable_data(context.GetPlace()); + auto* hidden = context.Output("Hidden"); + hidden->mutable_data(context.GetPlace()); + + auto hidden_dims = hidden->dims(); + + bool is_reverse = context.Attr("is_reverse"); + math::LoDTensor2BatchFunctor to_batch; + auto& dev_ctx = context.template device_context(); + to_batch(dev_ctx, *input, batch_gate, true, is_reverse); + + if (bias) { + math::RowwiseAdd add_bias; + add_bias(dev_ctx, *batch_gate, *bias, batch_gate); + } + + int frame_size = hidden_dims[1]; + math::GRUMetaValue gru_value; + gru_value.gate_weight = const_cast(weight_data); + gru_value.state_weight = + const_cast(weight_data + 2 * frame_size * frame_size); + Tensor ordered_h0; + + framework::Vector order(batch_gate->lod()[2]); + + if (h0) { + // Since the batch computing for GRU reorders the input sequences + // according to their length. The initialized cell state also needs + // to reorder. + ReorderInitState( + context.template device_context(), *h0, order, + &ordered_h0, true); + gru_value.prev_out_value = ordered_h0.data(); + } else { + gru_value.prev_out_value = nullptr; + } + auto batch_starts = batch_gate->lod()[0]; + size_t num_batch = batch_starts.size() - 1; + auto active_node = math::detail::GetActivationType( + context.Attr("activation")); + auto active_gate = math::detail::GetActivationType( + context.Attr("gate_activation")); + for (size_t n = 0; n < num_batch; n++) { + int bstart = static_cast(batch_starts[n]); + int bend = static_cast(batch_starts[n + 1]); + int cur_batch_size = bend - bstart; + + Tensor gate_t = batch_gate->Slice(bstart, bend); + Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend); + Tensor hidden_t = batch_hidden->Slice(bstart, bend); + gru_value.output_value = hidden_t.data(); + gru_value.gate_value = gate_t.data(); + gru_value.reset_output_value = reset_hidden_prev_t.data(); + math::GRUUnitFunctor::compute( + dev_ctx, gru_value, frame_size, cur_batch_size, active_node, + active_gate); + gru_value.prev_out_value = gru_value.output_value; + } + + math::Batch2LoDTensorFunctor to_seq; + batch_hidden->set_lod(batch_gate->lod()); + to_seq(dev_ctx, *batch_hidden, hidden); + } + + void Compute(const framework::ExecutionContext& context) const override { + BatchCompute(context); + } +}; + +} // namespace operators +} // namespace paddle + namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( gru, ops::GRUKernel, diff --git a/paddle/fluid/operators/gru_op.h b/paddle/fluid/operators/gru_op.h index 3b0d93e54b72910de1429ddf41eb6b0fe9646942..0b551e8046be16c95f7d6b10b68b32a9af594f73 100644 --- a/paddle/fluid/operators/gru_op.h +++ b/paddle/fluid/operators/gru_op.h @@ -37,90 +37,6 @@ inline void ReorderInitState(const DeviceContext& ctx, row_shuffle(ctx, src, index_lod, dst, indexed_src); } -template -class GRUKernel : public framework::OpKernel { - public: - void BatchCompute(const framework::ExecutionContext& context) const { - auto* input = context.Input("Input"); - auto* h0 = context.Input("H0"); - auto* weight = context.Input("Weight"); - const T* weight_data = weight->data(); - auto* bias = context.Input("Bias"); - auto* batch_gate = context.Output("BatchGate"); - batch_gate->mutable_data(context.GetPlace()); - auto* batch_reset_hidden_prev = - context.Output("BatchResetHiddenPrev"); - batch_reset_hidden_prev->mutable_data(context.GetPlace()); - auto* batch_hidden = context.Output("BatchHidden"); - batch_hidden->mutable_data(context.GetPlace()); - auto* hidden = context.Output("Hidden"); - hidden->mutable_data(context.GetPlace()); - - auto hidden_dims = hidden->dims(); - - bool is_reverse = context.Attr("is_reverse"); - math::LoDTensor2BatchFunctor to_batch; - auto& dev_ctx = context.template device_context(); - to_batch(dev_ctx, *input, batch_gate, true, is_reverse); - - if (bias) { - math::RowwiseAdd add_bias; - add_bias(dev_ctx, *batch_gate, *bias, batch_gate); - } - - int frame_size = hidden_dims[1]; - math::GRUMetaValue gru_value; - gru_value.gate_weight = const_cast(weight_data); - gru_value.state_weight = - const_cast(weight_data + 2 * frame_size * frame_size); - Tensor ordered_h0; - - framework::Vector order(batch_gate->lod()[2]); - - if (h0) { - // Since the batch computing for GRU reorders the input sequences - // according to their length. The initialized cell state also needs - // to reorder. - ReorderInitState( - context.template device_context(), *h0, order, - &ordered_h0, true); - gru_value.prev_out_value = ordered_h0.data(); - } else { - gru_value.prev_out_value = nullptr; - } - auto batch_starts = batch_gate->lod()[0]; - size_t num_batch = batch_starts.size() - 1; - auto active_node = math::detail::GetActivationType( - context.Attr("activation")); - auto active_gate = math::detail::GetActivationType( - context.Attr("gate_activation")); - for (size_t n = 0; n < num_batch; n++) { - int bstart = static_cast(batch_starts[n]); - int bend = static_cast(batch_starts[n + 1]); - int cur_batch_size = bend - bstart; - - Tensor gate_t = batch_gate->Slice(bstart, bend); - Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend); - Tensor hidden_t = batch_hidden->Slice(bstart, bend); - gru_value.output_value = hidden_t.data(); - gru_value.gate_value = gate_t.data(); - gru_value.reset_output_value = reset_hidden_prev_t.data(); - math::GRUUnitFunctor::compute( - dev_ctx, gru_value, frame_size, cur_batch_size, active_node, - active_gate); - gru_value.prev_out_value = gru_value.output_value; - } - - math::Batch2LoDTensorFunctor to_seq; - batch_hidden->set_lod(batch_gate->lod()); - to_seq(dev_ctx, *batch_hidden, hidden); - } - - void Compute(const framework::ExecutionContext& context) const override { - BatchCompute(context); - } -}; - template class GRUGradKernel : public framework::OpKernel { public: diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index 70f88f24f682e05972ca73ef7b50f96be50d1ef4..2558154e0b39a4281bfaa59ba75867589d73be5d 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -90,6 +90,25 @@ class Blas { void GEMM(bool transA, bool transB, int M, int N, int K, T alpha, const T* A, int lda, const T* B, int ldb, T beta, T* C, int ldc) const; +#ifdef PADDLE_WITH_MKLML + template + T* GEMM_ALLOC(const CBLAS_IDENTIFIER id, const int M, const int N, + const int K) const; + + template + void GEMM_PACK(const CBLAS_IDENTIFIER id, const CBLAS_TRANSPOSE trans, int M, + int N, int K, const T alpha, const T* src, const int ld, + T* dst) const; + + template + void GEMM_COMPUTE(int transA, int transB, int M, int N, int K, const T* A, + const int lda, const T* B, const int ldb, T beta, T* C, + const int ldc) const; + + template + void GEMM_FREE(T* data) const; +#endif + template void MatMul(const framework::Tensor& mat_a, bool trans_a, const framework::Tensor& mat_b, bool trans_b, T alpha, @@ -146,6 +165,28 @@ class BlasT : private Blas { Base()->template GEMM(args...); } +#ifdef PADDLE_WITH_MKLML + template + T* GEMM_ALLOC(ARGS... args) const { + return Base()->template GEMM_ALLOC(args...); + } + + template + void GEMM_PACK(ARGS... args) const { + Base()->template GEMM_PACK(args...); + } + + template + void GEMM_COMPUTE(ARGS... args) const { + Base()->template GEMM_COMPUTE(args...); + } + + template + void GEMM_FREE(ARGS... args) const { + Base()->template GEMM_FREE(args...); + } +#endif + template void MatMul(ARGS... args) const { Base()->template MatMul(args...); diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index a0802ef90ca7e30a2b22d187cb9092163518d8e9..bf3382107960dfd8b52f94b421b49022dcb6d291 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -31,6 +31,26 @@ struct CBlas { platform::dynload::cblas_sgemm(args...); } + template + static float *GEMM_ALLOC(ARGS... args) { + return platform::dynload::cblas_sgemm_alloc(args...); + } + + template + static void GEMM_PACK(ARGS... args) { + platform::dynload::cblas_sgemm_pack(args...); + } + + template + static void GEMM_COMPUTE(ARGS... args) { + platform::dynload::cblas_sgemm_compute(args...); + } + + template + static void GEMM_FREE(ARGS... args) { + platform::dynload::cblas_sgemm_free(args...); + } + #ifdef PADDLE_WITH_LIBXSMM template static void SMM_GEMM(ARGS... args) { @@ -71,6 +91,26 @@ struct CBlas { platform::dynload::cblas_dgemm(args...); } + template + static double *GEMM_ALLOC(ARGS... args) { + return platform::dynload::cblas_dgemm_alloc(args...); + } + + template + static void GEMM_PACK(ARGS... args) { + platform::dynload::cblas_dgemm_pack(args...); + } + + template + static void GEMM_COMPUTE(ARGS... args) { + platform::dynload::cblas_dgemm_compute(args...); + } + + template + static void GEMM_FREE(ARGS... args) { + platform::dynload::cblas_dgemm_free(args...); + } + #ifdef PADDLE_WITH_LIBXSMM template static void SMM_GEMM(ARGS... args) { @@ -224,6 +264,41 @@ inline void GEMM_WARP(CBLAS_ORDER order, CBLAS_TRANSPOSE transA, beta, C, ldc); } +#ifdef PADDLE_WITH_MKLML +template <> +template +T *Blas::GEMM_ALLOC(const CBLAS_IDENTIFIER id, + const int M, const int N, + const int K) const { + return CBlas::GEMM_ALLOC(id, M, N, K); +} + +template <> +template +void Blas::GEMM_PACK(const CBLAS_IDENTIFIER id, + const CBLAS_TRANSPOSE trans, + int M, int N, int K, + const T alpha, const T *src, + const int ld, T *dst) const { + CBlas::GEMM_PACK(CblasRowMajor, id, trans, M, N, K, alpha, src, ld, dst); +} + +template <> +template +void Blas::GEMM_COMPUTE( + int transA, int transB, int M, int N, int K, const T *A, const int lda, + const T *B, const int ldb, T beta, T *C, const int ldc) const { + CBlas::GEMM_COMPUTE(CblasRowMajor, transA, transB, M, N, K, A, lda, B, ldb, + beta, C, ldc); +} + +template <> +template +void Blas::GEMM_FREE(T *data) const { + CBlas::GEMM_FREE(data); +} +#endif + template <> template void Blas::GEMM(CBLAS_TRANSPOSE transA, diff --git a/paddle/fluid/platform/dynload/mklml.h b/paddle/fluid/platform/dynload/mklml.h index 17acefe8cde01809572e4c86cbdccfed9a477a51..9e7a616094e184695de521aa035257bde4170a91 100644 --- a/paddle/fluid/platform/dynload/mklml.h +++ b/paddle/fluid/platform/dynload/mklml.h @@ -60,6 +60,14 @@ extern void* mklml_dso_handle; __macro(cblas_dgemm_batch); \ __macro(vsAdd); \ __macro(vdAdd); \ + __macro(cblas_sgemm_alloc); \ + __macro(cblas_sgemm_pack); \ + __macro(cblas_sgemm_compute); \ + __macro(cblas_sgemm_free); \ + __macro(cblas_dgemm_alloc); \ + __macro(cblas_dgemm_pack); \ + __macro(cblas_dgemm_compute); \ + __macro(cblas_dgemm_free); \ __macro(MKL_Set_Num_Threads) MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP);