From e4dba69a2fdc793ca399042e688256108e0098fb Mon Sep 17 00:00:00 2001 From: Feiyu Chan Date: Wed, 2 Mar 2022 10:23:15 +0800 Subject: [PATCH] [Pten] Gru lstm migration (#39729) * move sequence2batch * move lstm and gru * Add phi/kernels directory into exclusion to stop using hipcc to compile non .cu files in it. --- cmake/generic.cmake | 4 +- .../fused/fused_embedding_fc_lstm_op.cc | 6 +- paddle/fluid/operators/fused/fusion_gru_op.cc | 6 +- .../fluid/operators/fused/fusion_lstm_op.cc | 6 +- paddle/fluid/operators/fused/multi_gru_op.cc | 2 +- paddle/fluid/operators/gru_op.cc | 28 +- paddle/fluid/operators/gru_op.cu.cc | 12 +- paddle/fluid/operators/gru_op.h | 22 +- paddle/fluid/operators/lstm_op.h | 38 +- paddle/fluid/operators/lstmp_op.h | 68 +-- paddle/fluid/operators/math/CMakeLists.txt | 6 +- paddle/fluid/operators/math/gru_compute.h | 80 ---- paddle/fluid/operators/math/lstm_compute.cc | 93 ---- paddle/fluid/operators/math/lstm_compute.cu | 59 --- paddle/fluid/operators/rnn_op.h | 64 +-- paddle/phi/kernels/funcs/CMakeLists.txt | 4 + .../kernels/funcs}/detail/CMakeLists.txt | 0 .../funcs}/detail/activation_functions.h | 68 +-- .../kernels/funcs}/detail/avx_functions.cc | 19 +- .../kernels/funcs}/detail/avx_mathfun.h | 6 +- .../kernels/funcs}/detail/gru_cpu_kernel.h | 451 ++++++++++++------ .../kernels/funcs}/detail/gru_gpu_kernel.h | 106 ++-- .../kernels/funcs}/detail/gru_kernel.h | 150 +++--- .../kernels/funcs}/detail/lstm_cpu_kernel.h | 266 ++++++++--- .../kernels/funcs}/detail/lstm_gpu_kernel.h | 159 ++++-- .../kernels/funcs}/detail/lstm_kernel.h | 123 +++-- paddle/phi/kernels/funcs/gru_compute.cc | 373 +++++++++++++++ paddle/phi/kernels/funcs/gru_compute.cu | 349 ++++++++++++++ paddle/phi/kernels/funcs/gru_compute.h | 88 ++++ paddle/phi/kernels/funcs/lstm_compute.cc | 103 ++++ paddle/phi/kernels/funcs/lstm_compute.cu | 76 +++ .../math => phi/kernels/funcs}/lstm_compute.h | 39 +- .../kernels/funcs}/sequence2batch.cc | 62 +-- .../kernels/funcs}/sequence2batch.cu | 72 +-- .../kernels/funcs}/sequence2batch.h | 66 +-- 35 files changed, 2181 insertions(+), 893 deletions(-) delete mode 100644 paddle/fluid/operators/math/gru_compute.h delete mode 100644 paddle/fluid/operators/math/lstm_compute.cc delete mode 100644 paddle/fluid/operators/math/lstm_compute.cu rename paddle/{fluid/operators/math => phi/kernels/funcs}/detail/CMakeLists.txt (100%) rename paddle/{fluid/operators/math => phi/kernels/funcs}/detail/activation_functions.h (75%) rename paddle/{fluid/operators/math => phi/kernels/funcs}/detail/avx_functions.cc (87%) rename paddle/{fluid/operators/math => phi/kernels/funcs}/detail/avx_mathfun.h (99%) rename paddle/{fluid/operators/math => phi/kernels/funcs}/detail/gru_cpu_kernel.h (60%) rename paddle/{fluid/operators/math => phi/kernels/funcs}/detail/gru_gpu_kernel.h (74%) rename paddle/{fluid/operators/math => phi/kernels/funcs}/detail/gru_kernel.h (64%) rename paddle/{fluid/operators/math => phi/kernels/funcs}/detail/lstm_cpu_kernel.h (65%) rename paddle/{fluid/operators/math => phi/kernels/funcs}/detail/lstm_gpu_kernel.h (68%) rename paddle/{fluid/operators/math => phi/kernels/funcs}/detail/lstm_kernel.h (59%) create mode 100644 paddle/phi/kernels/funcs/gru_compute.cc create mode 100644 paddle/phi/kernels/funcs/gru_compute.cu create mode 100644 paddle/phi/kernels/funcs/gru_compute.h create mode 100644 paddle/phi/kernels/funcs/lstm_compute.cc create mode 100644 paddle/phi/kernels/funcs/lstm_compute.cu rename paddle/{fluid/operators/math => phi/kernels/funcs}/lstm_compute.h (56%) rename paddle/{fluid/operators/math => phi/kernels/funcs}/sequence2batch.cc (56%) rename paddle/{fluid/operators/math => phi/kernels/funcs}/sequence2batch.cu (55%) rename paddle/{fluid/operators/math => phi/kernels/funcs}/sequence2batch.h (80%) diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 51ed537ce5d..da81575188f 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -580,8 +580,8 @@ function(hip_library TARGET_NAME) cmake_parse_arguments(hip_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) if(hip_library_SRCS) # FindHIP.cmake defined hip_add_library, HIP_SOURCE_PROPERTY_FORMAT is requried if no .cu files found - if(NOT ${CMAKE_CURRENT_SOURCE_DIR} MATCHES ".*/operators") - set_source_files_properties(${hip_library_SRCS} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) + if(NOT (${CMAKE_CURRENT_SOURCE_DIR} MATCHES ".*/operators" OR ${CMAKE_CURRENT_SOURCE_DIR} MATCHES ".*/phi/kernels")) + set_source_files_properties(${hip_library_SRCS} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) endif() if (hip_library_SHARED OR hip_library_shared) # build *.so hip_add_library(${TARGET_NAME} SHARED ${hip_library_SRCS}) diff --git a/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc b/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc index 56c2c86e1a7..0c83c36b475 100644 --- a/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc +++ b/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc @@ -15,9 +15,9 @@ limitations under the License. */ #include "paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.h" #include #include "paddle/fluid/operators/math/cpu_vec.h" -#include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/fluid/platform/cpu_info.h" #include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/sequence2batch.h" namespace paddle { namespace operators { @@ -473,7 +473,7 @@ class FusedEmbeddingFCLSTMKernel : public framework::OpKernel { hidden_out->mutable_data(place); cell_out->mutable_data(place); - math::LoDTensor2BatchFunctor to_batch; + phi::funcs::LoDTensor2BatchFunctor to_batch; auto& dev_ctx = ctx.template device_context(); auto blas = phi::funcs::GetBlas(dev_ctx); @@ -591,7 +591,7 @@ class FusedEmbeddingFCLSTMKernel : public framework::OpKernel { #undef MOVE_ONE_BATCH #undef DEFINE_CUR - math::Batch2LoDTensorFunctor to_seq; + phi::funcs::Batch2LoDTensorFunctor to_seq; batched_h_out->set_lod(batched_lod); to_seq(dev_ctx, *batched_h_out, hidden_out); batched_c_out->set_lod(batched_lod); diff --git a/paddle/fluid/operators/fused/fusion_gru_op.cc b/paddle/fluid/operators/fused/fusion_gru_op.cc index 41a69031c54..3311e3b4ebc 100644 --- a/paddle/fluid/operators/fused/fusion_gru_op.cc +++ b/paddle/fluid/operators/fused/fusion_gru_op.cc @@ -19,8 +19,8 @@ limitations under the License. */ #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/operators/jit/kernels.h" #include "paddle/fluid/operators/math/fc.h" -#include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/sequence2batch.h" #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" #endif @@ -368,7 +368,7 @@ class FusionGRUKernel : public framework::OpKernel { hidden_out->mutable_data(place); auto& dev_ctx = ctx.template device_context(); auto blas = phi::funcs::GetBlas(dev_ctx); - math::LoDTensor2BatchFunctor to_batch; + phi::funcs::LoDTensor2BatchFunctor to_batch; math::FCFunctor fc; if (M > D3) { @@ -463,7 +463,7 @@ class FusionGRUKernel : public framework::OpKernel { batched_input_data = cur_batched_data; } - math::Batch2LoDTensorFunctor to_seq; + phi::funcs::Batch2LoDTensorFunctor to_seq; batched_out->set_lod(batched_lod); to_seq(dev_ctx, *batched_out, hidden_out); } diff --git a/paddle/fluid/operators/fused/fusion_lstm_op.cc b/paddle/fluid/operators/fused/fusion_lstm_op.cc index 06d406867f0..00be8b09d12 100644 --- a/paddle/fluid/operators/fused/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fused/fusion_lstm_op.cc @@ -16,8 +16,8 @@ limitations under the License. */ #include #include "paddle/fluid/operators/jit/kernels.h" #include "paddle/fluid/operators/math/fc.h" -#include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/sequence2batch.h" #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" #endif @@ -421,7 +421,7 @@ class FuisonLSTMKernel : public framework::OpKernel { hidden_out->mutable_data(place); cell_out->mutable_data(place); - math::LoDTensor2BatchFunctor to_batch; + phi::funcs::LoDTensor2BatchFunctor to_batch; auto& dev_ctx = ctx.template device_context(); auto blas = phi::funcs::GetBlas(dev_ctx); math::FCFunctor fc; @@ -514,7 +514,7 @@ class FuisonLSTMKernel : public framework::OpKernel { batched_input_data = cur_in_data; } - math::Batch2LoDTensorFunctor to_seq; + phi::funcs::Batch2LoDTensorFunctor to_seq; batched_h_out->set_lod(batched_lod); to_seq(dev_ctx, *batched_h_out, hidden_out); batched_c_out->set_lod(batched_lod); diff --git a/paddle/fluid/operators/fused/multi_gru_op.cc b/paddle/fluid/operators/fused/multi_gru_op.cc index 84826ff3993..c2260c53b2e 100644 --- a/paddle/fluid/operators/fused/multi_gru_op.cc +++ b/paddle/fluid/operators/fused/multi_gru_op.cc @@ -19,8 +19,8 @@ limitations under the License. */ #include #include "paddle/fluid/operators/jit/kernels.h" #include "paddle/fluid/operators/math/fc.h" -#include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/sequence2batch.h" #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" #endif diff --git a/paddle/fluid/operators/gru_op.cc b/paddle/fluid/operators/gru_op.cc index 88530b5352d..d7cf03ddd61 100644 --- a/paddle/fluid/operators/gru_op.cc +++ b/paddle/fluid/operators/gru_op.cc @@ -15,9 +15,9 @@ limitations under the License. */ #include "paddle/fluid/operators/gru_op.h" #include #include -#include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h" -#include "paddle/fluid/operators/math/detail/gru_kernel.h" #include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/detail/gru_cpu_kernel.h" +#include "paddle/phi/kernels/funcs/detail/gru_kernel.h" DECLARE_int32(paddle_num_threads); @@ -316,7 +316,7 @@ class GRUCPUKernel : public framework::OpKernel { batch_hidden->mutable_data(context.GetPlace()); bool is_reverse = context.Attr("is_reverse"); - math::LoDTensor2BatchFunctor to_batch; + phi::funcs::LoDTensor2BatchFunctor to_batch; auto& dev_ctx = context.template device_context(); to_batch(dev_ctx, *input, batch_gate, true, is_reverse); @@ -326,7 +326,7 @@ class GRUCPUKernel : public framework::OpKernel { } int frame_size = hidden_dims[1]; - math::GRUMetaValue gru_value; + phi::funcs::GRUMetaValue gru_value; gru_value.gate_weight = const_cast(weight_data); gru_value.state_weight = const_cast(weight_data + 2 * frame_size * frame_size); @@ -347,9 +347,9 @@ class GRUCPUKernel : public framework::OpKernel { } auto batch_starts = batch_gate->lod()[0]; size_t seq_len = batch_starts.size() - 1; - auto active_node = math::detail::GetActivationType( + auto active_node = phi::funcs::detail::GetActivationType( context.Attr("activation")); - auto active_gate = math::detail::GetActivationType( + auto active_gate = phi::funcs::detail::GetActivationType( context.Attr("gate_activation")); #ifdef PADDLE_WITH_MKLML @@ -396,9 +396,9 @@ class GRUCPUKernel : public framework::OpKernel { frame_size * 2, T(1), gru_value.gate_value, frame_size * 3); } - math::detail::forward_reset_output( - math::detail::forward::gru_resetOutput(), gru_value, frame_size, - cur_batch_size, active_gate); + phi::funcs::detail::forward_reset_output( + phi::funcs::detail::forward::gru_resetOutput(), gru_value, + frame_size, cur_batch_size, active_gate); if (gru_value.prev_out_value) { blas.GEMM_COMPUTE( @@ -408,9 +408,9 @@ class GRUCPUKernel : public framework::OpKernel { frame_size * 3); } - math::detail::forward_final_output( - math::detail::forward::gru_finalOutput(), gru_value, frame_size, - cur_batch_size, active_node, origin_mode); + phi::funcs::detail::forward_final_output( + phi::funcs::detail::forward::gru_finalOutput(), gru_value, + frame_size, cur_batch_size, active_node, origin_mode); gru_value.prev_out_value = gru_value.output_value; } @@ -432,7 +432,7 @@ class GRUCPUKernel : public framework::OpKernel { gru_value.gate_value = gate_t.data(); gru_value.reset_output_value = reset_hidden_prev_t.data(); - math::GRUUnitFunctor::compute( + phi::funcs::GRUUnitFunctor::compute( dev_ctx, gru_value, frame_size, cur_batch_size, active_node, active_gate, origin_mode); @@ -441,7 +441,7 @@ class GRUCPUKernel : public framework::OpKernel { #ifdef PADDLE_WITH_MKLML } #endif - math::Batch2LoDTensorFunctor to_seq; + phi::funcs::Batch2LoDTensorFunctor to_seq; batch_hidden->set_lod(batch_gate->lod()); to_seq(dev_ctx, *batch_hidden, hidden); } diff --git a/paddle/fluid/operators/gru_op.cu.cc b/paddle/fluid/operators/gru_op.cu.cc index 7d055240916..5be0acc1543 100644 --- a/paddle/fluid/operators/gru_op.cu.cc +++ b/paddle/fluid/operators/gru_op.cu.cc @@ -65,7 +65,7 @@ class GRUKernel : public framework::OpKernel { batch_hidden->mutable_data(context.GetPlace()); bool is_reverse = context.Attr("is_reverse"); - math::LoDTensor2BatchFunctor to_batch; + phi::funcs::LoDTensor2BatchFunctor to_batch; auto& dev_ctx = context.template device_context(); to_batch(dev_ctx, *input, batch_gate, true, is_reverse); @@ -75,7 +75,7 @@ class GRUKernel : public framework::OpKernel { } int frame_size = hidden_dims[1]; - math::GRUMetaValue gru_value; + phi::funcs::GRUMetaValue gru_value; gru_value.gate_weight = const_cast(weight_data); gru_value.state_weight = const_cast(weight_data + 2 * frame_size * frame_size); @@ -96,9 +96,9 @@ class GRUKernel : public framework::OpKernel { } auto batch_starts = batch_gate->lod()[0]; size_t num_batch = batch_starts.size() - 1; - auto active_node = math::detail::GetActivationType( + auto active_node = phi::funcs::detail::GetActivationType( context.Attr("activation")); - auto active_gate = math::detail::GetActivationType( + auto active_gate = phi::funcs::detail::GetActivationType( context.Attr("gate_activation")); for (size_t n = 0; n < num_batch; n++) { int bstart = static_cast(batch_starts[n]); @@ -111,13 +111,13 @@ class GRUKernel : public framework::OpKernel { gru_value.output_value = hidden_t.data(); gru_value.gate_value = gate_t.data(); gru_value.reset_output_value = reset_hidden_prev_t.data(); - math::GRUUnitFunctor::compute( + phi::funcs::GRUUnitFunctor::compute( dev_ctx, gru_value, frame_size, cur_batch_size, active_node, active_gate, origin_mode); gru_value.prev_out_value = gru_value.output_value; } - math::Batch2LoDTensorFunctor to_seq; + phi::funcs::Batch2LoDTensorFunctor to_seq; batch_hidden->set_lod(batch_gate->lod()); to_seq(dev_ctx, *batch_hidden, hidden); } diff --git a/paddle/fluid/operators/gru_op.h b/paddle/fluid/operators/gru_op.h index 130b10c7390..852655034c8 100644 --- a/paddle/fluid/operators/gru_op.h +++ b/paddle/fluid/operators/gru_op.h @@ -16,10 +16,10 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/detail/activation_functions.h" -#include "paddle/fluid/operators/math/gru_compute.h" -#include "paddle/fluid/operators/math/sequence2batch.h" +#include "paddle/phi/kernels/funcs/detail/activation_functions.h" +#include "paddle/phi/kernels/funcs/gru_compute.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/sequence2batch.h" namespace paddle { namespace operators { @@ -32,7 +32,7 @@ inline void ReorderInitState(const DeviceContext& ctx, const framework::Tensor& src, framework::Vector index_lod, framework::Tensor* dst, bool indexed_src) { - math::CopyMatrixRowsFunctor row_shuffle; + phi::funcs::CopyMatrixRowsFunctor row_shuffle; dst->mutable_data(src.dims(), ctx.GetPlace()); row_shuffle(ctx, src, index_lod, dst, indexed_src); } @@ -63,7 +63,7 @@ class GRUGradKernel : public framework::OpKernel { auto hidden_dims = hidden->dims(); int frame_size = hidden_dims[1]; - math::LoDTensor2BatchFunctor to_batch; + phi::funcs::LoDTensor2BatchFunctor to_batch; LoDTensor batch_hidden_grad, batch_gate_grad, batch_reset_hidden_prev_grad; batch_hidden_grad.mutable_data(hidden_dims, context.GetPlace()); batch_gate_grad.mutable_data(gate_dims, context.GetPlace()); @@ -93,12 +93,12 @@ class GRUGradKernel : public framework::OpKernel { batch_hidden_grad.set_lod(batch_hidden->lod()); to_batch(dev_ctx, *hidden_grad, &batch_hidden_grad, false, is_reverse); - math::GRUMetaValue gru_value; + phi::funcs::GRUMetaValue gru_value; gru_value.gate_weight = const_cast(weight_data); gru_value.state_weight = const_cast(weight_data + 2 * frame_size * frame_size); - math::GRUMetaGrad gru_grad; + phi::funcs::GRUMetaGrad gru_grad; if (weight_grad) { gru_grad.gate_weight_grad = weight_grad->mutable_data(context.GetPlace()); @@ -112,9 +112,9 @@ class GRUGradKernel : public framework::OpKernel { auto batch_starts = batch_hidden_grad.lod()[0]; size_t num_batch = batch_starts.size() - 1; - auto active_node = math::detail::GetActivationType( + auto active_node = phi::funcs::detail::GetActivationType( context.Attr("activation")); - auto active_gate = math::detail::GetActivationType( + auto active_gate = phi::funcs::detail::GetActivationType( context.Attr("gate_activation")); for (int n = static_cast(num_batch) - 1; n >= 0; n--) { int bstart = static_cast(batch_starts[n]); @@ -145,13 +145,13 @@ class GRUGradKernel : public framework::OpKernel { gru_grad.prev_out_grad = hidden_prev_grad_t.data(); } gru_value.output_value = nullptr; - math::GRUUnitGradFunctor::compute( + phi::funcs::GRUUnitGradFunctor::compute( dev_ctx, gru_value, gru_grad, frame_size, cur_batch_size, active_node, active_gate, origin_mode); } if (input_grad) { input_grad->mutable_data(context.GetPlace()); - math::Batch2LoDTensorFunctor to_seq; + phi::funcs::Batch2LoDTensorFunctor to_seq; batch_gate_grad.set_lod(batch_gate->lod()); to_seq(dev_ctx, batch_gate_grad, input_grad); } diff --git a/paddle/fluid/operators/lstm_op.h b/paddle/fluid/operators/lstm_op.h index 62f9cd26c41..4ec3072a96d 100644 --- a/paddle/fluid/operators/lstm_op.h +++ b/paddle/fluid/operators/lstm_op.h @@ -15,10 +15,10 @@ limitations under the License. */ #pragma once #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/detail/activation_functions.h" -#include "paddle/fluid/operators/math/lstm_compute.h" -#include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/detail/activation_functions.h" +#include "paddle/phi/kernels/funcs/lstm_compute.h" +#include "paddle/phi/kernels/funcs/sequence2batch.h" namespace paddle { namespace operators { @@ -31,7 +31,7 @@ inline void ReorderInitState(const DeviceContext& ctx, const framework::Tensor& src, framework::Vector index_lod, framework::Tensor* dst, bool indexed_src) { - math::CopyMatrixRowsFunctor row_shuffle; + phi::funcs::CopyMatrixRowsFunctor row_shuffle; dst->mutable_data(src.dims(), ctx.GetPlace()); row_shuffle(ctx, src, index_lod, dst, indexed_src); } @@ -64,7 +64,7 @@ class LSTMKernel : public framework::OpKernel { cell_out->mutable_data(ctx.GetPlace()); bool is_reverse = ctx.Attr("is_reverse"); - math::LoDTensor2BatchFunctor to_batch; + phi::funcs::LoDTensor2BatchFunctor to_batch; auto& device_ctx = ctx.template device_context(); to_batch(device_ctx, *input, batch_gate, true, is_reverse); @@ -80,7 +80,7 @@ class LSTMKernel : public framework::OpKernel { add_bias(device_ctx, *batch_gate, gate_bias, batch_gate); } - math::LstmMetaValue lstm_value; + phi::funcs::LstmMetaValue lstm_value; if (bias && ctx.Attr("use_peepholes")) { T* bias_data = const_cast(bias->data()); // the code style in LstmMetaValue will be updated later. @@ -121,11 +121,11 @@ class LSTMKernel : public framework::OpKernel { auto batch_starts = batch_gate->lod()[0]; size_t num_batch = batch_starts.size() - 1; - auto gate_act = math::detail::GetActivationType( + auto gate_act = phi::funcs::detail::GetActivationType( ctx.Attr("gate_activation")); - auto cell_act = math::detail::GetActivationType( + auto cell_act = phi::funcs::detail::GetActivationType( ctx.Attr("cell_activation")); - auto cand_act = math::detail::GetActivationType( + auto cand_act = phi::funcs::detail::GetActivationType( ctx.Attr("candidate_activation")); auto blas = phi::funcs::GetBlas(device_ctx); @@ -166,13 +166,13 @@ class LSTMKernel : public framework::OpKernel { lstm_value.state_value = cell_t.data(); lstm_value.state_active_value = cell_pre_act_t.data(); T cell_clip = 0.0; - math::LstmUnitFunctor::compute( + phi::funcs::LstmUnitFunctor::compute( device_ctx, lstm_value, frame_size, cur_batch_size, cell_clip, gate_act, cell_act, cand_act); lstm_value.prev_state_value = lstm_value.state_value; } - math::Batch2LoDTensorFunctor to_seq; + phi::funcs::Batch2LoDTensorFunctor to_seq; batch_hidden.set_lod(batch_gate->lod()); // restore the output hidden in LoDTensor from the batch hidden to_seq(device_ctx, batch_hidden, hidden_out); @@ -241,7 +241,7 @@ class LSTMGradKernel : public framework::OpKernel { ") should be %d, but received %d in LSTM@Grad operator.", frame_size, out_dims[1])); - math::LstmMetaValue lstm_value; + phi::funcs::LstmMetaValue lstm_value; if (bias && ctx.Attr("use_peepholes")) { T* bias_data = const_cast(bias->data()); lstm_value.check_ig = bias_data + 4 * frame_size; @@ -253,7 +253,7 @@ class LSTMGradKernel : public framework::OpKernel { lstm_value.check_og = nullptr; } - math::LstmMetaGrad lstm_grad; + phi::funcs::LstmMetaGrad lstm_grad; if (bias && bias_g) { bias_g->mutable_data(ctx.GetPlace()); @@ -270,7 +270,7 @@ class LSTMGradKernel : public framework::OpKernel { lstm_grad.check_og_grad = nullptr; } - math::LoDTensor2BatchFunctor to_batch; + phi::funcs::LoDTensor2BatchFunctor to_batch; auto ToBatch = [&batch_gate, &to_batch]( const DeviceContext& ctx, const framework::LoDTensor& src, @@ -293,11 +293,11 @@ class LSTMGradKernel : public framework::OpKernel { batch_gate_g.mutable_data(batch_gate->dims(), ctx.GetPlace()); batch_gate_g.set_lod(batch_gate->lod()); - auto gate_act = math::detail::GetActivationType( + auto gate_act = phi::funcs::detail::GetActivationType( ctx.Attr("gate_activation")); - auto cell_act = math::detail::GetActivationType( + auto cell_act = phi::funcs::detail::GetActivationType( ctx.Attr("cell_activation")); - auto cand_act = math::detail::GetActivationType( + auto cand_act = phi::funcs::detail::GetActivationType( ctx.Attr("candidate_activation")); auto batch_starts = batch_gate->lod()[0]; @@ -338,7 +338,7 @@ class LSTMGradKernel : public framework::OpKernel { lstm_grad.state_active_grad = nullptr; int cur_batch_size = bend - bstart; T cell_clip = 0.0; - math::LstmUnitGradFunctor::compute( + phi::funcs::LstmUnitGradFunctor::compute( device_ctx, lstm_value, lstm_grad, frame_size, cur_batch_size, cell_clip, gate_act, cell_act, cand_act); @@ -369,7 +369,7 @@ class LSTMGradKernel : public framework::OpKernel { } } - math::Batch2LoDTensorFunctor to_seq; + phi::funcs::Batch2LoDTensorFunctor to_seq; if (in_g) { /* backward data */ in_g->mutable_data(ctx.GetPlace()); diff --git a/paddle/fluid/operators/lstmp_op.h b/paddle/fluid/operators/lstmp_op.h index 96c074f1efb..5d24c0b70d3 100644 --- a/paddle/fluid/operators/lstmp_op.h +++ b/paddle/fluid/operators/lstmp_op.h @@ -18,12 +18,12 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/activation_op.h" -#include "paddle/fluid/operators/math/detail/activation_functions.h" -#include "paddle/fluid/operators/math/lstm_compute.h" -#include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/transform.h" #include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/detail/activation_functions.h" +#include "paddle/phi/kernels/funcs/lstm_compute.h" +#include "paddle/phi/kernels/funcs/sequence2batch.h" namespace paddle { namespace operators { @@ -72,7 +72,7 @@ inline void ReorderInitState(const DeviceContext& ctx, const framework::Tensor& src, framework::Vector index, framework::Tensor* dst, bool indexed_src) { - math::CopyMatrixRowsFunctor row_shuffle; + phi::funcs::CopyMatrixRowsFunctor row_shuffle; dst->mutable_data(src.dims(), ctx.GetPlace()); row_shuffle(ctx, src, index, dst, indexed_src); } @@ -81,15 +81,15 @@ template class LSTMPKernel : public framework::OpKernel { public: template - void ActCompute(const math::detail::ActivationType act_type, const Device& d, - X x, Y y, platform::Place place) const { - if (act_type == math::detail::ActivationType::kIdentity) { + void ActCompute(const phi::funcs::detail::ActivationType act_type, + const Device& d, X x, Y y, platform::Place place) const { + if (act_type == phi::funcs::detail::ActivationType::kIdentity) { y.device(d) = x; - } else if (act_type == math::detail::ActivationType::kSigmoid) { + } else if (act_type == phi::funcs::detail::ActivationType::kSigmoid) { SigmoidFunctor()(d, x, y); - } else if (act_type == math::detail::ActivationType::kTanh) { + } else if (act_type == phi::funcs::detail::ActivationType::kTanh) { TanhFunctor()(d, x, y); - } else if (act_type == math::detail::ActivationType::kReLU) { + } else if (act_type == phi::funcs::detail::ActivationType::kReLU) { if (place == platform::CPUPlace()) ReluCPUFunctor()(d, x, y); else @@ -120,7 +120,7 @@ class LSTMPKernel : public framework::OpKernel { cell_out->mutable_data(ctx.GetPlace()); bool is_reverse = ctx.Attr("is_reverse"); - math::LoDTensor2BatchFunctor to_batch; + phi::funcs::LoDTensor2BatchFunctor to_batch; auto& device_ctx = ctx.template device_context(); to_batch(device_ctx, *input, batch_gate, true, is_reverse); @@ -137,7 +137,7 @@ class LSTMPKernel : public framework::OpKernel { add_bias(device_ctx, *batch_gate, gate_bias, batch_gate); } - math::LstmMetaValue lstmp_value; + phi::funcs::LstmMetaValue lstmp_value; if (bias && ctx.Attr("use_peepholes")) { T* bias_data = const_cast(bias->data()); // the code style in LstmpMetaValue will be updated later. @@ -176,13 +176,13 @@ class LSTMPKernel : public framework::OpKernel { auto batch_starts = batch_gate->lod()[0]; size_t num_batch = batch_starts.size() - 1; - auto gate_act = math::detail::GetActivationType( + auto gate_act = phi::funcs::detail::GetActivationType( ctx.Attr("gate_activation")); - auto cell_act = math::detail::GetActivationType( + auto cell_act = phi::funcs::detail::GetActivationType( ctx.Attr("cell_activation")); - auto cand_act = math::detail::GetActivationType( + auto cand_act = phi::funcs::detail::GetActivationType( ctx.Attr("candidate_activation")); - auto proj_act = math::detail::GetActivationType( + auto proj_act = phi::funcs::detail::GetActivationType( ctx.Attr("proj_activation")); auto& place = *ctx.template device_context().eigen_device(); auto blas = phi::funcs::GetBlas(device_ctx); @@ -222,13 +222,13 @@ class LSTMPKernel : public framework::OpKernel { lstmp_value.output_value = hidden_t.data(); lstmp_value.state_value = cell_t.data(); lstmp_value.state_active_value = cell_pre_act_t.data(); - math::LstmUnitFunctor::compute( + phi::funcs::LstmUnitFunctor::compute( device_ctx, lstmp_value, frame_size, cur_batch_size, cell_clip, gate_act, cell_act, cand_act); lstmp_value.prev_state_value = lstmp_value.state_value; blas.MatMul(hidden_t, false, *proj_weight, false, static_cast(1.0), &proj_t, static_cast(0.0)); - if (proj_act != math::detail::ActivationType::kIdentity) { + if (proj_act != phi::funcs::detail::ActivationType::kIdentity) { auto proj_t_dev = EigenMatrix::From(proj_t); ActCompute(cell_act, place, proj_t_dev, proj_t_dev, ctx.GetPlace()); } @@ -242,7 +242,7 @@ class LSTMPKernel : public framework::OpKernel { } } - math::Batch2LoDTensorFunctor to_seq; + phi::funcs::Batch2LoDTensorFunctor to_seq; batch_proj.set_lod(batch_gate->lod()); // restore the output hidden in LoDTensor from the batch hidden to_seq(device_ctx, batch_proj, proj_out); @@ -257,16 +257,16 @@ template class LSTMPGradKernel : public framework::OpKernel { public: template - void ActGradCompute(const math::detail::ActivationType act_type, + void ActGradCompute(const phi::funcs::detail::ActivationType act_type, const Device& d, X x, Y y, DX dx, DY dy) const { // x is dummy and won't be used even in Relu(use y instead) - if (act_type == math::detail::ActivationType::kIdentity) + if (act_type == phi::funcs::detail::ActivationType::kIdentity) dx.device(d) = dy; - else if (act_type == math::detail::ActivationType::kSigmoid) + else if (act_type == phi::funcs::detail::ActivationType::kSigmoid) SigmoidGradFunctor()(d, x, y, dy, dx); - else if (act_type == math::detail::ActivationType::kTanh) + else if (act_type == phi::funcs::detail::ActivationType::kTanh) TanhGradFunctor()(d, x, y, dy, dx); - else if (act_type == math::detail::ActivationType::kReLU) + else if (act_type == phi::funcs::detail::ActivationType::kReLU) ReluGradFunctor()(d, x, y, dy, dx); else PADDLE_THROW( @@ -340,7 +340,7 @@ class LSTMPGradKernel : public framework::OpKernel { "but received %d in LSTMP@Grad operator.", frame_size, out_dims[1])); - math::LstmMetaValue lstmp_value; + phi::funcs::LstmMetaValue lstmp_value; if (bias && ctx.Attr("use_peepholes")) { T* bias_data = const_cast(bias->data()); lstmp_value.check_ig = bias_data + 4 * frame_size; @@ -352,7 +352,7 @@ class LSTMPGradKernel : public framework::OpKernel { lstmp_value.check_og = nullptr; } - math::LstmMetaGrad lstmp_grad; + phi::funcs::LstmMetaGrad lstmp_grad; if (bias && bias_g) { bias_g->mutable_data(ctx.GetPlace()); @@ -369,7 +369,7 @@ class LSTMPGradKernel : public framework::OpKernel { lstmp_grad.check_og_grad = nullptr; } - math::LoDTensor2BatchFunctor to_batch; + phi::funcs::LoDTensor2BatchFunctor to_batch; auto ToBatch = [&batch_gate, &to_batch]( const DeviceContext& ctx, const framework::LoDTensor& src, @@ -393,13 +393,13 @@ class LSTMPGradKernel : public framework::OpKernel { batch_gate_g.mutable_data(batch_gate->dims(), ctx.GetPlace()); batch_gate_g.set_lod(batch_gate->lod()); - auto gate_act = math::detail::GetActivationType( + auto gate_act = phi::funcs::detail::GetActivationType( ctx.Attr("gate_activation")); - auto cell_act = math::detail::GetActivationType( + auto cell_act = phi::funcs::detail::GetActivationType( ctx.Attr("cell_activation")); - auto cand_act = math::detail::GetActivationType( + auto cand_act = phi::funcs::detail::GetActivationType( ctx.Attr("candidate_activation")); - auto proj_act = math::detail::GetActivationType( + auto proj_act = phi::funcs::detail::GetActivationType( ctx.Attr("proj_activation")); auto& place = *ctx.template device_context().eigen_device(); @@ -423,7 +423,7 @@ class LSTMPGradKernel : public framework::OpKernel { _ClipGradFunctor(-1.0 * proj_clip, proj_clip)); } - if (proj_act != math::detail::ActivationType::kIdentity) { + if (proj_act != phi::funcs::detail::ActivationType::kIdentity) { auto cur_proj_dev = EigenMatrix::From(cur_proj); auto proj_g_dev = EigenMatrix::From(proj_g); ActGradCompute(cell_act, place, cur_proj_dev, cur_proj_dev, proj_g_dev, @@ -470,7 +470,7 @@ class LSTMPGradKernel : public framework::OpKernel { lstmp_value.output_value = nullptr; lstmp_grad.state_active_grad = nullptr; - math::LstmUnitGradFunctor::compute( + phi::funcs::LstmUnitGradFunctor::compute( device_ctx, lstmp_value, lstmp_grad, frame_size, cur_batch_size, cell_clip, gate_act, cell_act, cand_act); @@ -503,7 +503,7 @@ class LSTMPGradKernel : public framework::OpKernel { } } - math::Batch2LoDTensorFunctor to_seq; + phi::funcs::Batch2LoDTensorFunctor to_seq; if (in_g) { /* backward data */ in_g->mutable_data(ctx.GetPlace()); diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index ac6566a8703..ba047355ad7 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -1,5 +1,3 @@ -add_subdirectory(detail) - if (WITH_ASCEND_CL) cc_library(beam_search_npu SRCS beam_search_npu.cc DEPS npu_op_runner) endif() @@ -18,8 +16,7 @@ math_library(im2col) math_library(sample_prob) math_library(sampler DEPS generator) -math_library(gru_compute DEPS activation_functions math_function) -math_library(lstm_compute DEPS activation_functions) +# math_library(math_function DEPS blas dense_tensor tensor) math_library(maxouting) math_library(pooling) @@ -29,7 +26,6 @@ else() math_library(selected_rows_functor DEPS selected_rows_utils math_function blas) endif() -math_library(sequence2batch) math_library(sequence_padding) math_library(sequence_pooling DEPS math_function jit_kernel_helper) math_library(sequence_scale) diff --git a/paddle/fluid/operators/math/gru_compute.h b/paddle/fluid/operators/math/gru_compute.h deleted file mode 100644 index 70cbfecefc8..00000000000 --- a/paddle/fluid/operators/math/gru_compute.h +++ /dev/null @@ -1,80 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/fluid/operators/math/detail/activation_functions.h" -#include "paddle/fluid/platform/device_context.h" -#include "paddle/fluid/platform/enforce.h" - -namespace paddle { -namespace operators { -namespace math { - -template -struct GRUMetaValue { - const T *gate_weight; - const T *state_weight; - const T *reset_bias; - T *gate_value; - T *reset_output_value; - T *output_value; - const T *prev_out_value; -}; - -template -struct GRUMetaGrad { - T *gate_weight_grad; - T *state_weight_grad; - T *gate_grad; - T *reset_output_grad; - T *output_grad; - T *prev_out_grad; - T *bias_hh_grad; -}; - -template -struct GRUUnitFunctor { - static void compute(const DeviceContext &context, GRUMetaValue value, - int frame_size, int batch_size, - const detail::ActivationType active_node, - const detail::ActivationType active_gate, - bool origin_mode); -}; - -template -struct GRUUnitGradFunctor { - static void compute(const DeviceContext &context, GRUMetaValue value, - GRUMetaGrad grad, int frame_size, int batch_size, - const detail::ActivationType active_node, - const detail::ActivationType active_gate, - bool origin_mode); -}; - -template -struct GRUUnitFunctorV2 { - static void compute(const DeviceContext &context, GRUMetaValue value, - int frame_size, int batch_size, - const detail::ActivationType active_node, - const detail::ActivationType active_gate); -}; - -template -struct GRUUnitGradFunctorV2 { - static void compute(const DeviceContext &context, GRUMetaValue value, - GRUMetaGrad grad, int frame_size, int batch_size, - const detail::ActivationType active_node, - const detail::ActivationType active_gate); -}; - -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/lstm_compute.cc b/paddle/fluid/operators/math/lstm_compute.cc deleted file mode 100644 index aa4fe65a520..00000000000 --- a/paddle/fluid/operators/math/lstm_compute.cc +++ /dev/null @@ -1,93 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/math/lstm_compute.h" - -#include "paddle/fluid/operators/math/detail/lstm_cpu_kernel.h" -#include "paddle/fluid/operators/math/detail/lstm_kernel.h" - -namespace paddle { -namespace platform { -class CPUDeviceContext; -} // namespace platform -} // namespace paddle - -namespace paddle { -namespace operators { -namespace math { - -template -struct LstmUnitFunctor { - static void compute(const platform::CPUDeviceContext& context, - LstmMetaValue value, int frame_size, int batch_size, - T cell_clip, const detail::ActivationType& gate_act, - const detail::ActivationType& cell_act, - const detail::ActivationType& cand_act, - bool old_api_version = true) { - for (int b = 0; b < batch_size; b++) { - detail::cpu_lstm_forward(context, detail::forward::lstm(), value, - frame_size, cell_clip, cand_act, gate_act, - cell_act, old_api_version); - value.gate_value += frame_size * 4; - value.state_value += frame_size; - value.state_active_value += frame_size; - value.output_value += frame_size; - if (value.prev_state_value) { - value.prev_state_value += frame_size; - } - } - } -}; - -template -struct LstmUnitGradFunctor { - static void compute(const platform::CPUDeviceContext& context, - LstmMetaValue value, LstmMetaGrad grad, - int frame_size, int batch_size, T cell_clip, - const detail::ActivationType& gate_act, - const detail::ActivationType& cell_act, - const detail::ActivationType& cand_act, - bool old_api_version = true) { - for (int b = 0; b < batch_size; b++) { - detail::cpu_lstm_backward(context, detail::backward::lstm(), value, - grad, frame_size, cell_clip, cand_act, gate_act, - cell_act, old_api_version); - - value.gate_value += frame_size * 4; - value.state_value += frame_size; - value.state_active_value += frame_size; - value.output_value += frame_size; - if (value.prev_state_value) { - value.prev_state_value += frame_size; - } - - grad.gate_grad += frame_size * 4; - grad.state_grad += frame_size; - grad.state_active_grad += frame_size; - grad.output_grad += frame_size; - if (grad.prev_state_grad) { - grad.prev_state_grad += frame_size; - } - } - } -}; - -template class LstmUnitFunctor; -template class LstmUnitFunctor; -template class LstmUnitGradFunctor; -template class LstmUnitGradFunctor; - -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/lstm_compute.cu b/paddle/fluid/operators/math/lstm_compute.cu deleted file mode 100644 index 4342cb7b799..00000000000 --- a/paddle/fluid/operators/math/lstm_compute.cu +++ /dev/null @@ -1,59 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/math/detail/lstm_gpu_kernel.h" -#include "paddle/fluid/operators/math/detail/lstm_kernel.h" -#include "paddle/fluid/operators/math/lstm_compute.h" - -namespace paddle { -namespace operators { -namespace math { - -template -struct LstmUnitFunctor { - static void compute(const platform::CUDADeviceContext& context, - LstmMetaValue value, int frame_size, int batch_size, - T cell_clip, const detail::ActivationType& gate_act, - const detail::ActivationType& cell_act, - const detail::ActivationType& cand_act, - bool old_api_version = true) { - detail::gpu_lstm_forward(context, detail::forward::lstm(), value, - frame_size, batch_size, cell_clip, cand_act, - gate_act, cell_act); - } -}; - -template -struct LstmUnitGradFunctor { - static void compute(const platform::CUDADeviceContext& context, - LstmMetaValue value, LstmMetaGrad grad, - int frame_size, int batch_size, T cell_clip, - const detail::ActivationType& gate_act, - const detail::ActivationType& cell_act, - const detail::ActivationType& cand_act, - bool old_api_version = true) { - detail::gpu_lstm_backward(context, detail::backward::lstm(), value, grad, - frame_size, batch_size, cell_clip, cand_act, - gate_act, cell_act); - } -}; - -template class LstmUnitFunctor; -template class LstmUnitFunctor; -template class LstmUnitGradFunctor; -template class LstmUnitGradFunctor; - -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/rnn_op.h b/paddle/fluid/operators/rnn_op.h index c18570af775..b636184ae45 100644 --- a/paddle/fluid/operators/rnn_op.h +++ b/paddle/fluid/operators/rnn_op.h @@ -20,13 +20,13 @@ limitations under the License. */ #include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/operators/math/concat_and_split.h" -#include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/fc.h" -#include "paddle/fluid/operators/math/gru_compute.h" -#include "paddle/fluid/operators/math/lstm_compute.h" #include "paddle/fluid/operators/unique_op.h" #include "paddle/fluid/operators/utils.h" #include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/detail/activation_functions.h" +#include "paddle/phi/kernels/funcs/gru_compute.h" +#include "paddle/phi/kernels/funcs/lstm_compute.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { @@ -100,7 +100,7 @@ struct Cell { }; template class EigenActivationFunctor, - math::detail::ActivationType act_type> + phi::funcs::detail::ActivationType act_type> struct SimpleRNNCell : Cell { void operator()(const platform::CPUDeviceContext* device_ctx, Tensor* input, const Tensor* weight_hh, const Tensor* init_h, @@ -148,7 +148,7 @@ struct GRUCell : Cell { size_t frame_size = init_h->dims()[2]; size_t batch_size = init_h->dims()[1]; - math::GRUMetaValue gru_value; + phi::funcs::GRUMetaValue gru_value; gru_value.gate_weight = weight_hh->data(); gru_value.state_weight = weight_hh->data() + 2 * frame_size * frame_size; gru_value.reset_bias = bias_hh->data() + 2 * frame_size; @@ -158,10 +158,10 @@ struct GRUCell : Cell { gru_value.output_value = output->data(); gru_value.prev_out_value = init_h->data(); - auto gate_act = math::detail::GetActivationType("sigmoid_v2"); - auto cand_act = math::detail::GetActivationType("tanh_v2"); + auto gate_act = phi::funcs::detail::GetActivationType("sigmoid_v2"); + auto cand_act = phi::funcs::detail::GetActivationType("tanh_v2"); - math::GRUUnitFunctorV2::compute( + phi::funcs::GRUUnitFunctorV2::compute( *device_ctx, gru_value, frame_size, batch_size, cand_act, gate_act); } }; @@ -184,14 +184,14 @@ struct LSTMCell : Cell { blas.MatMul(*init_h, mat_dim_a, *weight_hh, mat_dim_b, static_cast(1.0), input, static_cast(1.0)); - math::LstmMetaValue lstm_value; + phi::funcs::LstmMetaValue lstm_value; lstm_value.check_ig = nullptr; lstm_value.check_fg = nullptr; lstm_value.check_og = nullptr; - auto gate_act = math::detail::GetActivationType("sigmoid_v2"); - auto cell_act = math::detail::GetActivationType("tanh_v2"); - auto cand_act = math::detail::GetActivationType("tanh_v2"); + auto gate_act = phi::funcs::detail::GetActivationType("sigmoid_v2"); + auto cell_act = phi::funcs::detail::GetActivationType("tanh_v2"); + auto cand_act = phi::funcs::detail::GetActivationType("tanh_v2"); size_t frame_size = init_h->dims()[2]; size_t batch_size = init_h->dims()[1]; @@ -208,7 +208,7 @@ struct LSTMCell : Cell { lstm_value.state_value = last_c->data(); lstm_value.state_active_value = last_c_act->data(); T cell_clip = 0.0; - math::LstmUnitFunctor::compute( + phi::funcs::LstmUnitFunctor::compute( *device_ctx, lstm_value, frame_size, batch_size, cell_clip, gate_act, cell_act, cand_act, false); } @@ -986,18 +986,18 @@ class RNNCPUKernel : public framework::OpKernel { seed, reserve_data); } else if (is_rnn_relu(ctx)) { gate_num = 1; - RnnFunc< - SimpleRNNCell, - Layer, SingleLayer, BidirLayer, T>( + RnnFunc, + Layer, SingleLayer, BidirLayer, T>( ctx, input, weight_list, pre_state[0], nullptr, sequence_length, state[0], nullptr, output, dropout_mask, num_layers, gate_num, input_size, hidden_size, is_bidirec, mode, dropout_prob, is_test, seed, reserve_data); } else if (is_rnn_tanh(ctx)) { gate_num = 1; - RnnFunc< - SimpleRNNCell, - Layer, SingleLayer, BidirLayer, T>( + RnnFunc, + Layer, SingleLayer, BidirLayer, T>( ctx, input, weight_list, pre_state[0], nullptr, sequence_length, state[0], nullptr, output, dropout_mask, num_layers, gate_num, input_size, hidden_size, is_bidirec, mode, dropout_prob, is_test, @@ -1014,14 +1014,14 @@ class RNNCPUKernel : public framework::OpKernel { }; template -void create_lstm_value(math::LstmMetaValue* lstm_value) { +void create_lstm_value(phi::funcs::LstmMetaValue* lstm_value) { lstm_value->check_ig = nullptr; lstm_value->check_fg = nullptr; lstm_value->check_og = nullptr; } template -void create_lstm_grad(math::LstmMetaGrad* lstm_grad) { +void create_lstm_grad(phi::funcs::LstmMetaGrad* lstm_grad) { lstm_grad->check_ig_grad = nullptr; lstm_grad->check_fg_grad = nullptr; lstm_grad->check_og_grad = nullptr; @@ -1686,8 +1686,8 @@ struct GRUGradCell : GradCell { // zero pre_hidden phi::funcs::SetConstant zero; zero(device_ctx, grad_pre_hidden, static_cast(0.0)); - math::GRUMetaValue gru_value; - math::GRUMetaGrad gru_grad; + phi::funcs::GRUMetaValue gru_value; + phi::funcs::GRUMetaGrad gru_grad; gru_value.gate_value = gate_tensor->data(); gru_value.prev_out_value = pre_hidden->data(); gru_value.reset_output_value = state_tensor->data(); @@ -1703,9 +1703,9 @@ struct GRUGradCell : GradCell { grad_weight_hh->data() + 2 * frame_size * frame_size; gru_grad.bias_hh_grad = grad_bias_hh->data(); - auto act_gate = math::detail::GetActivationType("sigmoid_v2"); - auto act_node = math::detail::GetActivationType("tanh_v2"); - math::GRUUnitGradFunctorV2::compute( + auto act_gate = phi::funcs::detail::GetActivationType("sigmoid_v2"); + auto act_node = phi::funcs::detail::GetActivationType("tanh_v2"); + phi::funcs::GRUUnitGradFunctorV2::compute( device_ctx, gru_value, gru_grad, frame_size, batch_size, act_node, act_gate); @@ -1738,8 +1738,8 @@ struct LSTMGradCell : GradCell { backup_tensor(context, &grad_pre_state_bak, grad_pre_state); } - math::LstmMetaValue lstm_value; - math::LstmMetaGrad lstm_grad; + phi::funcs::LstmMetaValue lstm_value; + phi::funcs::LstmMetaGrad lstm_grad; create_lstm_value(&lstm_value); create_lstm_grad(&lstm_grad); lstm_value.gate_value = gate_tensor->data(); @@ -1755,12 +1755,12 @@ struct LSTMGradCell : GradCell { lstm_value.output_value = nullptr; lstm_grad.state_active_grad = nullptr; - auto gate_act = math::detail::GetActivationType("sigmoid_v2"); - auto state_act = math::detail::GetActivationType("tanh_v2"); - auto cand_act = math::detail::GetActivationType("tanh_v2"); + auto gate_act = phi::funcs::detail::GetActivationType("sigmoid_v2"); + auto state_act = phi::funcs::detail::GetActivationType("tanh_v2"); + auto cand_act = phi::funcs::detail::GetActivationType("tanh_v2"); T cell_clip = 0.0; - math::LstmUnitGradFunctor::compute( + phi::funcs::LstmUnitGradFunctor::compute( device_ctx, lstm_value, lstm_grad, frame_size, batch_size, cell_clip, gate_act, state_act, cand_act, false); this->update_pre_hidden_grad( diff --git a/paddle/phi/kernels/funcs/CMakeLists.txt b/paddle/phi/kernels/funcs/CMakeLists.txt index aa4fac16920..8b8697b6df1 100644 --- a/paddle/phi/kernels/funcs/CMakeLists.txt +++ b/paddle/phi/kernels/funcs/CMakeLists.txt @@ -1,6 +1,10 @@ add_subdirectory(eigen) add_subdirectory(blas) add_subdirectory(lapack) +add_subdirectory(detail) math_library(math_function DEPS blas dense_tensor tensor) +math_library(sequence2batch) +math_library(gru_compute DEPS activation_functions math_function) +math_library(lstm_compute DEPS activation_functions) math_library(concat_and_split_functor DEPS dense_tensor) diff --git a/paddle/fluid/operators/math/detail/CMakeLists.txt b/paddle/phi/kernels/funcs/detail/CMakeLists.txt similarity index 100% rename from paddle/fluid/operators/math/detail/CMakeLists.txt rename to paddle/phi/kernels/funcs/detail/CMakeLists.txt diff --git a/paddle/fluid/operators/math/detail/activation_functions.h b/paddle/phi/kernels/funcs/detail/activation_functions.h similarity index 75% rename from paddle/fluid/operators/math/detail/activation_functions.h rename to paddle/phi/kernels/funcs/detail/activation_functions.h index 1fac60e7cb8..475557f1642 100644 --- a/paddle/fluid/operators/math/detail/activation_functions.h +++ b/paddle/phi/kernels/funcs/detail/activation_functions.h @@ -19,9 +19,8 @@ limitations under the License. */ #include "paddle/fluid/platform/cpu_info.h" #include "paddle/phi/core/hostdevice.h" -namespace paddle { -namespace operators { -namespace math { +namespace phi { +namespace funcs { namespace detail { #define SIGMOID_THRESHOLD_MIN -40.0 @@ -132,25 +131,35 @@ struct Active { #ifdef PADDLE_WITH_CUDA -static DEVICE Active::Act kActFloat[] = { - &forward::Sigmoid, &forward::SigmoidV2, - &forward::Relu, &forward::Tanh, - &forward::TanhV2, &forward::Identity}; +static DEVICE Active::Act kActFloat[] = {&forward::Sigmoid, + &forward::SigmoidV2, + &forward::Relu, + &forward::Tanh, + &forward::TanhV2, + &forward::Identity}; static DEVICE Active::ActGrad kActGradFloat[] = { - &backward::Sigmoid, &backward::Sigmoid, - &backward::Relu, &backward::Tanh, - &backward::Tanh, &backward::Identity}; - -static DEVICE Active::Act kActDouble[] = { - &forward::Sigmoid, &forward::SigmoidV2, - &forward::Relu, &forward::Tanh, - &forward::TanhV2, &forward::Identity}; + &backward::Sigmoid, + &backward::Sigmoid, + &backward::Relu, + &backward::Tanh, + &backward::Tanh, + &backward::Identity}; + +static DEVICE Active::Act kActDouble[] = {&forward::Sigmoid, + &forward::SigmoidV2, + &forward::Relu, + &forward::Tanh, + &forward::TanhV2, + &forward::Identity}; static DEVICE Active::ActGrad kActGradDouble[] = { - &backward::Sigmoid, &backward::Sigmoid, - &backward::Relu, &backward::Tanh, - &backward::Tanh, &backward::Identity}; + &backward::Sigmoid, + &backward::Sigmoid, + &backward::Relu, + &backward::Tanh, + &backward::Tanh, + &backward::Identity}; namespace forward { inline DEVICE float activation(float a, int index) { @@ -287,13 +296,19 @@ __m256 Identity(const __m256 a, const __m256 b); } // namespace avx } // namespace backward -static Active<__m256>::Act kActAvx[] = { - &forward::avx::Sigmoid, &forward::avx::SigmoidV2, &forward::avx::Relu, - &forward::avx::Tanh, &forward::avx::TanhV2, &forward::avx::Identity}; +static Active<__m256>::Act kActAvx[] = {&forward::avx::Sigmoid, + &forward::avx::SigmoidV2, + &forward::avx::Relu, + &forward::avx::Tanh, + &forward::avx::TanhV2, + &forward::avx::Identity}; -static Active<__m256>::ActGrad kActGradAvx[] = { - &backward::avx::Sigmoid, &backward::avx::Sigmoid, &backward::avx::Relu, - &backward::avx::Tanh, &backward::avx::Tanh, &backward::avx::Identity}; +static Active<__m256>::ActGrad kActGradAvx[] = {&backward::avx::Sigmoid, + &backward::avx::Sigmoid, + &backward::avx::Relu, + &backward::avx::Tanh, + &backward::avx::Tanh, + &backward::avx::Identity}; namespace forward { inline __m256 activation(__m256 a, int index) { return kActAvx[index](a); } @@ -308,6 +323,5 @@ inline __m256 activation(__m256 a, __m256 b, int index) { #endif } // namespace detail -} // namespace math -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi diff --git a/paddle/fluid/operators/math/detail/avx_functions.cc b/paddle/phi/kernels/funcs/detail/avx_functions.cc similarity index 87% rename from paddle/fluid/operators/math/detail/avx_functions.cc rename to paddle/phi/kernels/funcs/detail/avx_functions.cc index 89e2c825c24..51af97857df 100644 --- a/paddle/fluid/operators/math/detail/avx_functions.cc +++ b/paddle/phi/kernels/funcs/detail/avx_functions.cc @@ -14,12 +14,11 @@ limitations under the License. */ #ifdef __AVX__ -#include "paddle/fluid/operators/math/detail/activation_functions.h" -#include "paddle/fluid/operators/math/detail/avx_mathfun.h" +#include "paddle/phi/kernels/funcs/detail/activation_functions.h" +#include "paddle/phi/kernels/funcs/detail/avx_mathfun.h" -namespace paddle { -namespace operators { -namespace math { +namespace phi { +namespace funcs { namespace detail { __m256 Exp(__m256 a) { return exp256_ps(a); } @@ -77,8 +76,9 @@ namespace backward { namespace avx { __m256 Relu(const __m256 a, const __m256 b) { return _mm256_mul_ps( - a, _mm256_and_ps(_mm256_cmp_ps(b, _mm256_set1_ps(0.0f), _CMP_GT_OS), - _mm256_set1_ps(1.0f))); + a, + _mm256_and_ps(_mm256_cmp_ps(b, _mm256_set1_ps(0.0f), _CMP_GT_OS), + _mm256_set1_ps(1.0f))); } __m256 Sigmoid(const __m256 a, const __m256 b) { @@ -96,8 +96,7 @@ __m256 Identity(const __m256 a, const __m256 b) { return a; } } // namespace backward } // namespace detail -} // namespace math -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi #endif diff --git a/paddle/fluid/operators/math/detail/avx_mathfun.h b/paddle/phi/kernels/funcs/detail/avx_mathfun.h similarity index 99% rename from paddle/fluid/operators/math/detail/avx_mathfun.h rename to paddle/phi/kernels/funcs/detail/avx_mathfun.h index d7cf91134e4..e5e7388d51d 100644 --- a/paddle/fluid/operators/math/detail/avx_mathfun.h +++ b/paddle/phi/kernels/funcs/detail/avx_mathfun.h @@ -49,9 +49,9 @@ typedef __m256 v8sf; // vector of 8 float (avx) typedef __m256i v8si; // vector of 8 int (avx) typedef __m128i v4si; // vector of 8 int (avx) -#define _PI32AVX_CONST(Name, Val) \ - static const ALIGN32_BEG int _pi32avx_##Name[4] ALIGN32_END = {Val, Val, \ - Val, Val} +#define _PI32AVX_CONST(Name, Val) \ + static const ALIGN32_BEG int _pi32avx_##Name[4] ALIGN32_END = { \ + Val, Val, Val, Val} _PI32AVX_CONST(1, 1); _PI32AVX_CONST(inv1, ~1); diff --git a/paddle/fluid/operators/math/detail/gru_cpu_kernel.h b/paddle/phi/kernels/funcs/detail/gru_cpu_kernel.h similarity index 60% rename from paddle/fluid/operators/math/detail/gru_cpu_kernel.h rename to paddle/phi/kernels/funcs/detail/gru_cpu_kernel.h index cbbfbc321b5..cb37daa680e 100644 --- a/paddle/fluid/operators/math/detail/gru_cpu_kernel.h +++ b/paddle/phi/kernels/funcs/detail/gru_cpu_kernel.h @@ -16,24 +16,28 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/operators/activation_op.h" -#include "paddle/fluid/operators/math/detail/activation_functions.h" -#include "paddle/fluid/operators/math/gru_compute.h" +#include "paddle/phi/kernels/funcs/detail/activation_functions.h" +#include "paddle/phi/kernels/funcs/gru_compute.h" -namespace paddle { -namespace operators { -namespace math { +namespace phi { +namespace funcs { namespace detail { using Array1 = Eigen::DSizes; -template -using EigenVector = framework::EigenVector; +using EigenVector = paddle::framework::EigenVector; #if !defined(__NVCC__) && !defined(__HIPCC___) // @{ Group for GRU CPU template -void hl_naive_gru_forward_reset_output( - OpResetOutput op_reset_output, T *gate_value, T *reset_output_value, - const T *prev_output_value, int frame_size, ActivationType active_gate, - bool old_version = true, const T *reset_bias = nullptr) { +void hl_naive_gru_forward_reset_output(OpResetOutput op_reset_output, + T *gate_value, + T *reset_output_value, + const T *prev_output_value, + int frame_size, + ActivationType active_gate, + bool old_version = true, + const T *reset_bias = nullptr) { T r_value_update_gate; T r_value_reset_gate; T r_value_reset_output; @@ -59,8 +63,12 @@ void hl_naive_gru_forward_reset_output( r_prev_out = prev_output_value[i]; } - op_reset_output(&r_value_update_gate, &r_value_reset_gate, &r_prev_out, - &r_value_reset_output, active_gate, &r_reset_bias, + op_reset_output(&r_value_update_gate, + &r_value_reset_gate, + &r_prev_out, + &r_value_reset_output, + active_gate, + &r_reset_bias, old_version); update_gate[i] = r_value_update_gate; @@ -70,10 +78,14 @@ void hl_naive_gru_forward_reset_output( } template -void hl_naive_gru_forward_final_output( - OpFinalOutput op_final_output, T *gate_value, const T *prev_output_value, - T *output_value, int frame_size, ActivationType active_node, - bool origin_mode, bool old_version = true) { +void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output, + T *gate_value, + const T *prev_output_value, + T *output_value, + int frame_size, + ActivationType active_node, + bool origin_mode, + bool old_version = true) { T r_value_update_gate; T r_value_frame_state; T r_prev_out = 0; @@ -93,8 +105,12 @@ void hl_naive_gru_forward_final_output( r_prev_out = prev_output_value[i]; } - op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out, - &r_output, active_node, origin_mode); + op_final_output(&r_value_update_gate, + &r_value_frame_state, + &r_prev_out, + &r_output, + active_node, + origin_mode); frame_state[i] = r_value_frame_state; output_value[i] = r_output; @@ -103,8 +119,10 @@ void hl_naive_gru_forward_final_output( template void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output, - T *gate_value, T *reset_output_value, - const T *prev_output_value, int frame_size, + T *gate_value, + T *reset_output_value, + const T *prev_output_value, + int frame_size, ActivationType active_gate, bool old_version = true, const T *reset_bias = nullptr) { @@ -152,8 +170,12 @@ void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output, _mm256_loadu_ps((const float *)(reset_output_value + i)); } - op_reset_output(&r_value_update_gate, &r_value_reset_gate, &r_prev_out, - &r_value_reset_output, active_gate, &r_reset_bias, + op_reset_output(&r_value_update_gate, + &r_value_reset_gate, + &r_prev_out, + &r_value_reset_output, + active_gate, + &r_reset_bias, old_version); _mm256_storeu_ps(reinterpret_cast(update_gate + i), @@ -167,9 +189,13 @@ void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output, if (rest > 0) { i = n - block; - op_reset_output(&r_value_update_gate_last, &r_value_reset_gate_last, - &r_prev_out_last, &r_value_reset_output, active_gate, - &r_reset_bias, old_version); + op_reset_output(&r_value_update_gate_last, + &r_value_reset_gate_last, + &r_prev_out_last, + &r_value_reset_output, + active_gate, + &r_reset_bias, + old_version); _mm256_storeu_ps(reinterpret_cast(update_gate + i), r_value_update_gate_last); @@ -183,8 +209,10 @@ void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output, template void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output, - T *gate_value, const T *prev_output_value, - T *output_value, int frame_size, + T *gate_value, + const T *prev_output_value, + T *output_value, + int frame_size, ActivationType active_node, bool origin_mode, bool old_version = true) { @@ -226,8 +254,12 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output, r_prev_out = _mm256_loadu_ps((const float *)(prev_output_value + i)); } - op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out, - &r_output, active_node, origin_mode); + op_final_output(&r_value_update_gate, + &r_value_frame_state, + &r_prev_out, + &r_output, + active_node, + origin_mode); _mm256_storeu_ps(reinterpret_cast(frame_state + i), r_value_frame_state); @@ -236,8 +268,12 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output, if (rest > 0) { i = n - block; - op_final_output(&r_value_update_gate_last, &r_value_frame_state_last, - &r_prev_out_last, &r_output, active_node, origin_mode); + op_final_output(&r_value_update_gate_last, + &r_value_frame_state_last, + &r_prev_out_last, + &r_output, + active_node, + origin_mode); _mm256_storeu_ps(reinterpret_cast(frame_state + i), r_value_frame_state_last); @@ -248,8 +284,10 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output, } template -inline void forward_reset_outputV2(const platform::CPUDeviceContext &context, - GRUMetaValue value, int frame_size) { +inline void forward_reset_outputV2( + const paddle::platform::CPUDeviceContext &context, + phi::funcs::GRUMetaValue value, + int frame_size) { auto &place = *context.eigen_device(); auto value_reset_gate = typename EigenVector::Type(value.gate_value, Array1(frame_size)); @@ -259,17 +297,23 @@ inline void forward_reset_outputV2(const platform::CPUDeviceContext &context, value.reset_output_value, Array1(frame_size)); auto value_reset_bias = typename EigenVector::ConstType(value.reset_bias, Array1(frame_size)); - SigmoidFunctor()(place, value_reset_gate, value_reset_gate); - SigmoidFunctor()(place, value_update_gate, value_update_gate); + paddle::operators::SigmoidFunctor()( + place, value_reset_gate, value_reset_gate); + paddle::operators::SigmoidFunctor()( + place, value_update_gate, value_update_gate); value_reset_output.device(place) = (value_reset_output + value_reset_bias) * value_reset_gate; } template inline void forward_reset_output( - OpResetOutput op_reset_output, GRUMetaValue value, int frame_size, - int batch_size, ActivationType active_gate, bool old_version = true, - const platform::CPUDeviceContext *context = nullptr) { + OpResetOutput op_reset_output, + phi::funcs::GRUMetaValue value, + int frame_size, + int batch_size, + ActivationType active_gate, + bool old_version = true, + const paddle::platform::CPUDeviceContext *context = nullptr) { for (int b = 0; b < batch_size; b++) { if (!old_version) { // use eigen @@ -277,15 +321,23 @@ inline void forward_reset_output( } else { if (OpResetOutput::avx && (frame_size > static_cast(8 - 1)) && (sizeof(T) == 4)) { - hl_avx_gru_forward_reset_output( - op_reset_output, value.gate_value, value.reset_output_value, - value.prev_out_value, frame_size, active_gate, old_version, - value.reset_bias); + hl_avx_gru_forward_reset_output(op_reset_output, + value.gate_value, + value.reset_output_value, + value.prev_out_value, + frame_size, + active_gate, + old_version, + value.reset_bias); } else { - hl_naive_gru_forward_reset_output( - op_reset_output, value.gate_value, value.reset_output_value, - value.prev_out_value, frame_size, active_gate, old_version, - value.reset_bias); + hl_naive_gru_forward_reset_output(op_reset_output, + value.gate_value, + value.reset_output_value, + value.prev_out_value, + frame_size, + active_gate, + old_version, + value.reset_bias); } } value.gate_value += frame_size * 3; @@ -297,8 +349,10 @@ inline void forward_reset_output( } template -inline void forward_final_outputV2(const platform::CPUDeviceContext &context, - GRUMetaValue value, int frame_size) { +inline void forward_final_outputV2( + const paddle::platform::CPUDeviceContext &context, + phi::funcs::GRUMetaValue value, + int frame_size) { auto &place = *context.eigen_device(); auto value_update_gate = typename EigenVector::Type( value.gate_value + frame_size, Array1(frame_size)); @@ -306,7 +360,8 @@ inline void forward_final_outputV2(const platform::CPUDeviceContext &context, value.gate_value + 2 * frame_size, Array1(frame_size)); auto value_output = typename EigenVector::Type(value.output_value, Array1(frame_size)); - TanhFunctor()(place, value_frame_state, value_frame_state); + paddle::operators::TanhFunctor()( + place, value_frame_state, value_frame_state); value_output.device(place) = (static_cast(1.0) - value_update_gate) * value_frame_state; if (value.prev_out_value) { @@ -319,10 +374,14 @@ inline void forward_final_outputV2(const platform::CPUDeviceContext &context, template inline void forward_final_output( - OpFinalOutput op_final_output, GRUMetaValue value, int frame_size, - int batch_size, ActivationType active_node, bool origin_mode, + OpFinalOutput op_final_output, + phi::funcs::GRUMetaValue value, + int frame_size, + int batch_size, + ActivationType active_node, + bool origin_mode, bool old_version = true, - const platform::CPUDeviceContext *context = nullptr) { + const paddle::platform::CPUDeviceContext *context = nullptr) { for (int b = 0; b < batch_size; b++) { if (!old_version) { // eigen @@ -330,15 +389,23 @@ inline void forward_final_output( } else { if (OpFinalOutput::avx && (frame_size > static_cast(8 - 1)) && (sizeof(T) == 4)) { - hl_avx_gru_forward_final_output(op_final_output, value.gate_value, + hl_avx_gru_forward_final_output(op_final_output, + value.gate_value, value.prev_out_value, - value.output_value, frame_size, - active_node, origin_mode, old_version); + value.output_value, + frame_size, + active_node, + origin_mode, + old_version); } else { - hl_naive_gru_forward_final_output( - op_final_output, value.gate_value, value.prev_out_value, - value.output_value, frame_size, active_node, origin_mode, - old_version); + hl_naive_gru_forward_final_output(op_final_output, + value.gate_value, + value.prev_out_value, + value.output_value, + frame_size, + active_node, + origin_mode, + old_version); } } value.gate_value += frame_size * 3; @@ -350,9 +417,12 @@ inline void forward_final_output( } template -void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, - T *gate_grad, const T *prev_out_value, - T *prev_out_grad, T *output_grad, +void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, + T *gate_value, + T *gate_grad, + const T *prev_out_value, + T *prev_out_grad, + T *output_grad, int frame_size, ActivationType active_node, bool origin_mode) { @@ -379,9 +449,15 @@ void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, r_prev_out_grad = prev_out_grad[i]; } - op_state_grad(&r_update_gate_value, &r_update_gate_grad, - &r_frame_state_value, &r_frame_state_grad, &r_prev_out_value, - &r_prev_out_grad, &r_out_grad, active_node, origin_mode); + op_state_grad(&r_update_gate_value, + &r_update_gate_grad, + &r_frame_state_value, + &r_frame_state_grad, + &r_prev_out_value, + &r_prev_out_grad, + &r_out_grad, + active_node, + origin_mode); update_gate_grad[i] = r_update_gate_grad; frame_state_grad[i] = r_frame_state_grad; @@ -392,9 +468,12 @@ void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, } template -void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, - T *gate_grad, const T *prev_out_value, - T *prev_out_grad, T *reset_output_grad, +void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, + T *gate_value, + T *gate_grad, + const T *prev_out_value, + T *prev_out_grad, + T *reset_output_grad, int frame_size, ActivationType active_gate) { T r_update_gate_value; @@ -424,9 +503,14 @@ void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, r_prev_out_grad = prev_out_grad[i]; } - op_reset_grad(&r_update_gate_value, &r_update_gate_grad, - &r_reset_gate_value, &r_reset_gate_grad, &r_prev_out_value, - &r_prev_out_grad, &r_reset_output_grad, active_gate); + op_reset_grad(&r_update_gate_value, + &r_update_gate_grad, + &r_reset_gate_value, + &r_reset_gate_grad, + &r_prev_out_value, + &r_prev_out_grad, + &r_reset_output_grad, + active_gate); update_gate_grad[i] = r_update_gate_grad; reset_gate_grad[i] = r_reset_gate_grad; @@ -437,10 +521,14 @@ void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, } template -void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, - T *gate_grad, const T *prev_out_value, - T *prev_out_grad, T *output_grad, - int frame_size, ActivationType active_node, +void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, + T *gate_value, + T *gate_grad, + const T *prev_out_value, + T *prev_out_grad, + T *output_grad, + int frame_size, + ActivationType active_node, bool origin_mode) { #ifdef __AVX__ __m256 r_update_gate_value; @@ -468,9 +556,15 @@ void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, r_prev_out_grad = (reinterpret_cast<__m256 *>(prev_out_grad))[i]; } - op_state_grad(&r_update_gate_value, &r_update_gate_grad, - &r_frame_state_value, &r_frame_state_grad, &r_prev_out_value, - &r_prev_out_grad, &r_out_grad, active_node, origin_mode); + op_state_grad(&r_update_gate_value, + &r_update_gate_grad, + &r_frame_state_value, + &r_frame_state_grad, + &r_prev_out_value, + &r_prev_out_grad, + &r_out_grad, + active_node, + origin_mode); update_gate_grad[i] = r_update_gate_grad; frame_state_grad[i] = r_frame_state_grad; @@ -482,9 +576,12 @@ void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, } template -void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, - T *gate_grad, const T *prev_out_value, - T *prev_out_grad, T *reset_output_grad, +void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, + T *gate_value, + T *gate_grad, + const T *prev_out_value, + T *prev_out_grad, + T *reset_output_grad, int frame_size, ActivationType active_gate) { #ifdef __AVX__ @@ -516,9 +613,14 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, r_prev_out_grad = (reinterpret_cast<__m256 *>(prev_out_grad))[i]; } - op_reset_grad(&r_update_gate_value, &r_update_gate_grad, - &r_reset_gate_value, &r_reset_gate_grad, &r_prev_out_value, - &r_prev_out_grad, &r_reset_output_grad, active_gate); + op_reset_grad(&r_update_gate_value, + &r_update_gate_grad, + &r_reset_gate_value, + &r_reset_gate_grad, + &r_prev_out_value, + &r_prev_out_grad, + &r_reset_output_grad, + active_gate); update_gate_grad[i] = r_update_gate_grad; reset_gate_grad[i] = r_reset_gate_grad; @@ -530,11 +632,16 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, } template -inline void hl_naive_gru_backward(OpGruGrad op_gru_grad, T *gate_value, - T *gate_grad, const T *prev_out_value, - T *prev_out_grad, T *reset_output_value, - T *reset_output_grad, T *output_grad, - int frame_size, ActivationType active_node, +inline void hl_naive_gru_backward(OpGruGrad op_gru_grad, + T *gate_value, + T *gate_grad, + const T *prev_out_value, + T *prev_out_grad, + T *reset_output_value, + T *reset_output_grad, + T *output_grad, + int frame_size, + ActivationType active_node, ActivationType active_gate) { T r_value_reset_gate; T r_grad_reset_gate; @@ -573,10 +680,18 @@ inline void hl_naive_gru_backward(OpGruGrad op_gru_grad, T *gate_value, r_grad_reset_output = reset_output_grad[i]; } - op_gru_grad(&r_value_reset_gate, &r_grad_reset_gate, &r_value_update_gate, - &r_grad_update_gate, &r_value_frame_state, &r_grad_frame_state, - &r_value_prev_out, &r_grad_prev_out, &r_grad_output, - &r_value_reset_output, &r_grad_reset_output, active_node, + op_gru_grad(&r_value_reset_gate, + &r_grad_reset_gate, + &r_value_update_gate, + &r_grad_update_gate, + &r_value_frame_state, + &r_grad_frame_state, + &r_value_prev_out, + &r_grad_prev_out, + &r_grad_output, + &r_value_reset_output, + &r_grad_reset_output, + active_node, active_gate); reset_gate_grad[i] = r_grad_reset_gate; @@ -592,11 +707,16 @@ inline void hl_naive_gru_backward(OpGruGrad op_gru_grad, T *gate_value, } template -inline void hl_avx_gru_backward(OpGruGrad op_gru_grad, T *gate_value, - T *gate_grad, const T *prev_out_value, - T *prev_out_grad, T *reset_output_value, - T *reset_output_grad, T *output_grad, - int frame_size, ActivationType active_node, +inline void hl_avx_gru_backward(OpGruGrad op_gru_grad, + T *gate_value, + T *gate_grad, + const T *prev_out_value, + T *prev_out_grad, + T *reset_output_value, + T *reset_output_grad, + T *output_grad, + int frame_size, + ActivationType active_node, ActivationType active_gate) { #ifdef __AVX__ __m256 r_value_reset_gate; @@ -639,10 +759,18 @@ inline void hl_avx_gru_backward(OpGruGrad op_gru_grad, T *gate_value, r_grad_reset_output = (reinterpret_cast<__m256 *>(reset_output_grad))[i]; } - op_gru_grad(&r_value_reset_gate, &r_grad_reset_gate, &r_value_update_gate, - &r_grad_update_gate, &r_value_frame_state, &r_grad_frame_state, - &r_value_prev_out, &r_grad_prev_out, &r_grad_output, - &r_value_reset_output, &r_grad_reset_output, active_node, + op_gru_grad(&r_value_reset_gate, + &r_grad_reset_gate, + &r_value_update_gate, + &r_grad_update_gate, + &r_value_frame_state, + &r_grad_frame_state, + &r_value_prev_out, + &r_grad_prev_out, + &r_grad_output, + &r_value_reset_output, + &r_grad_reset_output, + active_node, active_gate); reset_gate_grad[i] = r_grad_reset_gate; @@ -660,20 +788,33 @@ inline void hl_avx_gru_backward(OpGruGrad op_gru_grad, T *gate_value, template inline void backward_state_grad(OpStateGrad op_state_grad, - GRUMetaValue value, GRUMetaGrad grad, - int frame_size, int batch_size, - ActivationType active_node, bool origin_mode) { + phi::funcs::GRUMetaValue value, + phi::funcs::GRUMetaGrad grad, + int frame_size, + int batch_size, + ActivationType active_node, + bool origin_mode) { for (int b = 0; b < batch_size; b++) { if (OpStateGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) { - hl_avx_gru_backward_state_grad(op_state_grad, value.gate_value, - grad.gate_grad, value.prev_out_value, - grad.prev_out_grad, grad.output_grad, - frame_size, active_node, origin_mode); + hl_avx_gru_backward_state_grad(op_state_grad, + value.gate_value, + grad.gate_grad, + value.prev_out_value, + grad.prev_out_grad, + grad.output_grad, + frame_size, + active_node, + origin_mode); } else { - hl_naive_gru_backward_state_grad(op_state_grad, value.gate_value, - grad.gate_grad, value.prev_out_value, - grad.prev_out_grad, grad.output_grad, - frame_size, active_node, origin_mode); + hl_naive_gru_backward_state_grad(op_state_grad, + value.gate_value, + grad.gate_grad, + value.prev_out_value, + grad.prev_out_grad, + grad.output_grad, + frame_size, + active_node, + origin_mode); } value.gate_value += frame_size * 3; @@ -691,18 +832,30 @@ inline void backward_state_grad(OpStateGrad op_state_grad, template inline void backward_reset_grad(OpResetGrad op_reset_grad, - GRUMetaValue value, GRUMetaGrad grad, - int frame_size, int batch_size, + phi::funcs::GRUMetaValue value, + phi::funcs::GRUMetaGrad grad, + int frame_size, + int batch_size, ActivationType active_gate) { for (int b = 0; b < batch_size; b++) { if (OpResetGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) { - hl_avx_gru_backward_reset_grad( - op_reset_grad, value.gate_value, grad.gate_grad, value.prev_out_value, - grad.prev_out_grad, grad.reset_output_grad, frame_size, active_gate); + hl_avx_gru_backward_reset_grad(op_reset_grad, + value.gate_value, + grad.gate_grad, + value.prev_out_value, + grad.prev_out_grad, + grad.reset_output_grad, + frame_size, + active_gate); } else { - hl_naive_gru_backward_reset_grad( - op_reset_grad, value.gate_value, grad.gate_grad, value.prev_out_value, - grad.prev_out_grad, grad.reset_output_grad, frame_size, active_gate); + hl_naive_gru_backward_reset_grad(op_reset_grad, + value.gate_value, + grad.gate_grad, + value.prev_out_value, + grad.prev_out_grad, + grad.reset_output_grad, + frame_size, + active_gate); } value.gate_value += frame_size * 3; @@ -719,8 +872,9 @@ inline void backward_reset_grad(OpResetGrad op_reset_grad, } template -inline void gru_backward(const platform::CPUDeviceContext &context, - GRUMetaValue value, GRUMetaGrad grad, +inline void gru_backward(const paddle::platform::CPUDeviceContext &context, + phi::funcs::GRUMetaValue value, + phi::funcs::GRUMetaGrad grad, int frame_size) { auto &place = *context.eigen_device(); @@ -747,13 +901,19 @@ inline void gru_backward(const platform::CPUDeviceContext &context, if (value.prev_out_value) { auto value_prev_out = typename EigenVector::ConstType( value.prev_out_value, Array1(frame_size)); - SigmoidGradFunctor()(place, 1 /*useless*/, value_update_gate, - (value_prev_out - value_frame_state) * grad_output, - grad_update_gate); + paddle::operators::SigmoidGradFunctor()( + place, + 1 /*useless*/, + value_update_gate, + (value_prev_out - value_frame_state) * grad_output, + grad_update_gate); } else { - SigmoidGradFunctor()( - place, 1 /*useless*/, value_update_gate, - static_cast(-1) * value_frame_state * grad_output, grad_update_gate); + paddle::operators::SigmoidGradFunctor()( + place, + 1 /*useless*/, + value_update_gate, + static_cast(-1) * value_frame_state * grad_output, + grad_update_gate); } if (grad.prev_out_grad) { auto grad_prev_out = @@ -761,11 +921,16 @@ inline void gru_backward(const platform::CPUDeviceContext &context, grad_prev_out.device(place) = grad_prev_out + grad_output * value_update_gate; } - TanhGradFunctor()(place, 1 /*useless*/, value_frame_state, - grad_output * (static_cast(1.0) - value_update_gate), - grad_frame_state); - SigmoidGradFunctor()( - place, 1 /*useless*/, value_reset_gate, + paddle::operators::TanhGradFunctor()( + place, + 1 /*useless*/, + value_frame_state, + grad_output * (static_cast(1.0) - value_update_gate), + grad_frame_state); + paddle::operators::SigmoidGradFunctor()( + place, + 1 /*useless*/, + value_reset_gate, value_reset_output / value_reset_gate * grad_frame_state, grad_reset_gate); if (value.prev_out_value && grad.prev_out_grad) { @@ -774,10 +939,13 @@ inline void gru_backward(const platform::CPUDeviceContext &context, } template -inline void cpu_gru_backward(const platform::CPUDeviceContext &context, - OpGruGrad op_gru_grad, GRUMetaValue value, - GRUMetaGrad grad, int frame_size, - int batch_size, ActivationType active_node, +inline void cpu_gru_backward(const paddle::platform::CPUDeviceContext &context, + OpGruGrad op_gru_grad, + phi::funcs::GRUMetaValue value, + phi::funcs::GRUMetaGrad grad, + int frame_size, + int batch_size, + ActivationType active_node, ActivationType active_gate) { for (int b = 0; b < batch_size; ++b) { // eigen @@ -801,6 +969,5 @@ inline void cpu_gru_backward(const platform::CPUDeviceContext &context, #endif // @} End Group for GRU CPU } // namespace detail -} // namespace math -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi diff --git a/paddle/fluid/operators/math/detail/gru_gpu_kernel.h b/paddle/phi/kernels/funcs/detail/gru_gpu_kernel.h similarity index 74% rename from paddle/fluid/operators/math/detail/gru_gpu_kernel.h rename to paddle/phi/kernels/funcs/detail/gru_gpu_kernel.h index 75d4809a462..6657417beac 100644 --- a/paddle/fluid/operators/math/detail/gru_gpu_kernel.h +++ b/paddle/phi/kernels/funcs/detail/gru_gpu_kernel.h @@ -14,14 +14,13 @@ limitations under the License. */ #pragma once #include -#include "paddle/fluid/operators/math/detail/activation_functions.h" -#include "paddle/fluid/operators/math/gru_compute.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/fluid/platform/device_context.h" +#include "paddle/phi/kernels/funcs/detail/activation_functions.h" +#include "paddle/phi/kernels/funcs/gru_compute.h" -namespace paddle { -namespace operators { -namespace math { +namespace phi { +namespace funcs { namespace detail { /* @@ -30,9 +29,11 @@ namespace detail { */ template __global__ void KeGruForwardResetOutput(OpResetOutput op_reset_output, - T *gate_value, T *reset_output_value, + T *gate_value, + T *reset_output_value, const T *prev_output_value, - int frame_size, int batch_size, + int frame_size, + int batch_size, ActivationType active_gate) { const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; if (frame_idx >= frame_size) return; @@ -55,8 +56,11 @@ __global__ void KeGruForwardResetOutput(OpResetOutput op_reset_output, r_prev_out = prev_output_value[frame_idx]; } - op_reset_output(&r_value_update_gate, &r_value_reset_gate, &r_prev_out, - &r_value_reset_output, active_gate); + op_reset_output(&r_value_update_gate, + &r_value_reset_gate, + &r_prev_out, + &r_value_reset_output, + active_gate); gate_value[frame_idx + frame_size * 0] = r_value_update_gate; gate_value[frame_idx + frame_size * 1] = r_value_reset_gate; @@ -68,10 +72,14 @@ __global__ void KeGruForwardResetOutput(OpResetOutput op_reset_output, * grid(frame_blocks, batch_blocks) */ template -__global__ void KeGruForwardFinalOutput( - OpFinalOutput op_final_output, T *gate_value, const T *prev_output_value, - T *output_value, int frame_size, int batch_size, ActivationType active_node, - bool origin_mode) { +__global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output, + T *gate_value, + const T *prev_output_value, + T *output_value, + int frame_size, + int batch_size, + ActivationType active_node, + bool origin_mode) { const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; if (frame_idx >= frame_size) return; int batch_idx = 0; @@ -92,8 +100,12 @@ __global__ void KeGruForwardFinalOutput( r_prev_out = prev_output_value[frame_idx]; } - op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out, - &r_output, active_node, origin_mode); + op_final_output(&r_value_update_gate, + &r_value_frame_state, + &r_prev_out, + &r_output, + active_node, + origin_mode); gate_value[frame_idx + frame_size * 2] = r_value_frame_state; output_value[frame_idx] = r_output; @@ -106,7 +118,8 @@ __global__ void KeGruForwardFinalOutput( template __global__ void KeFastCollectiveGruGate(T *gate_value, const T *prev_output_value, - const T *gate_weight, T *reset_output, + const T *gate_weight, + T *reset_output, int frame_size, ActivationType active_node) { T xt_0 = 0.0f; @@ -164,9 +177,12 @@ __global__ void KeFastCollectiveGruGate(T *gate_value, */ template __global__ void KeFastCollectiveGruOut(const T *gate_weight, - const T *prev_out_value, T *output_value, - T *gate_value, T *reset_value, - int frame_size, ActivationType act_node, + const T *prev_out_value, + T *output_value, + T *gate_value, + T *reset_value, + int frame_size, + ActivationType act_node, bool origin_mode) { int COL = blockIdx.x * blockDim.x + threadIdx.x; @@ -221,10 +237,14 @@ __global__ void KeFastCollectiveGruOut(const T *gate_weight, * grid(frame_blocks, batch_blocks) */ template -__global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value, - T *gate_grad, const T *prev_out_value, - T *prev_out_grad, T *output_grad, - int frame_size, int batch_size, +__global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, + T *gate_value, + T *gate_grad, + const T *prev_out_value, + T *prev_out_grad, + T *output_grad, + int frame_size, + int batch_size, ActivationType active_node, bool origin_mode) { const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; @@ -254,9 +274,15 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value, r_prev_out_grad = prev_out_grad[frame_idx]; } - op_state_grad(&r_update_gate_value, &r_update_gate_grad, &r_frame_state_value, - &r_frame_state_grad, &r_prev_out_value, &r_prev_out_grad, - &r_out_grad, active_node, origin_mode); + op_state_grad(&r_update_gate_value, + &r_update_gate_grad, + &r_frame_state_value, + &r_frame_state_grad, + &r_prev_out_value, + &r_prev_out_grad, + &r_out_grad, + active_node, + origin_mode); gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad; gate_grad[frame_idx + frame_size * 2] = r_frame_state_grad; @@ -270,10 +296,14 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value, * grid(frame_blocks, batch_blocks) */ template -__global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value, - T *gate_grad, const T *prev_out_value, - T *prev_out_grad, T *reset_output_grad, - int frame_size, int batch_size, +__global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, + T *gate_value, + T *gate_grad, + const T *prev_out_value, + T *prev_out_grad, + T *reset_output_grad, + int frame_size, + int batch_size, ActivationType active_gate) { const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; if (frame_idx >= frame_size) return; @@ -302,9 +332,14 @@ __global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value, r_reset_output_grad = reset_output_grad[frame_idx]; } - op_reset_grad(&r_update_gate_value, &r_update_gate_grad, &r_reset_gate_value, - &r_reset_gate_grad, &r_prev_out_value, &r_prev_out_grad, - &r_reset_output_grad, active_gate); + op_reset_grad(&r_update_gate_value, + &r_update_gate_grad, + &r_reset_gate_value, + &r_reset_gate_grad, + &r_prev_out_value, + &r_prev_out_grad, + &r_reset_output_grad, + active_gate); gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad; gate_grad[frame_idx + frame_size * 1] = r_reset_gate_grad; @@ -313,6 +348,5 @@ __global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value, } } } // namespace detail -} // namespace math -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi diff --git a/paddle/fluid/operators/math/detail/gru_kernel.h b/paddle/phi/kernels/funcs/detail/gru_kernel.h similarity index 64% rename from paddle/fluid/operators/math/detail/gru_kernel.h rename to paddle/phi/kernels/funcs/detail/gru_kernel.h index 082c2a180da..db53fc4576d 100644 --- a/paddle/fluid/operators/math/detail/gru_kernel.h +++ b/paddle/phi/kernels/funcs/detail/gru_kernel.h @@ -14,13 +14,12 @@ limitations under the License. */ #pragma once #include -#include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/phi/core/hostdevice.h" +#include "paddle/phi/kernels/funcs/detail/activation_functions.h" // TODO(guosheng): refine code style in gru_kernel -namespace paddle { -namespace operators { -namespace math { +namespace phi { +namespace funcs { namespace detail { namespace forward { @@ -28,8 +27,10 @@ namespace forward { template class gru_resetOutput { public: - HOSTDEVICE void operator()(T *value_update_gate, T *value_reset_gate, - T *prev_out, T *value_reset_output, + HOSTDEVICE void operator()(T *value_update_gate, + T *value_reset_gate, + T *prev_out, + T *value_reset_output, ActivationType act_gate, T *value_reset_bias = nullptr, bool old_version = true) { @@ -48,7 +49,8 @@ class gru_resetOutput { #else static const bool avx = true; HOSTDEVICE void operator()(__m256 *value_update_gate, - __m256 *value_reset_gate, __m256 *prev_out, + __m256 *value_reset_gate, + __m256 *prev_out, __m256 *value_reset_output, ActivationType act_gate, __m256 *value_reset_bias = nullptr, @@ -71,9 +73,12 @@ class gru_resetOutput { template class gru_finalOutput { public: - HOSTDEVICE void operator()(T *value_update_gate, T *value_frame_state, - T *prev_out, T *value_output, - ActivationType act_input, bool origin_mode) { + HOSTDEVICE void operator()(T *value_update_gate, + T *value_frame_state, + T *prev_out, + T *value_output, + ActivationType act_input, + bool origin_mode) { *value_frame_state = activation(*value_frame_state, act_input); if (origin_mode) { *value_output = ((*value_update_gate) * (*prev_out)) + @@ -90,8 +95,10 @@ class gru_finalOutput { #else static const bool avx = true; HOSTDEVICE void operator()(__m256 *value_update_gate, - __m256 *value_frame_state, __m256 *prev_out, - __m256 *value_output, ActivationType act_input, + __m256 *value_frame_state, + __m256 *prev_out, + __m256 *value_output, + ActivationType act_input, bool origin_mode) { *value_frame_state = activation(*value_frame_state, act_input); if (origin_mode) { @@ -116,10 +123,14 @@ namespace backward { template class gru_stateGrad { public: - HOSTDEVICE void operator()(T *value_update_gate, T *grad_update_gate, - T *value_frame_state, T *grad_frame_state, - T *value_prev_out, T *grad_prev_out, - T *grad_output, ActivationType act_input, + HOSTDEVICE void operator()(T *value_update_gate, + T *grad_update_gate, + T *value_frame_state, + T *grad_frame_state, + T *value_prev_out, + T *grad_prev_out, + T *grad_output, + ActivationType act_input, bool origin_mode) { if (origin_mode) { *grad_update_gate = @@ -127,14 +138,15 @@ class gru_stateGrad { *grad_prev_out += (*grad_output * (*value_update_gate)); *grad_frame_state = activation( *grad_output * (static_cast(1.0) - (*value_update_gate)), - *value_frame_state, act_input); + *value_frame_state, + act_input); } else { *grad_update_gate = (*grad_output) * ((*value_frame_state) - (*value_prev_out)); *grad_prev_out += (*grad_output * (static_cast(1.0) - *value_update_gate)); - *grad_frame_state = activation(*grad_output * (*value_update_gate), - *value_frame_state, act_input); + *grad_frame_state = activation( + *grad_output * (*value_update_gate), *value_frame_state, act_input); } } #if !defined(__NVCC__) && !defined(__HIPCC___) // @{ Group GRU state grad @@ -145,28 +157,35 @@ class gru_stateGrad { HOSTDEVICE void operator()(__m256 *value_update_gate, __m256 *grad_update_gate, __m256 *value_frame_state, - __m256 *grad_frame_state, __m256 *value_prev_out, - __m256 *grad_prev_out, __m256 *grad_output, - ActivationType act_input, bool origin_mode) { + __m256 *grad_frame_state, + __m256 *value_prev_out, + __m256 *grad_prev_out, + __m256 *grad_output, + ActivationType act_input, + bool origin_mode) { if (origin_mode) { *grad_update_gate = _mm256_mul_ps( *grad_output, _mm256_sub_ps(*value_prev_out, *value_frame_state)); *grad_prev_out = _mm256_add_ps( *grad_prev_out, _mm256_mul_ps(*grad_output, *value_update_gate)); *grad_frame_state = activation( - _mm256_mul_ps(*grad_output, _mm256_sub_ps(_mm256_set1_ps(1.0f), - *value_update_gate)), - *value_frame_state, act_input); + _mm256_mul_ps( + *grad_output, + _mm256_sub_ps(_mm256_set1_ps(1.0f), *value_update_gate)), + *value_frame_state, + act_input); } else { *grad_update_gate = _mm256_mul_ps( *grad_output, _mm256_sub_ps(*value_frame_state, *value_prev_out)); *grad_prev_out = _mm256_add_ps( *grad_prev_out, - _mm256_mul_ps(*grad_output, _mm256_sub_ps(_mm256_set1_ps(1.0f), - *value_update_gate))); + _mm256_mul_ps( + *grad_output, + _mm256_sub_ps(_mm256_set1_ps(1.0f), *value_update_gate))); *grad_frame_state = activation(_mm256_mul_ps(*grad_output, *value_update_gate), - *value_frame_state, act_input); + *value_frame_state, + act_input); } } #endif @@ -176,10 +195,14 @@ class gru_stateGrad { template class gru_resetGrad { public: - HOSTDEVICE void operator()(T *value_update_gate, T *grad_update_gate, - T *value_reset_gate, T *grad_reset_gate, - T *value_prev_out, T *grad_prev_out, - T *grad_reset_output, ActivationType act_gate) { + HOSTDEVICE void operator()(T *value_update_gate, + T *grad_update_gate, + T *value_reset_gate, + T *grad_reset_gate, + T *value_prev_out, + T *grad_prev_out, + T *grad_reset_output, + ActivationType act_gate) { *grad_reset_gate = (*grad_reset_output * (*value_prev_out)); *grad_prev_out += (*grad_reset_output * (*value_reset_gate)); *grad_update_gate = @@ -193,9 +216,12 @@ class gru_resetGrad { #else static const bool avx = true; HOSTDEVICE void operator()(__m256 *value_update_gate, - __m256 *grad_update_gate, __m256 *value_reset_gate, - __m256 *grad_reset_gate, __m256 *value_prev_out, - __m256 *grad_prev_out, __m256 *grad_reset_output, + __m256 *grad_update_gate, + __m256 *value_reset_gate, + __m256 *grad_reset_gate, + __m256 *value_prev_out, + __m256 *grad_prev_out, + __m256 *grad_reset_output, ActivationType act_gate) { *grad_reset_gate = _mm256_mul_ps(*grad_reset_output, *value_prev_out); *grad_prev_out = _mm256_add_ps( @@ -211,23 +237,31 @@ class gru_resetGrad { template class gru { public: - HOSTDEVICE void operator()(T *value_reset_gate, T *grad_reset_gate, - T *value_update_gate, T *grad_update_gate, - T *value_frame_state, T *grad_frame_state, - T *value_prev_out, T *grad_prev_out, - T *grad_output, T *value_reset_output, - T *grad_reset_output, ActivationType act_node, + HOSTDEVICE void operator()(T *value_reset_gate, + T *grad_reset_gate, + T *value_update_gate, + T *grad_update_gate, + T *value_frame_state, + T *grad_frame_state, + T *value_prev_out, + T *grad_prev_out, + T *grad_output, + T *value_reset_output, + T *grad_reset_output, + ActivationType act_node, ActivationType act_gate) { *grad_update_gate = activation((*grad_output) * ((*value_prev_out) - (*value_frame_state)), - (*value_update_gate), act_gate); + (*value_update_gate), + act_gate); *grad_prev_out += (*grad_output * (*value_update_gate)); *grad_frame_state = activation(*grad_output * (static_cast(1.0) - (*value_update_gate)), - *value_frame_state, act_node); + *value_frame_state, + act_node); T reset_output = (*value_reset_output) / (*value_reset_gate); - *grad_reset_gate = activation(reset_output * (*grad_frame_state), - *value_reset_gate, act_gate); + *grad_reset_gate = activation( + reset_output * (*grad_frame_state), *value_reset_gate, act_gate); *grad_reset_output = (*value_reset_gate) * (*grad_frame_state); } #if !defined(__NVCC__) && !defined(__HIPCC___) // @{ Group GRU CPU @@ -235,29 +269,36 @@ class gru { static const bool avx = false; #else static const bool avx = true; - HOSTDEVICE void operator()(__m256 *value_reset_gate, __m256 *grad_reset_gate, + HOSTDEVICE void operator()(__m256 *value_reset_gate, + __m256 *grad_reset_gate, __m256 *value_update_gate, __m256 *grad_update_gate, __m256 *value_frame_state, - __m256 *grad_frame_state, __m256 *value_prev_out, - __m256 *grad_prev_out, __m256 *grad_output, + __m256 *grad_frame_state, + __m256 *value_prev_out, + __m256 *grad_prev_out, + __m256 *grad_output, __m256 *value_reset_output, - __m256 *grad_reset_output, ActivationType act_node, + __m256 *grad_reset_output, + ActivationType act_node, ActivationType act_gate) { *grad_update_gate = activation( _mm256_mul_ps(*grad_output, _mm256_sub_ps(*value_prev_out, *value_frame_state)), - *value_update_gate, act_gate); + *value_update_gate, + act_gate); *grad_prev_out = _mm256_add_ps( *grad_prev_out, _mm256_mul_ps(*grad_output, *value_update_gate)); *grad_frame_state = activation( _mm256_mul_ps(*grad_output, _mm256_sub_ps(_mm256_set1_ps(1.0f), *value_update_gate)), - *value_frame_state, act_node); + *value_frame_state, + act_node); __m256 reset_output = _mm256_div_ps(*value_reset_output, *value_reset_gate); *grad_reset_gate = activation(_mm256_mul_ps(reset_output, *grad_frame_state), - *value_reset_gate, act_gate); + *value_reset_gate, + act_gate); *grad_reset_output = _mm256_mul_ps(*value_reset_gate, *grad_frame_state); } #endif @@ -267,6 +308,5 @@ class gru { } // namespace backward } // namespace detail -} // namespace math -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi diff --git a/paddle/fluid/operators/math/detail/lstm_cpu_kernel.h b/paddle/phi/kernels/funcs/detail/lstm_cpu_kernel.h similarity index 65% rename from paddle/fluid/operators/math/detail/lstm_cpu_kernel.h rename to paddle/phi/kernels/funcs/detail/lstm_cpu_kernel.h index 169c5488bb5..10dbf27d348 100644 --- a/paddle/fluid/operators/math/detail/lstm_cpu_kernel.h +++ b/paddle/phi/kernels/funcs/detail/lstm_cpu_kernel.h @@ -16,8 +16,8 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/operators/activation_op.h" -#include "paddle/fluid/operators/math/detail/activation_functions.h" -#include "paddle/fluid/operators/math/lstm_compute.h" +#include "paddle/phi/kernels/funcs/detail/activation_functions.h" +#include "paddle/phi/kernels/funcs/lstm_compute.h" #if defined(_WIN32) #if defined(__AVX2__) || defined(__AVX__) @@ -25,21 +25,23 @@ inline __m256 operator+=(__m256 a, __m256 b) { return _mm256_add_ps(a, b); } #endif #endif -namespace paddle { -namespace operators { -namespace math { +namespace phi { +namespace funcs { namespace detail { using Array1 = Eigen::DSizes; -template -using EigenVector = framework::EigenVector; +using EigenVector = paddle::framework::EigenVector; #if !defined(__NVCC__) && !defined(__HIPCC___) // @{ Group LSTM CPU template -void naive_lstm_forward_one_sequence(Op op, LstmMetaValue value, - int frame_size, T cell_clip, +void naive_lstm_forward_one_sequence(Op op, + phi::funcs::LstmMetaValue value, + int frame_size, + T cell_clip, ActivationType active_node, ActivationType active_gate, ActivationType active_state, @@ -79,9 +81,21 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue value, r_prev_state = value.prev_state_value[i]; } - op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state, - &r_state, &r_state_atv, &r_out, &r_checkI, &r_checkF, &r_checkO, - &cell_clip, active_node, active_gate, active_state); + op(&r_value_in, + &r_value_ig, + &r_value_fg, + &r_value_og, + &r_prev_state, + &r_state, + &r_state_atv, + &r_out, + &r_checkI, + &r_checkF, + &r_checkO, + &cell_clip, + active_node, + active_gate, + active_state); value_in[i] = r_value_in; value_ig[i] = r_value_ig; @@ -94,9 +108,12 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue value, } template -void naive_lstm_backward_one_sequence(Op op, LstmMetaValue value, - LstmMetaGrad grad, int frame_size, - T cell_clip, ActivationType active_node, +void naive_lstm_backward_one_sequence(Op op, + phi::funcs::LstmMetaValue value, + phi::funcs::LstmMetaGrad grad, + int frame_size, + T cell_clip, + ActivationType active_node, ActivationType active_gate, ActivationType active_state, bool old_api_version) { @@ -157,11 +174,30 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue value, r_prev_state = value.prev_state_value[i]; } - op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_grad_in, - &r_grad_ig, &r_grad_fg, &r_grad_og, &r_prev_state, &r_prev_state_grad, - &r_state, &r_state_grad, &r_state_atv, &r_output_grad, &r_checkI, - &r_checkF, &r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad, - &cell_clip, active_node, active_gate, active_state); + op(&r_value_in, + &r_value_ig, + &r_value_fg, + &r_value_og, + &r_grad_in, + &r_grad_ig, + &r_grad_fg, + &r_grad_og, + &r_prev_state, + &r_prev_state_grad, + &r_state, + &r_state_grad, + &r_state_atv, + &r_output_grad, + &r_checkI, + &r_checkF, + &r_checkO, + &r_checkIGrad, + &r_checkFGrad, + &r_checkOGrad, + &cell_clip, + active_node, + active_gate, + active_state); grad_in[i] = r_grad_in; grad_ig[i] = r_grad_ig; @@ -179,8 +215,10 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue value, } template -void avx_lstm_forward_one_sequence(Op op, LstmMetaValue value, - int frame_size, T cell_clip, +void avx_lstm_forward_one_sequence(Op op, + phi::funcs::LstmMetaValue value, + int frame_size, + T cell_clip, ActivationType active_node, ActivationType active_gate, ActivationType active_state, @@ -226,9 +264,21 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue value, (reinterpret_cast<__m256 const *>(value.prev_state_value))[i]; } - op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state, - &r_state, &r_state_atv, &r_out, &r_checkI, &r_checkF, &r_checkO, - &cell_clip, active_node, active_gate, active_state); + op(&r_value_in, + &r_value_ig, + &r_value_fg, + &r_value_og, + &r_prev_state, + &r_state, + &r_state_atv, + &r_out, + &r_checkI, + &r_checkF, + &r_checkO, + &cell_clip, + active_node, + active_gate, + active_state); value_in[i] = r_value_in; value_ig[i] = r_value_ig; @@ -242,9 +292,12 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue value, } template -void avx_lstm_backward_one_sequence(Op op, LstmMetaValue value, - LstmMetaGrad grad, int frame_size, - T cell_clip, ActivationType active_node, +void avx_lstm_backward_one_sequence(Op op, + phi::funcs::LstmMetaValue value, + phi::funcs::LstmMetaGrad grad, + int frame_size, + T cell_clip, + ActivationType active_node, ActivationType active_gate, ActivationType active_state, bool old_api_version) { @@ -311,11 +364,30 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue value, (reinterpret_cast<__m256 const *>(value.prev_state_value))[i]; } - op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_grad_in, - &r_grad_ig, &r_grad_fg, &r_grad_og, &r_prev_state, &r_prev_state_grad, - &r_state, &r_state_grad, &r_state_atv, &r_output_grad, &r_checkI, - &r_checkF, &r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad, - &cell_clip, active_node, active_gate, active_state); + op(&r_value_in, + &r_value_ig, + &r_value_fg, + &r_value_og, + &r_grad_in, + &r_grad_ig, + &r_grad_fg, + &r_grad_og, + &r_prev_state, + &r_prev_state_grad, + &r_state, + &r_state_grad, + &r_state_atv, + &r_output_grad, + &r_checkI, + &r_checkF, + &r_checkO, + &r_checkIGrad, + &r_checkFGrad, + &r_checkOGrad, + &cell_clip, + active_node, + active_gate, + active_state); grad_in[i] = r_grad_in; grad_ig[i] = r_grad_ig; @@ -338,8 +410,10 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue value, } template -void eigen_lstm_forward_one_sequence(const platform::CPUDeviceContext &context, - LstmMetaValue value, int frame_size) { +void eigen_lstm_forward_one_sequence( + const paddle::platform::CPUDeviceContext &context, + phi::funcs::LstmMetaValue value, + int frame_size) { auto eigen_value_ig = typename EigenVector::Type(value.gate_value, Array1(frame_size)); auto eigen_value_fg = typename EigenVector::Type( @@ -356,10 +430,10 @@ void eigen_lstm_forward_one_sequence(const platform::CPUDeviceContext &context, typename EigenVector::Type(value.output_value, Array1(frame_size)); auto &place = *context.eigen_device(); - TanhFunctor()(place, eigen_value_in, eigen_value_in); - SigmoidFunctor()(place, eigen_value_ig, eigen_value_ig); - SigmoidFunctor()(place, eigen_value_fg, eigen_value_fg); - SigmoidFunctor()(place, eigen_value_og, eigen_value_og); + paddle::operators::TanhFunctor()(place, eigen_value_in, eigen_value_in); + paddle::operators::SigmoidFunctor()(place, eigen_value_ig, eigen_value_ig); + paddle::operators::SigmoidFunctor()(place, eigen_value_fg, eigen_value_fg); + paddle::operators::SigmoidFunctor()(place, eigen_value_og, eigen_value_og); eigen_state.device(place) = eigen_value_in * eigen_value_ig; if (value.prev_state_value) { @@ -368,14 +442,16 @@ void eigen_lstm_forward_one_sequence(const platform::CPUDeviceContext &context, eigen_state.device(place) = eigen_state + eigen_prev_state * eigen_value_fg; } - TanhFunctor()(place, eigen_state, eigen_state_act); + paddle::operators::TanhFunctor()(place, eigen_state, eigen_state_act); eigen_output.device(place) = eigen_value_og * eigen_state_act; } template -void eigen_lstm_backward_one_sequence(const platform::CPUDeviceContext &context, - LstmMetaValue value, - LstmMetaGrad grad, int frame_size) { +void eigen_lstm_backward_one_sequence( + const paddle::platform::CPUDeviceContext &context, + phi::funcs::LstmMetaValue value, + phi::funcs::LstmMetaGrad grad, + int frame_size) { auto eigen_value_ig = typename EigenVector::Type(value.gate_value, Array1(frame_size)); auto eigen_value_fg = typename EigenVector::Type( @@ -401,23 +477,38 @@ void eigen_lstm_backward_one_sequence(const platform::CPUDeviceContext &context, typename EigenVector::Type(grad.state_grad, Array1(frame_size)); auto &place = *context.eigen_device(); - SigmoidGradFunctor()(place, 1 /*useless*/, eigen_value_og, - eigen_grad_output * eigen_state_act, eigen_grad_og); + paddle::operators::SigmoidGradFunctor()( + place, + 1 /*useless*/, + eigen_value_og, + eigen_grad_output * eigen_state_act, + eigen_grad_og); eigen_grad_state.device(place) = eigen_grad_state + eigen_grad_output * eigen_value_og * (static_cast(1) - eigen_state_act * eigen_state_act); - TanhGradFunctor()(place, 1, eigen_value_in, - eigen_grad_state * eigen_value_ig, eigen_grad_in); - SigmoidGradFunctor()(place, 1, eigen_value_ig, - eigen_grad_state * eigen_value_in, eigen_grad_ig); + paddle::operators::TanhGradFunctor()(place, + 1, + eigen_value_in, + eigen_grad_state * eigen_value_ig, + eigen_grad_in); + paddle::operators::SigmoidGradFunctor()(place, + 1, + eigen_value_ig, + eigen_grad_state * eigen_value_in, + eigen_grad_ig); if (value.prev_state_value) { auto eigen_prev_state = typename EigenVector::ConstType( value.prev_state_value, Array1(frame_size)); - SigmoidGradFunctor()(place, 1, eigen_value_fg, - eigen_grad_state * eigen_prev_state, eigen_grad_fg); + paddle::operators::SigmoidGradFunctor()( + place, + 1, + eigen_value_fg, + eigen_grad_state * eigen_prev_state, + eigen_grad_fg); } else { - SigmoidGradFunctor()(place, 1, eigen_value_fg, 0, eigen_grad_fg); + paddle::operators::SigmoidGradFunctor()( + place, 1, eigen_value_fg, 0, eigen_grad_fg); } if (grad.prev_state_grad) { auto eigen_grad_pre_state = @@ -427,42 +518,74 @@ void eigen_lstm_backward_one_sequence(const platform::CPUDeviceContext &context, } template -void cpu_lstm_forward(const platform::CPUDeviceContext &context, Op op, - LstmMetaValue value, int frame_size, T cell_clip, - ActivationType active_node, ActivationType active_gate, - ActivationType active_state, bool old_api_version) { +void cpu_lstm_forward(const paddle::platform::CPUDeviceContext &context, + Op op, + phi::funcs::LstmMetaValue value, + int frame_size, + T cell_clip, + ActivationType active_node, + ActivationType active_gate, + ActivationType active_state, + bool old_api_version) { if (!old_api_version) { eigen_lstm_forward_one_sequence(context, value, frame_size); } else { if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same::value)) { - avx_lstm_forward_one_sequence(op, value, frame_size, cell_clip, - active_node, active_gate, active_state, + avx_lstm_forward_one_sequence(op, + value, + frame_size, + cell_clip, + active_node, + active_gate, + active_state, old_api_version); } else { - naive_lstm_forward_one_sequence(op, value, frame_size, cell_clip, - active_node, active_gate, active_state, + naive_lstm_forward_one_sequence(op, + value, + frame_size, + cell_clip, + active_node, + active_gate, + active_state, old_api_version); } } } template -void cpu_lstm_backward(const platform::CPUDeviceContext &context, Op op, - LstmMetaValue value, LstmMetaGrad grad, - int frame_size, T cell_clip, ActivationType active_node, - ActivationType active_gate, ActivationType active_state, +void cpu_lstm_backward(const paddle::platform::CPUDeviceContext &context, + Op op, + phi::funcs::LstmMetaValue value, + phi::funcs::LstmMetaGrad grad, + int frame_size, + T cell_clip, + ActivationType active_node, + ActivationType active_gate, + ActivationType active_state, bool old_api_version) { if (!old_api_version) { eigen_lstm_backward_one_sequence(context, value, grad, frame_size); } else { if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same::value)) { - avx_lstm_backward_one_sequence(op, value, grad, frame_size, cell_clip, - active_node, active_gate, active_state, + avx_lstm_backward_one_sequence(op, + value, + grad, + frame_size, + cell_clip, + active_node, + active_gate, + active_state, old_api_version); } else { - naive_lstm_backward_one_sequence(op, value, grad, frame_size, - cell_clip, active_node, active_gate, - active_state, old_api_version); + naive_lstm_backward_one_sequence(op, + value, + grad, + frame_size, + cell_clip, + active_node, + active_gate, + active_state, + old_api_version); } } } @@ -470,6 +593,5 @@ void cpu_lstm_backward(const platform::CPUDeviceContext &context, Op op, #endif // @{ End Group LSTM CPU } // namespace detail -} // namespace math -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi diff --git a/paddle/fluid/operators/math/detail/lstm_gpu_kernel.h b/paddle/phi/kernels/funcs/detail/lstm_gpu_kernel.h similarity index 68% rename from paddle/fluid/operators/math/detail/lstm_gpu_kernel.h rename to paddle/phi/kernels/funcs/detail/lstm_gpu_kernel.h index 851a62dbe9a..6d4c430d9e6 100644 --- a/paddle/fluid/operators/math/detail/lstm_gpu_kernel.h +++ b/paddle/phi/kernels/funcs/detail/lstm_gpu_kernel.h @@ -15,14 +15,13 @@ limitations under the License. */ #pragma once #include -#include "paddle/fluid/operators/math/detail/activation_functions.h" -#include "paddle/fluid/operators/math/lstm_compute.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/fluid/platform/device_context.h" +#include "paddle/phi/kernels/funcs/detail/activation_functions.h" +#include "paddle/phi/kernels/funcs/lstm_compute.h" -namespace paddle { -namespace operators { -namespace math { +namespace phi { +namespace funcs { namespace detail { /* @@ -30,8 +29,11 @@ namespace detail { * grid(frame_blocks, batch_blocks) */ template -__global__ void KeLstmForward(Op op, LstmMetaValue value, int frame_size, - int batch_size, T cell_clip, +__global__ void KeLstmForward(Op op, + phi::funcs::LstmMetaValue value, + int frame_size, + int batch_size, + T cell_clip, ActivationType active_node, ActivationType active_gate, ActivationType active_state) { @@ -71,9 +73,21 @@ __global__ void KeLstmForward(Op op, LstmMetaValue value, int frame_size, r_prev_state = value.prev_state_value[frame_idx]; } - op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state, - &r_state, &r_state_atv, &r_out, &r_checkI, &r_checkF, &r_checkO, - &cell_clip, active_node, active_gate, active_state); + op(&r_value_in, + &r_value_ig, + &r_value_fg, + &r_value_og, + &r_prev_state, + &r_state, + &r_state_atv, + &r_out, + &r_checkI, + &r_checkF, + &r_checkO, + &cell_clip, + active_node, + active_gate, + active_state); value.gate_value[frame_idx] = r_value_in; value.gate_value[frame_idx + frame_size] = r_value_ig; @@ -90,9 +104,12 @@ __global__ void KeLstmForward(Op op, LstmMetaValue value, int frame_size, * grid(frame_blocks, batch_blocks) */ template -__global__ void KeLstmBackward(Op op, LstmMetaValue value, - LstmMetaGrad grad, int frame_size, - int batch_size, T cell_clip, +__global__ void KeLstmBackward(Op op, + phi::funcs::LstmMetaValue value, + phi::funcs::LstmMetaGrad grad, + int frame_size, + int batch_size, + T cell_clip, ActivationType active_node, ActivationType active_gate, ActivationType active_state) { @@ -147,11 +164,30 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue value, r_prev_state = value.prev_state_value[frame_idx]; } - op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_grad_in, &r_grad_ig, - &r_grad_fg, &r_grad_og, &r_prev_state, &r_prev_state_grad, &r_state, - &r_state_grad, &r_state_atv, &r_output_grad, &r_checkI, &r_checkF, - &r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad, &cell_clip, - active_node, active_gate, active_state); + op(&r_value_in, + &r_value_ig, + &r_value_fg, + &r_value_og, + &r_grad_in, + &r_grad_ig, + &r_grad_fg, + &r_grad_og, + &r_prev_state, + &r_prev_state_grad, + &r_state, + &r_state_grad, + &r_state_atv, + &r_output_grad, + &r_checkI, + &r_checkF, + &r_checkO, + &r_checkIGrad, + &r_checkFGrad, + &r_checkOGrad, + &cell_clip, + active_node, + active_gate, + active_state); grad.gate_grad[frame_idx] = r_grad_in; grad.gate_grad[frame_idx + frame_size] = r_grad_ig; @@ -185,10 +221,15 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue value, } template -void gpu_lstm_forward(const platform::DeviceContext& context, Op op, - LstmMetaValue value, int frame_size, int batch_size, - T cell_clip, ActivationType active_node, - ActivationType active_gate, ActivationType active_state) { +void gpu_lstm_forward(const paddle::platform::DeviceContext& context, + Op op, + phi::funcs::LstmMetaValue value, + int frame_size, + int batch_size, + T cell_clip, + ActivationType active_node, + ActivationType active_gate, + ActivationType active_state) { dim3 threads; dim3 grid; if (batch_size == 1) { @@ -203,25 +244,45 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op, } auto stream = - reinterpret_cast(context).stream(); + reinterpret_cast(context) + .stream(); if (batch_size == 1) { - KeLstmForward<<>>( - op, value, frame_size, batch_size, cell_clip, active_node, active_gate, + op, + value, + frame_size, + batch_size, + cell_clip, + active_node, + active_gate, active_state); } else { - KeLstmForward<<>>( - op, value, frame_size, batch_size, cell_clip, active_node, active_gate, + op, + value, + frame_size, + batch_size, + cell_clip, + active_node, + active_gate, active_state); } } template -void gpu_lstm_backward(const platform::DeviceContext& context, Op op, - LstmMetaValue value, LstmMetaGrad grad, - int frame_size, int batch_size, T cell_clip, - ActivationType active_node, ActivationType active_gate, +void gpu_lstm_backward(const paddle::platform::DeviceContext& context, + Op op, + phi::funcs::LstmMetaValue value, + phi::funcs::LstmMetaGrad grad, + int frame_size, + int batch_size, + T cell_clip, + ActivationType active_node, + ActivationType active_gate, ActivationType active_state) { dim3 threads; dim3 grid; @@ -237,21 +298,37 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op, } auto stream = - reinterpret_cast(context).stream(); + reinterpret_cast(context) + .stream(); if (batch_size == 1) { - KeLstmBackward<<>>( - op, value, grad, frame_size, batch_size, cell_clip, active_node, - active_gate, active_state); + op, + value, + grad, + frame_size, + batch_size, + cell_clip, + active_node, + active_gate, + active_state); } else { - KeLstmBackward<<>>( - op, value, grad, frame_size, batch_size, cell_clip, active_node, - active_gate, active_state); + op, + value, + grad, + frame_size, + batch_size, + cell_clip, + active_node, + active_gate, + active_state); } } } // namespace detail -} // namespace math -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi diff --git a/paddle/fluid/operators/math/detail/lstm_kernel.h b/paddle/phi/kernels/funcs/detail/lstm_kernel.h similarity index 59% rename from paddle/fluid/operators/math/detail/lstm_kernel.h rename to paddle/phi/kernels/funcs/detail/lstm_kernel.h index 2d4e7dd59fb..8b429264125 100644 --- a/paddle/fluid/operators/math/detail/lstm_kernel.h +++ b/paddle/phi/kernels/funcs/detail/lstm_kernel.h @@ -14,12 +14,11 @@ limitations under the License. */ #pragma once #include -#include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/phi/core/hostdevice.h" +#include "paddle/phi/kernels/funcs/detail/activation_functions.h" -namespace paddle { -namespace operators { -namespace math { +namespace phi { +namespace funcs { namespace detail { namespace forward { @@ -27,9 +26,18 @@ namespace forward { template class lstm { public: - HOSTDEVICE void operator()(T *value_in, T *value_ig, T *value_fg, T *value_og, - T *prev_state, T *state, T *state_atv, T *output, - T *checkI, T *checkF, T *checkO, T *cell_clip, + HOSTDEVICE void operator()(T *value_in, + T *value_ig, + T *value_fg, + T *value_og, + T *prev_state, + T *state, + T *state_atv, + T *output, + T *checkI, + T *checkF, + T *checkO, + T *cell_clip, ActivationType active_node, ActivationType active_gate, ActivationType active_state) { @@ -57,11 +65,18 @@ class lstm { // Only float support AVX optimization static const bool avx = std::is_same::value; - HOSTDEVICE void operator()(__m256 *value_in, __m256 *value_ig, - __m256 *value_fg, __m256 *value_og, - __m256 *prev_state, __m256 *state, - __m256 *state_atv, __m256 *output, __m256 *checkI, - __m256 *checkF, __m256 *checkO, T *cell_clip, + HOSTDEVICE void operator()(__m256 *value_in, + __m256 *value_ig, + __m256 *value_fg, + __m256 *value_og, + __m256 *prev_state, + __m256 *state, + __m256 *state_atv, + __m256 *output, + __m256 *checkI, + __m256 *checkF, + __m256 *checkO, + T *cell_clip, ActivationType active_node, ActivationType active_gate, ActivationType active_state) { @@ -97,12 +112,27 @@ namespace backward { template class lstm { public: - HOSTDEVICE void operator()(T *value_in, T *value_ig, T *value_fg, T *value_og, - T *grad_in, T *grad_ig, T *grad_fg, T *grad_og, - T *prev_state, T *prev_state_grad, T *state, - T *state_grad, T *state_atv, T *output_grad, - T *checkI, T *checkF, T *checkO, T *checkIGrad, - T *checkFGrad, T *checkOGrad, T *cell_clip, + HOSTDEVICE void operator()(T *value_in, + T *value_ig, + T *value_fg, + T *value_og, + T *grad_in, + T *grad_ig, + T *grad_fg, + T *grad_og, + T *prev_state, + T *prev_state_grad, + T *state, + T *state_grad, + T *state_atv, + T *output_grad, + T *checkI, + T *checkF, + T *checkO, + T *checkIGrad, + T *checkFGrad, + T *checkOGrad, + T *cell_clip, ActivationType active_node, ActivationType active_gate, ActivationType active_state) { @@ -138,17 +168,32 @@ class lstm { #else // Only float support AVX optimization static const bool avx = std::is_same::value; - HOSTDEVICE void operator()( - __m256 *value_in, __m256 *value_ig, __m256 *value_fg, __m256 *value_og, - __m256 *grad_in, __m256 *grad_ig, __m256 *grad_fg, __m256 *grad_og, - __m256 *prev_state, __m256 *prev_state_grad, __m256 *state, - __m256 *state_grad, __m256 *state_atv, __m256 *output_grad, - __m256 *checkI, __m256 *checkF, __m256 *checkO, __m256 *checkIGrad, - __m256 *checkFGrad, __m256 *checkOGrad, T *cell_clip, - ActivationType active_node, ActivationType active_gate, - ActivationType active_state) { - *grad_og = activation(_mm256_mul_ps(*output_grad, *state_atv), *value_og, - active_gate); + HOSTDEVICE void operator()(__m256 *value_in, + __m256 *value_ig, + __m256 *value_fg, + __m256 *value_og, + __m256 *grad_in, + __m256 *grad_ig, + __m256 *grad_fg, + __m256 *grad_og, + __m256 *prev_state, + __m256 *prev_state_grad, + __m256 *state, + __m256 *state_grad, + __m256 *state_atv, + __m256 *output_grad, + __m256 *checkI, + __m256 *checkF, + __m256 *checkO, + __m256 *checkIGrad, + __m256 *checkFGrad, + __m256 *checkOGrad, + T *cell_clip, + ActivationType active_node, + ActivationType active_gate, + ActivationType active_state) { + *grad_og = activation( + _mm256_mul_ps(*output_grad, *state_atv), *value_og, active_gate); if (*cell_clip > 0.0f) { T *state_ = reinterpret_cast(state); if (*state_ >= (*cell_clip) || *state_ <= (0.0f - (*cell_clip))) { @@ -156,18 +201,19 @@ class lstm { } else { *state_grad = _mm256_add_ps(activation(_mm256_mul_ps(*output_grad, *value_og), - *state_atv, active_state), + *state_atv, + active_state), *state_grad); *state_grad = _mm256_add_ps(_mm256_mul_ps(*grad_og, *checkO), *state_grad); } } - *grad_in = activation(_mm256_mul_ps(*state_grad, *value_ig), *value_in, - active_node); - *grad_ig = activation(_mm256_mul_ps(*state_grad, *value_in), *value_ig, - active_gate); - *grad_fg = activation(_mm256_mul_ps(*state_grad, *prev_state), *value_fg, - active_gate); + *grad_in = activation( + _mm256_mul_ps(*state_grad, *value_ig), *value_in, active_node); + *grad_ig = activation( + _mm256_mul_ps(*state_grad, *value_in), *value_ig, active_gate); + *grad_fg = activation( + _mm256_mul_ps(*state_grad, *prev_state), *value_fg, active_gate); *prev_state_grad = _mm256_add_ps(_mm256_mul_ps(*grad_ig, *checkI), _mm256_mul_ps(*grad_fg, *checkF)); *prev_state_grad = @@ -183,6 +229,5 @@ class lstm { } // namespace backward } // namespace detail -} // namespace math -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/funcs/gru_compute.cc b/paddle/phi/kernels/funcs/gru_compute.cc new file mode 100644 index 00000000000..4f159fd28af --- /dev/null +++ b/paddle/phi/kernels/funcs/gru_compute.cc @@ -0,0 +1,373 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/funcs/gru_compute.h" + +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/detail/gru_cpu_kernel.h" +#include "paddle/phi/kernels/funcs/detail/gru_kernel.h" + +namespace phi { +namespace funcs { + +template +struct GRUUnitFunctor { + static void compute(const paddle::platform::CPUDeviceContext &context, + GRUMetaValue value, + int frame_size, + int batch_size, + const phi::funcs::detail::ActivationType active_node, + const phi::funcs::detail::ActivationType active_gate, + bool origin_mode) { +#if !defined(__NVCC__) && !defined(__HIPCC___) + auto blas = + phi::funcs::GetBlas(context); + if (value.prev_out_value) { + blas.GEMM(false, + false, + batch_size, + frame_size * 2, + frame_size, + 1, + value.prev_out_value, + frame_size, + value.gate_weight, + frame_size * 2, + 1, + value.gate_value, + frame_size * 3); + } + + detail::forward_reset_output( + phi::funcs::detail::forward::gru_resetOutput(), + value, + frame_size, + batch_size, + active_gate, + true, + nullptr); + + if (value.prev_out_value) { + blas.GEMM(false, + false, + batch_size, + frame_size, + frame_size, + 1, + value.reset_output_value, + frame_size, + value.state_weight, + frame_size, + 1, + value.gate_value + frame_size * 2, + frame_size * 3); + } + + detail::forward_final_output( + phi::funcs::detail::forward::gru_finalOutput(), + value, + frame_size, + batch_size, + active_node, + origin_mode, + true, + nullptr); +#endif + } +}; + +template +struct GRUUnitGradFunctor { + static void compute(const paddle::platform::CPUDeviceContext &context, + GRUMetaValue value, + GRUMetaGrad grad, + int frame_size, + int batch_size, + const phi::funcs::detail::ActivationType active_node, + const phi::funcs::detail::ActivationType active_gate, + bool origin_mode) { +#if !defined(__NVCC__) && !defined(__HIPCC___) + detail::backward_state_grad( + phi::funcs::detail::backward::gru_stateGrad(), + value, + grad, + frame_size, + batch_size, + active_node, + origin_mode); + auto blas = + phi::funcs::GetBlas(context); + if (value.prev_out_value && grad.prev_out_grad) { + blas.GEMM(false, + true, + batch_size, + frame_size, + frame_size, + 1, + grad.gate_grad + frame_size * 2, + frame_size * 3, + value.state_weight, + frame_size, + 0, + grad.reset_output_grad, + frame_size); + + if (grad.state_weight_grad) { + blas.GEMM(true, + false, + frame_size, + frame_size, + batch_size, + 1, + value.reset_output_value, + frame_size, + grad.gate_grad + frame_size * 2, + frame_size * 3, + 1, + grad.state_weight_grad, + frame_size); + } + } + + detail::backward_reset_grad( + phi::funcs::detail::backward::gru_resetGrad(), + value, + grad, + frame_size, + batch_size, + active_gate); + if (grad.prev_out_grad && value.prev_out_value) { + blas.GEMM(false, + true, + batch_size, + frame_size, + frame_size * 2, + 1, + grad.gate_grad, + frame_size * 3, + value.gate_weight, + frame_size * 2, + 1, + grad.prev_out_grad, + frame_size); + + if (grad.gate_weight_grad) { + blas.GEMM(true, + false, + frame_size, + frame_size * 2, + batch_size, + 1, + value.prev_out_value, + frame_size, + grad.gate_grad, + frame_size * 3, + 1, + grad.gate_weight_grad, + frame_size * 2); + } + } +#endif + } +}; + +template +struct GRUUnitFunctorV2 { + static void compute(const paddle::platform::CPUDeviceContext &context, + GRUMetaValue value, + int frame_size, + int batch_size, + const phi::funcs::detail::ActivationType active_node, + const phi::funcs::detail::ActivationType active_gate) { +#if !defined(__NVCC__) && !defined(__HIPCC___) + auto blas = + phi::funcs::GetBlas(context); + if (value.prev_out_value) { + blas.GEMM(CblasNoTrans, + CblasTrans, + batch_size, + frame_size, + frame_size, + 1, + value.prev_out_value, + value.state_weight, + 0, + value.reset_output_value); + } + detail::forward_reset_output( + phi::funcs::detail::forward::gru_resetOutput(), + value, + frame_size, + batch_size, + active_gate, + false, + &context); + + T *cell_state_value = value.gate_value + 2 * frame_size; + T *reset_output_value = value.reset_output_value; + for (int b = 0; b < batch_size; ++b) { + blas.VADD( + frame_size, cell_state_value, reset_output_value, cell_state_value); + cell_state_value += frame_size * 3; + reset_output_value += frame_size; + } + + detail::forward_final_output( + phi::funcs::detail::forward::gru_finalOutput(), + value, + frame_size, + batch_size, + active_node, + true, + false, + &context); +#endif + } +}; + +template +struct GRUUnitGradFunctorV2 { + static void compute(const paddle::platform::CPUDeviceContext &context, + GRUMetaValue value, + GRUMetaGrad grad, + int frame_size, + int batch_size, + const phi::funcs::detail::ActivationType active_node, + const phi::funcs::detail::ActivationType active_gate) { +#if !defined(__NVCC__) && !defined(__HIPCC___) + // calculate grad_update_gate, grad_frame_state, + // grad_reset_output, grad_reset_gate + detail::cpu_gru_backward(context, + phi::funcs::detail::backward::gru(), + value, + grad, + frame_size, + batch_size, + active_node, + active_gate); + auto blas = + phi::funcs::GetBlas(context); + if (grad.prev_out_grad && value.prev_out_value) { + // update prev_out_grad + blas.GEMM(false, + false, + batch_size, + frame_size, + frame_size, + 1, + grad.gate_grad, + frame_size * 3, + value.gate_weight, + frame_size, + 1, + grad.prev_out_grad, + frame_size); + blas.GEMM(false, + false, + batch_size, + frame_size, + frame_size, + 1, + grad.gate_grad + frame_size, + frame_size * 3, + value.gate_weight + frame_size * frame_size, + frame_size, + 1, + grad.prev_out_grad, + frame_size); + blas.GEMM(false, + false, + batch_size, + frame_size, + frame_size, + 1, + grad.reset_output_grad, + frame_size, + value.state_weight, + frame_size, + 1, + grad.prev_out_grad, + frame_size); + // update weight_hh_grad + if (grad.gate_weight_grad) { + // reset gate + blas.GEMM(true, + false, + frame_size, + frame_size, + batch_size, + 1, + grad.gate_grad, + frame_size * 3, + value.prev_out_value, + frame_size, + 1, + grad.gate_weight_grad, + frame_size); + // update gate + blas.GEMM(true, + false, + frame_size, + frame_size, + batch_size, + 1, + grad.gate_grad + frame_size, + frame_size * 3, + value.prev_out_value, + frame_size, + 1, + grad.gate_weight_grad + frame_size * frame_size, + frame_size); + // cell state + blas.GEMM(true, + false, + frame_size, + frame_size, + batch_size, + 1, + grad.reset_output_grad, + frame_size, + value.prev_out_value, + frame_size, + 1, + grad.state_weight_grad, + frame_size); + } + } + // update bias_hh_grad + T *gate_grad = grad.gate_grad; + T *bias_hh_grad = grad.bias_hh_grad; + T *state_bias_grad = grad.bias_hh_grad + 2 * frame_size; + T *reset_output_grad = grad.reset_output_grad; + for (int b = 0; b < batch_size; ++b) { + blas.VADD(2 * frame_size, bias_hh_grad, gate_grad, bias_hh_grad); + blas.VADD( + frame_size, state_bias_grad, reset_output_grad, state_bias_grad); + gate_grad += 3 * frame_size; + reset_output_grad += frame_size; + } +#endif + } +}; + +template struct GRUUnitFunctor; +template struct GRUUnitFunctor; +template struct GRUUnitGradFunctor; +template struct GRUUnitGradFunctor; + +template struct GRUUnitFunctorV2; +template struct GRUUnitFunctorV2; +template struct GRUUnitGradFunctorV2; +template struct GRUUnitGradFunctorV2; + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/funcs/gru_compute.cu b/paddle/phi/kernels/funcs/gru_compute.cu new file mode 100644 index 00000000000..7666206b7f7 --- /dev/null +++ b/paddle/phi/kernels/funcs/gru_compute.cu @@ -0,0 +1,349 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/detail/gru_gpu_kernel.h" +#include "paddle/phi/kernels/funcs/detail/gru_kernel.h" +#include "paddle/phi/kernels/funcs/gru_compute.h" + +namespace phi { +namespace funcs { + +template +struct GRUUnitFunctor { + static void compute(const paddle::platform::CUDADeviceContext &context, + GRUMetaValue value, + int frame_size, + int batch_size, + const phi::funcs::detail::ActivationType active_node, + const phi::funcs::detail::ActivationType active_gate, + bool origin_mode) { + auto stream = context.stream(); + dim3 threads; + dim3 grid; + if (batch_size == 1) { + if (context.GetComputeCapability() >= 70) { + if (frame_size < 16) { + constexpr int tiled_size = 8; + int frame_blocks = (frame_size * 2 + tiled_size - 1) / tiled_size; + threads = dim3(tiled_size, 1); + grid = dim3(frame_blocks, 1); + detail::KeFastCollectiveGruGate< + T, + tiled_size><<>>( + value.gate_value, + value.prev_out_value, + value.gate_weight, + value.reset_output_value, + frame_size, + active_gate); + + frame_blocks = (frame_size + tiled_size - 1) / tiled_size; + grid = dim3(frame_blocks, 1); + detail::KeFastCollectiveGruOut< + T, + tiled_size><<>>( + value.state_weight, + value.prev_out_value, + value.output_value, + value.gate_value, + value.reset_output_value, + frame_size, + active_node, + origin_mode); + } else { + constexpr int tiled_size = 16; + int frame_blocks = (frame_size * 2 + tiled_size - 1) / tiled_size; + threads = dim3(tiled_size, 1); + grid = dim3(frame_blocks, 1); + detail::KeFastCollectiveGruGate< + T, + tiled_size><<>>( + value.gate_value, + value.prev_out_value, + value.gate_weight, + value.reset_output_value, + frame_size, + active_gate); + + frame_blocks = (frame_size + tiled_size - 1) / tiled_size; + grid = dim3(frame_blocks, 1); + detail::KeFastCollectiveGruOut< + T, + tiled_size><<>>( + value.state_weight, + value.prev_out_value, + value.output_value, + value.gate_value, + value.reset_output_value, + frame_size, + active_node, + origin_mode); + } + return; + } else { + int frame_per_block = frame_size <= 1024 ? frame_size : 1024; + int frame_blocks = (frame_size + 1024 - 1) / 1024; + threads = dim3(frame_per_block, 1); + grid = dim3(frame_blocks, 1); + } + } else { + threads = dim3(32, 32); + grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32); + } + auto blas = + phi::funcs::GetBlas(context); + if (value.prev_out_value) { + blas.GEMM(false, + false, + batch_size, + frame_size * 2, + frame_size, + 1, + value.prev_out_value, + frame_size, + value.gate_weight, + frame_size * 2, + 1, + value.gate_value, + frame_size * 3); + } + + if (batch_size == 1) { + detail::KeGruForwardResetOutput< + phi::funcs::detail::forward::gru_resetOutput, + /* is_batch= */ false, + T><<>>( + phi::funcs::detail::forward::gru_resetOutput(), + value.gate_value, + value.reset_output_value, + value.prev_out_value, + frame_size, + batch_size, + active_gate); + } else { + detail::KeGruForwardResetOutput< + phi::funcs::detail::forward::gru_resetOutput, + /* is_batch= */ true, + T><<>>( + phi::funcs::detail::forward::gru_resetOutput(), + value.gate_value, + value.reset_output_value, + value.prev_out_value, + frame_size, + batch_size, + active_gate); + } + + if (value.prev_out_value) { + blas.GEMM(false, + false, + batch_size, + frame_size, + frame_size, + 1, + value.reset_output_value, + frame_size, + value.state_weight, + frame_size, + 1, + value.gate_value + frame_size * 2, + frame_size * 3); + } + + if (batch_size == 1) { + detail::KeGruForwardFinalOutput< + phi::funcs::detail::forward::gru_finalOutput, + /* is_batch= */ false, + T><<>>( + phi::funcs::detail::forward::gru_finalOutput(), + value.gate_value, + value.prev_out_value, + value.output_value, + frame_size, + batch_size, + active_node, + origin_mode); + } else { + detail::KeGruForwardFinalOutput< + phi::funcs::detail::forward::gru_finalOutput, + /* is_batch= */ true, + T><<>>( + phi::funcs::detail::forward::gru_finalOutput(), + value.gate_value, + value.prev_out_value, + value.output_value, + frame_size, + batch_size, + active_node, + origin_mode); + } + } +}; + +template +struct GRUUnitGradFunctor { + static void compute(const paddle::platform::CUDADeviceContext &context, + GRUMetaValue value, + GRUMetaGrad grad, + int frame_size, + int batch_size, + const phi::funcs::detail::ActivationType active_node, + const phi::funcs::detail::ActivationType active_gate, + bool origin_mode) { + auto stream = context.stream(); + dim3 threads; + dim3 grid; + if (batch_size == 1) { + int frame_per_block = frame_size <= 1024 ? frame_size : 1024; + int frame_blocks = (frame_size + 1024 - 1) / 1024; + threads = dim3(frame_per_block, 1); + grid = dim3(frame_blocks, 1); + } else { + threads = dim3(32, 32); + grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32); + } + + if (batch_size == 1) { + detail::KeGruBackwardStateGrad< + phi::funcs::detail::backward::gru_stateGrad, + /* is_batch= */ false><<>>( + phi::funcs::detail::backward::gru_stateGrad(), + value.gate_value, + grad.gate_grad, + value.prev_out_value, + grad.prev_out_grad, + grad.output_grad, + frame_size, + batch_size, + active_node, + origin_mode); + } else { + detail::KeGruBackwardStateGrad< + phi::funcs::detail::backward::gru_stateGrad, + /* is_batch= */ true><<>>( + phi::funcs::detail::backward::gru_stateGrad(), + value.gate_value, + grad.gate_grad, + value.prev_out_value, + grad.prev_out_grad, + grad.output_grad, + frame_size, + batch_size, + active_node, + origin_mode); + } + + auto blas = + phi::funcs::GetBlas(context); + + if (value.prev_out_value && grad.prev_out_grad) { + blas.GEMM(false, + true, + batch_size, + frame_size, + frame_size, + 1, + grad.gate_grad + frame_size * 2, + frame_size * 3, + value.state_weight, + frame_size, + 0, + grad.reset_output_grad, + frame_size); + + if (grad.state_weight_grad) { + blas.GEMM(true, + false, + frame_size, + frame_size, + batch_size, + 1, + value.reset_output_value, + frame_size, + grad.gate_grad + frame_size * 2, + frame_size * 3, + 1, + grad.state_weight_grad, + frame_size); + } + } + + if (batch_size == 1) { + detail::KeGruBackwardResetGrad< + phi::funcs::detail::backward::gru_resetGrad, + /* is_batch= */ false><<>>( + phi::funcs::detail::backward::gru_resetGrad(), + value.gate_value, + grad.gate_grad, + value.prev_out_value, + grad.prev_out_grad, + grad.reset_output_grad, + frame_size, + batch_size, + active_gate); + } else { + detail::KeGruBackwardResetGrad< + phi::funcs::detail::backward::gru_resetGrad, + /* is_batch= */ true><<>>( + phi::funcs::detail::backward::gru_resetGrad(), + value.gate_value, + grad.gate_grad, + value.prev_out_value, + grad.prev_out_grad, + grad.reset_output_grad, + frame_size, + batch_size, + active_gate); + } + + if (grad.prev_out_grad && value.prev_out_value) { + blas.GEMM(false, + true, + batch_size, + frame_size, + frame_size * 2, + 1, + grad.gate_grad, + frame_size * 3, + value.gate_weight, + frame_size * 2, + 1, + grad.prev_out_grad, + frame_size); + + if (grad.gate_weight_grad) { + blas.GEMM(true, + false, + frame_size, + frame_size * 2, + batch_size, + 1, + value.prev_out_value, + frame_size, + grad.gate_grad, + frame_size * 3, + 1, + grad.gate_weight_grad, + frame_size * 2); + } + } + } +}; + +template struct GRUUnitFunctor; +template struct GRUUnitFunctor; +template struct GRUUnitGradFunctor; +template struct GRUUnitGradFunctor; + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/funcs/gru_compute.h b/paddle/phi/kernels/funcs/gru_compute.h new file mode 100644 index 00000000000..02b2b91423c --- /dev/null +++ b/paddle/phi/kernels/funcs/gru_compute.h @@ -0,0 +1,88 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/kernels/funcs/detail/activation_functions.h" + +namespace phi { +namespace funcs { + +template +struct GRUMetaValue { + const T *gate_weight; + const T *state_weight; + const T *reset_bias; + T *gate_value; + T *reset_output_value; + T *output_value; + const T *prev_out_value; +}; + +template +struct GRUMetaGrad { + T *gate_weight_grad; + T *state_weight_grad; + T *gate_grad; + T *reset_output_grad; + T *output_grad; + T *prev_out_grad; + T *bias_hh_grad; +}; + +template +struct GRUUnitFunctor { + static void compute(const DeviceContext &context, + GRUMetaValue value, + int frame_size, + int batch_size, + const phi::funcs::detail::ActivationType active_node, + const phi::funcs::detail::ActivationType active_gate, + bool origin_mode); +}; + +template +struct GRUUnitGradFunctor { + static void compute(const DeviceContext &context, + GRUMetaValue value, + GRUMetaGrad grad, + int frame_size, + int batch_size, + const phi::funcs::detail::ActivationType active_node, + const phi::funcs::detail::ActivationType active_gate, + bool origin_mode); +}; + +template +struct GRUUnitFunctorV2 { + static void compute(const DeviceContext &context, + GRUMetaValue value, + int frame_size, + int batch_size, + const phi::funcs::detail::ActivationType active_node, + const phi::funcs::detail::ActivationType active_gate); +}; + +template +struct GRUUnitGradFunctorV2 { + static void compute(const DeviceContext &context, + GRUMetaValue value, + GRUMetaGrad grad, + int frame_size, + int batch_size, + const phi::funcs::detail::ActivationType active_node, + const phi::funcs::detail::ActivationType active_gate); +}; + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/funcs/lstm_compute.cc b/paddle/phi/kernels/funcs/lstm_compute.cc new file mode 100644 index 00000000000..19932c62b01 --- /dev/null +++ b/paddle/phi/kernels/funcs/lstm_compute.cc @@ -0,0 +1,103 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/funcs/lstm_compute.h" +#include "paddle/phi/kernels/funcs/detail/lstm_cpu_kernel.h" +#include "paddle/phi/kernels/funcs/detail/lstm_kernel.h" + +namespace phi { +namespace funcs { + +template +struct LstmUnitFunctor { + static void compute(const paddle::platform::CPUDeviceContext& context, + LstmMetaValue value, + int frame_size, + int batch_size, + T cell_clip, + const phi::funcs::detail::ActivationType& gate_act, + const phi::funcs::detail::ActivationType& cell_act, + const phi::funcs::detail::ActivationType& cand_act, + bool old_api_version = true) { + for (int b = 0; b < batch_size; b++) { + detail::cpu_lstm_forward(context, + phi::funcs::detail::forward::lstm(), + value, + frame_size, + cell_clip, + cand_act, + gate_act, + cell_act, + old_api_version); + value.gate_value += frame_size * 4; + value.state_value += frame_size; + value.state_active_value += frame_size; + value.output_value += frame_size; + if (value.prev_state_value) { + value.prev_state_value += frame_size; + } + } + } +}; + +template +struct LstmUnitGradFunctor { + static void compute(const paddle::platform::CPUDeviceContext& context, + LstmMetaValue value, + LstmMetaGrad grad, + int frame_size, + int batch_size, + T cell_clip, + const phi::funcs::detail::ActivationType& gate_act, + const phi::funcs::detail::ActivationType& cell_act, + const phi::funcs::detail::ActivationType& cand_act, + bool old_api_version = true) { + for (int b = 0; b < batch_size; b++) { + detail::cpu_lstm_backward(context, + phi::funcs::detail::backward::lstm(), + value, + grad, + frame_size, + cell_clip, + cand_act, + gate_act, + cell_act, + old_api_version); + + value.gate_value += frame_size * 4; + value.state_value += frame_size; + value.state_active_value += frame_size; + value.output_value += frame_size; + if (value.prev_state_value) { + value.prev_state_value += frame_size; + } + + grad.gate_grad += frame_size * 4; + grad.state_grad += frame_size; + grad.state_active_grad += frame_size; + grad.output_grad += frame_size; + if (grad.prev_state_grad) { + grad.prev_state_grad += frame_size; + } + } + } +}; + +template class LstmUnitFunctor; +template class LstmUnitFunctor; +template class LstmUnitGradFunctor; +template class LstmUnitGradFunctor; + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/funcs/lstm_compute.cu b/paddle/phi/kernels/funcs/lstm_compute.cu new file mode 100644 index 00000000000..b2057cfc4f9 --- /dev/null +++ b/paddle/phi/kernels/funcs/lstm_compute.cu @@ -0,0 +1,76 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/funcs/detail/lstm_gpu_kernel.h" +#include "paddle/phi/kernels/funcs/detail/lstm_kernel.h" +#include "paddle/phi/kernels/funcs/lstm_compute.h" + +namespace phi { +namespace funcs { + +template +struct LstmUnitFunctor { + static void compute(const paddle::platform::CUDADeviceContext& context, + LstmMetaValue value, + int frame_size, + int batch_size, + T cell_clip, + const phi::funcs::detail::ActivationType& gate_act, + const phi::funcs::detail::ActivationType& cell_act, + const phi::funcs::detail::ActivationType& cand_act, + bool old_api_version = true) { + detail::gpu_lstm_forward(context, + phi::funcs::detail::forward::lstm(), + value, + frame_size, + batch_size, + cell_clip, + cand_act, + gate_act, + cell_act); + } +}; + +template +struct LstmUnitGradFunctor { + static void compute(const paddle::platform::CUDADeviceContext& context, + LstmMetaValue value, + LstmMetaGrad grad, + int frame_size, + int batch_size, + T cell_clip, + const phi::funcs::detail::ActivationType& gate_act, + const phi::funcs::detail::ActivationType& cell_act, + const phi::funcs::detail::ActivationType& cand_act, + bool old_api_version = true) { + detail::gpu_lstm_backward(context, + phi::funcs::detail::backward::lstm(), + value, + grad, + frame_size, + batch_size, + cell_clip, + cand_act, + gate_act, + cell_act); + } +}; + +template class LstmUnitFunctor; +template class LstmUnitFunctor; +template class LstmUnitGradFunctor; +template class LstmUnitGradFunctor; + +} // namespace funcs +} // namespace phi diff --git a/paddle/fluid/operators/math/lstm_compute.h b/paddle/phi/kernels/funcs/lstm_compute.h similarity index 56% rename from paddle/fluid/operators/math/lstm_compute.h rename to paddle/phi/kernels/funcs/lstm_compute.h index cc91f784f39..d51b92fc4fd 100644 --- a/paddle/fluid/operators/math/lstm_compute.h +++ b/paddle/phi/kernels/funcs/lstm_compute.h @@ -14,13 +14,12 @@ limitations under the License. */ #pragma once -#include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/kernels/funcs/detail/activation_functions.h" -namespace paddle { -namespace operators { -namespace math { +namespace phi { +namespace funcs { template struct LstmMetaValue { @@ -49,25 +48,31 @@ struct LstmMetaGrad { template class LstmUnitFunctor { public: - static void compute(const DeviceContext &context, LstmMetaValue value, - int frame_size, int batch_size, T cell_clip, - const detail::ActivationType &gate_act, - const detail::ActivationType &cell_act, - const detail::ActivationType &cand_act, + static void compute(const DeviceContext &context, + LstmMetaValue value, + int frame_size, + int batch_size, + T cell_clip, + const phi::funcs::detail::ActivationType &gate_act, + const phi::funcs::detail::ActivationType &cell_act, + const phi::funcs::detail::ActivationType &cand_act, bool old_api_version = true); }; template class LstmUnitGradFunctor { public: - static void compute(const DeviceContext &context, LstmMetaValue value, - LstmMetaGrad grad, int frame_size, int batch_size, - T cell_clip, const detail::ActivationType &gate_act, - const detail::ActivationType &cell_act, - const detail::ActivationType &cand_act, + static void compute(const DeviceContext &context, + LstmMetaValue value, + LstmMetaGrad grad, + int frame_size, + int batch_size, + T cell_clip, + const phi::funcs::detail::ActivationType &gate_act, + const phi::funcs::detail::ActivationType &cell_act, + const phi::funcs::detail::ActivationType &cand_act, bool old_api_version = true); }; -} // namespace math -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi diff --git a/paddle/fluid/operators/math/sequence2batch.cc b/paddle/phi/kernels/funcs/sequence2batch.cc similarity index 56% rename from paddle/fluid/operators/math/sequence2batch.cc rename to paddle/phi/kernels/funcs/sequence2batch.cc index 852700fa7ff..0d75ba877db 100644 --- a/paddle/fluid/operators/math/sequence2batch.cc +++ b/paddle/phi/kernels/funcs/sequence2batch.cc @@ -12,47 +12,45 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/math/sequence2batch.h" +#include "paddle/phi/kernels/funcs/sequence2batch.h" -namespace paddle { -namespace platform { -class CPUDeviceContext; -} // namespace platform -} // namespace paddle - -namespace paddle { -namespace operators { -namespace math { +namespace phi { +namespace funcs { template -class CopyMatrixRowsFunctor { +class CopyMatrixRowsFunctor { public: - void operator()(const platform::CPUDeviceContext& context, - const framework::Tensor& src, - framework::Vector index_lod, framework::Tensor* dst, + void operator()(const paddle::platform::CPUDeviceContext& context, + const paddle::framework::Tensor& src, + paddle::framework::Vector index_lod, + paddle::framework::Tensor* dst, bool is_src_index) { size_t* index = index_lod.data(); auto src_dims = src.dims(); auto dst_dims = dst->dims(); - PADDLE_ENFORCE_EQ(src_dims.size(), 2UL, - platform::errors::InvalidArgument( + PADDLE_ENFORCE_EQ(src_dims.size(), + 2UL, + phi::errors::InvalidArgument( "The source tensor must be a matrix with rank 2, but " "got the source tensor rank is %lu. " "Please check the rank of the source tensor", src_dims.size())); - PADDLE_ENFORCE_EQ(dst_dims.size(), 2UL, - platform::errors::InvalidArgument( + PADDLE_ENFORCE_EQ(dst_dims.size(), + 2UL, + phi::errors::InvalidArgument( "The destination tensor must be a matrix with rank, " "but got the destination tensor rank is %lu. " "Please check the rank of the destination tensor", dst_dims.size())); PADDLE_ENFORCE_EQ( - src_dims[1], dst_dims[1], - platform::errors::InvalidArgument( + src_dims[1], + dst_dims[1], + phi::errors::InvalidArgument( "The width of the source tensor and the destination tensor must be " "same. But got %lu != %lu.Please check the rank of the source " "tensor", - src_dims.size(), dst_dims.size())); + src_dims.size(), + dst_dims.size())); auto height = dst_dims[0]; auto width = dst_dims[1]; auto* src_data = src.data(); @@ -70,14 +68,18 @@ class CopyMatrixRowsFunctor { } }; -template class CopyMatrixRowsFunctor; -template class CopyMatrixRowsFunctor; +template class CopyMatrixRowsFunctor; +template class CopyMatrixRowsFunctor; -template class LoDTensor2BatchFunctor; -template class LoDTensor2BatchFunctor; -template class Batch2LoDTensorFunctor; -template class Batch2LoDTensorFunctor; +template class LoDTensor2BatchFunctor; +template class LoDTensor2BatchFunctor; +template class Batch2LoDTensorFunctor; +template class Batch2LoDTensorFunctor; -} // namespace math -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi diff --git a/paddle/fluid/operators/math/sequence2batch.cu b/paddle/phi/kernels/funcs/sequence2batch.cu similarity index 55% rename from paddle/fluid/operators/math/sequence2batch.cu rename to paddle/phi/kernels/funcs/sequence2batch.cu index f56c5293971..a66030e6426 100644 --- a/paddle/fluid/operators/math/sequence2batch.cu +++ b/paddle/phi/kernels/funcs/sequence2batch.cu @@ -11,15 +11,17 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/math/sequence2batch.h" +#include "paddle/phi/kernels/funcs/sequence2batch.h" -namespace paddle { -namespace operators { -namespace math { +namespace phi { +namespace funcs { template -__global__ void CopyMatrixRowsKernel(const T* src, T* dst, const size_t* index, - int64_t height, int64_t width, +__global__ void CopyMatrixRowsKernel(const T* src, + T* dst, + const size_t* index, + int64_t height, + int64_t width, bool is_src_index) { int idx = threadIdx.x; int idy = threadIdx.y; @@ -37,33 +39,38 @@ __global__ void CopyMatrixRowsKernel(const T* src, T* dst, const size_t* index, } template -class CopyMatrixRowsFunctor { +class CopyMatrixRowsFunctor { public: - void operator()(const platform::CUDADeviceContext& context, - const framework::Tensor& src, - framework::Vector index_lod, framework::Tensor* dst, + void operator()(const paddle::platform::CUDADeviceContext& context, + const paddle::framework::Tensor& src, + paddle::framework::Vector index_lod, + paddle::framework::Tensor* dst, bool is_src_index) { auto src_dims = src.dims(); auto dst_dims = dst->dims(); - PADDLE_ENFORCE_EQ(src_dims.size(), 2, - platform::errors::InvalidArgument( + PADDLE_ENFORCE_EQ(src_dims.size(), + 2, + phi::errors::InvalidArgument( "The source tensor must be a matrix with rank 2, but " "got the source tensor rank is %lu. " "Please check the rank of the source tensor", src_dims.size())); - PADDLE_ENFORCE_EQ(dst_dims.size(), 2, - platform::errors::InvalidArgument( + PADDLE_ENFORCE_EQ(dst_dims.size(), + 2, + phi::errors::InvalidArgument( "The destination tensor must be a matrix with rank, " "but got the destination tensor rank is %lu. " "Please check the rank of the destination tensor", dst_dims.size())); PADDLE_ENFORCE_EQ( - src_dims[1], dst_dims[1], - platform::errors::InvalidArgument( + src_dims[1], + dst_dims[1], + phi::errors::InvalidArgument( "The width of the source tensor and the destination tensor must be " "same. But got %lu != %lu.Please check the rank of the source " "tensor", - src_dims.size(), dst_dims.size())); + src_dims.size(), + dst_dims.size())); auto height = dst_dims[0]; auto width = dst_dims[1]; auto* src_data = src.data(); @@ -74,19 +81,28 @@ class CopyMatrixRowsFunctor { auto stream = context.stream(); paddle::framework::MixVector mix_index_lod(&index_lod); CopyMatrixRowsKernel<<>>( - src_data, dst_data, mix_index_lod.CUDAData(context.GetPlace()), height, - width, is_src_index); + src_data, + dst_data, + mix_index_lod.CUDAData(context.GetPlace()), + height, + width, + is_src_index); } }; -template class CopyMatrixRowsFunctor; -template class CopyMatrixRowsFunctor; +template class CopyMatrixRowsFunctor; +template class CopyMatrixRowsFunctor; -template class LoDTensor2BatchFunctor; -template class LoDTensor2BatchFunctor; -template class Batch2LoDTensorFunctor; -template class Batch2LoDTensorFunctor; +template class LoDTensor2BatchFunctor; +template class LoDTensor2BatchFunctor; +template class Batch2LoDTensorFunctor; +template class Batch2LoDTensorFunctor; -} // namespace math -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi diff --git a/paddle/fluid/operators/math/sequence2batch.h b/paddle/phi/kernels/funcs/sequence2batch.h similarity index 80% rename from paddle/fluid/operators/math/sequence2batch.h rename to paddle/phi/kernels/funcs/sequence2batch.h index 6aa513e4d10..e7c387fb99b 100644 --- a/paddle/fluid/operators/math/sequence2batch.h +++ b/paddle/phi/kernels/funcs/sequence2batch.h @@ -20,13 +20,13 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/platform/device_context.h" -namespace paddle { -namespace operators { -namespace math { +namespace phi { +namespace funcs { -template -using EigenMatrix = framework::EigenMatrix; +using EigenMatrix = paddle::framework::EigenMatrix; template class CopyMatrixRowsFunctor { @@ -36,8 +36,10 @@ class CopyMatrixRowsFunctor { // If is_src_index is false, // copy the input src to the indexed rows of output dst. // The indexed rows are based on the input index. - void operator()(const DeviceContext& context, const framework::Tensor& src, - framework::Vector index_lod, framework::Tensor* dst, + void operator()(const DeviceContext& context, + const paddle::framework::Tensor& src, + paddle::framework::Vector index_lod, + paddle::framework::Tensor* dst, bool is_src_index); }; @@ -59,32 +61,37 @@ class LoDTensor2BatchFunctor { public: void operator()(const DeviceContext& context, - const framework::LoDTensor& lod_tensor, - framework::LoDTensor* batch, bool is_cal_batch_lod, + const paddle::framework::LoDTensor& lod_tensor, + paddle::framework::LoDTensor* batch, + bool is_cal_batch_lod, bool is_reverse = false) const { if (!is_cal_batch_lod) { auto lods = batch->lod(); PADDLE_ENFORCE_GT( - lods.size(), 2UL, - platform::errors::InvalidArgument( + lods.size(), + 2UL, + phi::errors::InvalidArgument( "The LoD of LoDTensor should inlcude at least 2-level " "sequence information, but got the LoD level is %lu. Please " "check the input value.", lods.size())); PADDLE_ENFORCE_EQ( - lods[1].size(), static_cast(lod_tensor.dims()[0]), - platform::errors::InvalidArgument( + lods[1].size(), + static_cast(lod_tensor.dims()[0]), + phi::errors::InvalidArgument( "The LoD information should be consistent with the dims, but got " "%lu != %lu. Please check the input value.", - lods[1].size(), static_cast(lod_tensor.dims()[0]))); + lods[1].size(), + static_cast(lod_tensor.dims()[0]))); CopyMatrixRowsFunctor to_batch; to_batch(context, lod_tensor, lods[1], batch, true); return; } auto lods = lod_tensor.lod(); - PADDLE_ENFORCE_EQ(lods.size(), 1UL, - platform::errors::InvalidArgument( + PADDLE_ENFORCE_EQ(lods.size(), + 1UL, + phi::errors::InvalidArgument( "Only support one level sequence now, but got the " "LoD level is %lu. Please check the input value.", lods.size())); @@ -97,8 +104,9 @@ class LoDTensor2BatchFunctor { seq_info.emplace_back(lod[seq_id], length, seq_id); } - std::sort(seq_info.begin(), seq_info.end(), - [](SeqInfo a, SeqInfo b) { return a.length > b.length; }); + std::sort(seq_info.begin(), seq_info.end(), [](SeqInfo a, SeqInfo b) { + return a.length > b.length; + }); // Calculate the start position of each batch. // example: sequences = {s0, s1, s2} @@ -169,27 +177,29 @@ template class Batch2LoDTensorFunctor { public: void operator()(const DeviceContext& context, - const framework::LoDTensor& batch, - framework::LoDTensor* lod_tensor) const { + const paddle::framework::LoDTensor& batch, + paddle::framework::LoDTensor* lod_tensor) const { auto in_lod = batch.lod(); PADDLE_ENFORCE_GT( - in_lod.size(), 2UL, - platform::errors::InvalidArgument( + in_lod.size(), + 2UL, + phi::errors::InvalidArgument( "The LoD of LoDTensor should inlcude at least 2-level " "sequence information, but got the LoD level is %lu. Please check " "the input value.", in_lod.size())); PADDLE_ENFORCE_EQ( - in_lod[1].size(), static_cast(lod_tensor->dims()[0]), - platform::errors::InvalidArgument( + in_lod[1].size(), + static_cast(lod_tensor->dims()[0]), + phi::errors::InvalidArgument( "The LoD information should be consistent with the dims, but got " "%lu != %lu. Please check the input value.", - in_lod[1].size(), static_cast(lod_tensor->dims()[0]))); + in_lod[1].size(), + static_cast(lod_tensor->dims()[0]))); CopyMatrixRowsFunctor to_seq; to_seq(context, batch, in_lod[1], lod_tensor, false); } }; -} // namespace math -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi -- GitLab