未验证 提交 e4dba69a 编写于 作者: F Feiyu Chan 提交者: GitHub

[Pten] Gru lstm migration (#39729)

* move sequence2batch

* move lstm and gru

* Add phi/kernels directory into exclusion to stop using hipcc to compile non .cu files in it.
上级 dbcf8797
...@@ -580,8 +580,8 @@ function(hip_library TARGET_NAME) ...@@ -580,8 +580,8 @@ function(hip_library TARGET_NAME)
cmake_parse_arguments(hip_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(hip_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
if(hip_library_SRCS) if(hip_library_SRCS)
# FindHIP.cmake defined hip_add_library, HIP_SOURCE_PROPERTY_FORMAT is requried if no .cu files found # FindHIP.cmake defined hip_add_library, HIP_SOURCE_PROPERTY_FORMAT is requried if no .cu files found
if(NOT ${CMAKE_CURRENT_SOURCE_DIR} MATCHES ".*/operators") if(NOT (${CMAKE_CURRENT_SOURCE_DIR} MATCHES ".*/operators" OR ${CMAKE_CURRENT_SOURCE_DIR} MATCHES ".*/phi/kernels"))
set_source_files_properties(${hip_library_SRCS} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) set_source_files_properties(${hip_library_SRCS} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
endif() endif()
if (hip_library_SHARED OR hip_library_shared) # build *.so if (hip_library_SHARED OR hip_library_shared) # build *.so
hip_add_library(${TARGET_NAME} SHARED ${hip_library_SRCS}) hip_add_library(${TARGET_NAME} SHARED ${hip_library_SRCS})
......
...@@ -15,9 +15,9 @@ limitations under the License. */ ...@@ -15,9 +15,9 @@ limitations under the License. */
#include "paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.h" #include "paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.h"
#include <string> #include <string>
#include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/sequence2batch.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -473,7 +473,7 @@ class FusedEmbeddingFCLSTMKernel : public framework::OpKernel<T> { ...@@ -473,7 +473,7 @@ class FusedEmbeddingFCLSTMKernel : public framework::OpKernel<T> {
hidden_out->mutable_data<T>(place); hidden_out->mutable_data<T>(place);
cell_out->mutable_data<T>(place); cell_out->mutable_data<T>(place);
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch; phi::funcs::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx); auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
...@@ -591,7 +591,7 @@ class FusedEmbeddingFCLSTMKernel : public framework::OpKernel<T> { ...@@ -591,7 +591,7 @@ class FusedEmbeddingFCLSTMKernel : public framework::OpKernel<T> {
#undef MOVE_ONE_BATCH #undef MOVE_ONE_BATCH
#undef DEFINE_CUR #undef DEFINE_CUR
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq; phi::funcs::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
batched_h_out->set_lod(batched_lod); batched_h_out->set_lod(batched_lod);
to_seq(dev_ctx, *batched_h_out, hidden_out); to_seq(dev_ctx, *batched_h_out, hidden_out);
batched_c_out->set_lod(batched_lod); batched_c_out->set_lod(batched_lod);
......
...@@ -19,8 +19,8 @@ limitations under the License. */ ...@@ -19,8 +19,8 @@ limitations under the License. */
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/jit/kernels.h" #include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/fc.h" #include "paddle/fluid/operators/math/fc.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/sequence2batch.h"
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#endif #endif
...@@ -368,7 +368,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -368,7 +368,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
hidden_out->mutable_data<T>(place); hidden_out->mutable_data<T>(place);
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx); auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch; phi::funcs::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
math::FCFunctor<DeviceContext, T> fc; math::FCFunctor<DeviceContext, T> fc;
if (M > D3) { if (M > D3) {
...@@ -463,7 +463,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -463,7 +463,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
batched_input_data = cur_batched_data; batched_input_data = cur_batched_data;
} }
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq; phi::funcs::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
batched_out->set_lod(batched_lod); batched_out->set_lod(batched_lod);
to_seq(dev_ctx, *batched_out, hidden_out); to_seq(dev_ctx, *batched_out, hidden_out);
} }
......
...@@ -16,8 +16,8 @@ limitations under the License. */ ...@@ -16,8 +16,8 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/operators/jit/kernels.h" #include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/fc.h" #include "paddle/fluid/operators/math/fc.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/sequence2batch.h"
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#endif #endif
...@@ -421,7 +421,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -421,7 +421,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
hidden_out->mutable_data<T>(place); hidden_out->mutable_data<T>(place);
cell_out->mutable_data<T>(place); cell_out->mutable_data<T>(place);
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch; phi::funcs::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx); auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
math::FCFunctor<DeviceContext, T> fc; math::FCFunctor<DeviceContext, T> fc;
...@@ -514,7 +514,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -514,7 +514,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
batched_input_data = cur_in_data; batched_input_data = cur_in_data;
} }
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq; phi::funcs::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
batched_h_out->set_lod(batched_lod); batched_h_out->set_lod(batched_lod);
to_seq(dev_ctx, *batched_h_out, hidden_out); to_seq(dev_ctx, *batched_h_out, hidden_out);
batched_c_out->set_lod(batched_lod); batched_c_out->set_lod(batched_lod);
......
...@@ -19,8 +19,8 @@ limitations under the License. */ ...@@ -19,8 +19,8 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/operators/jit/kernels.h" #include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/fc.h" #include "paddle/fluid/operators/math/fc.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/sequence2batch.h"
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#endif #endif
......
...@@ -15,9 +15,9 @@ limitations under the License. */ ...@@ -15,9 +15,9 @@ limitations under the License. */
#include "paddle/fluid/operators/gru_op.h" #include "paddle/fluid/operators/gru_op.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h"
#include "paddle/fluid/operators/math/detail/gru_kernel.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/detail/gru_cpu_kernel.h"
#include "paddle/phi/kernels/funcs/detail/gru_kernel.h"
DECLARE_int32(paddle_num_threads); DECLARE_int32(paddle_num_threads);
...@@ -316,7 +316,7 @@ class GRUCPUKernel : public framework::OpKernel<T> { ...@@ -316,7 +316,7 @@ class GRUCPUKernel : public framework::OpKernel<T> {
batch_hidden->mutable_data<T>(context.GetPlace()); batch_hidden->mutable_data<T>(context.GetPlace());
bool is_reverse = context.Attr<bool>("is_reverse"); bool is_reverse = context.Attr<bool>("is_reverse");
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch; phi::funcs::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
to_batch(dev_ctx, *input, batch_gate, true, is_reverse); to_batch(dev_ctx, *input, batch_gate, true, is_reverse);
...@@ -326,7 +326,7 @@ class GRUCPUKernel : public framework::OpKernel<T> { ...@@ -326,7 +326,7 @@ class GRUCPUKernel : public framework::OpKernel<T> {
} }
int frame_size = hidden_dims[1]; int frame_size = hidden_dims[1];
math::GRUMetaValue<T> gru_value; phi::funcs::GRUMetaValue<T> gru_value;
gru_value.gate_weight = const_cast<T*>(weight_data); gru_value.gate_weight = const_cast<T*>(weight_data);
gru_value.state_weight = gru_value.state_weight =
const_cast<T*>(weight_data + 2 * frame_size * frame_size); const_cast<T*>(weight_data + 2 * frame_size * frame_size);
...@@ -347,9 +347,9 @@ class GRUCPUKernel : public framework::OpKernel<T> { ...@@ -347,9 +347,9 @@ class GRUCPUKernel : public framework::OpKernel<T> {
} }
auto batch_starts = batch_gate->lod()[0]; auto batch_starts = batch_gate->lod()[0];
size_t seq_len = batch_starts.size() - 1; size_t seq_len = batch_starts.size() - 1;
auto active_node = math::detail::GetActivationType( auto active_node = phi::funcs::detail::GetActivationType(
context.Attr<std::string>("activation")); context.Attr<std::string>("activation"));
auto active_gate = math::detail::GetActivationType( auto active_gate = phi::funcs::detail::GetActivationType(
context.Attr<std::string>("gate_activation")); context.Attr<std::string>("gate_activation"));
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
...@@ -396,9 +396,9 @@ class GRUCPUKernel : public framework::OpKernel<T> { ...@@ -396,9 +396,9 @@ class GRUCPUKernel : public framework::OpKernel<T> {
frame_size * 2, T(1), gru_value.gate_value, frame_size * 3); frame_size * 2, T(1), gru_value.gate_value, frame_size * 3);
} }
math::detail::forward_reset_output( phi::funcs::detail::forward_reset_output(
math::detail::forward::gru_resetOutput<T>(), gru_value, frame_size, phi::funcs::detail::forward::gru_resetOutput<T>(), gru_value,
cur_batch_size, active_gate); frame_size, cur_batch_size, active_gate);
if (gru_value.prev_out_value) { if (gru_value.prev_out_value) {
blas.GEMM_COMPUTE( blas.GEMM_COMPUTE(
...@@ -408,9 +408,9 @@ class GRUCPUKernel : public framework::OpKernel<T> { ...@@ -408,9 +408,9 @@ class GRUCPUKernel : public framework::OpKernel<T> {
frame_size * 3); frame_size * 3);
} }
math::detail::forward_final_output( phi::funcs::detail::forward_final_output(
math::detail::forward::gru_finalOutput<T>(), gru_value, frame_size, phi::funcs::detail::forward::gru_finalOutput<T>(), gru_value,
cur_batch_size, active_node, origin_mode); frame_size, cur_batch_size, active_node, origin_mode);
gru_value.prev_out_value = gru_value.output_value; gru_value.prev_out_value = gru_value.output_value;
} }
...@@ -432,7 +432,7 @@ class GRUCPUKernel : public framework::OpKernel<T> { ...@@ -432,7 +432,7 @@ class GRUCPUKernel : public framework::OpKernel<T> {
gru_value.gate_value = gate_t.data<T>(); gru_value.gate_value = gate_t.data<T>();
gru_value.reset_output_value = reset_hidden_prev_t.data<T>(); gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
math::GRUUnitFunctor<DeviceContext, T>::compute( phi::funcs::GRUUnitFunctor<DeviceContext, T>::compute(
dev_ctx, gru_value, frame_size, cur_batch_size, active_node, dev_ctx, gru_value, frame_size, cur_batch_size, active_node,
active_gate, origin_mode); active_gate, origin_mode);
...@@ -441,7 +441,7 @@ class GRUCPUKernel : public framework::OpKernel<T> { ...@@ -441,7 +441,7 @@ class GRUCPUKernel : public framework::OpKernel<T> {
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
} }
#endif #endif
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq; phi::funcs::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
batch_hidden->set_lod(batch_gate->lod()); batch_hidden->set_lod(batch_gate->lod());
to_seq(dev_ctx, *batch_hidden, hidden); to_seq(dev_ctx, *batch_hidden, hidden);
} }
......
...@@ -65,7 +65,7 @@ class GRUKernel : public framework::OpKernel<T> { ...@@ -65,7 +65,7 @@ class GRUKernel : public framework::OpKernel<T> {
batch_hidden->mutable_data<T>(context.GetPlace()); batch_hidden->mutable_data<T>(context.GetPlace());
bool is_reverse = context.Attr<bool>("is_reverse"); bool is_reverse = context.Attr<bool>("is_reverse");
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch; phi::funcs::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
to_batch(dev_ctx, *input, batch_gate, true, is_reverse); to_batch(dev_ctx, *input, batch_gate, true, is_reverse);
...@@ -75,7 +75,7 @@ class GRUKernel : public framework::OpKernel<T> { ...@@ -75,7 +75,7 @@ class GRUKernel : public framework::OpKernel<T> {
} }
int frame_size = hidden_dims[1]; int frame_size = hidden_dims[1];
math::GRUMetaValue<T> gru_value; phi::funcs::GRUMetaValue<T> gru_value;
gru_value.gate_weight = const_cast<T*>(weight_data); gru_value.gate_weight = const_cast<T*>(weight_data);
gru_value.state_weight = gru_value.state_weight =
const_cast<T*>(weight_data + 2 * frame_size * frame_size); const_cast<T*>(weight_data + 2 * frame_size * frame_size);
...@@ -96,9 +96,9 @@ class GRUKernel : public framework::OpKernel<T> { ...@@ -96,9 +96,9 @@ class GRUKernel : public framework::OpKernel<T> {
} }
auto batch_starts = batch_gate->lod()[0]; auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1; size_t num_batch = batch_starts.size() - 1;
auto active_node = math::detail::GetActivationType( auto active_node = phi::funcs::detail::GetActivationType(
context.Attr<std::string>("activation")); context.Attr<std::string>("activation"));
auto active_gate = math::detail::GetActivationType( auto active_gate = phi::funcs::detail::GetActivationType(
context.Attr<std::string>("gate_activation")); context.Attr<std::string>("gate_activation"));
for (size_t n = 0; n < num_batch; n++) { for (size_t n = 0; n < num_batch; n++) {
int bstart = static_cast<int>(batch_starts[n]); int bstart = static_cast<int>(batch_starts[n]);
...@@ -111,13 +111,13 @@ class GRUKernel : public framework::OpKernel<T> { ...@@ -111,13 +111,13 @@ class GRUKernel : public framework::OpKernel<T> {
gru_value.output_value = hidden_t.data<T>(); gru_value.output_value = hidden_t.data<T>();
gru_value.gate_value = gate_t.data<T>(); gru_value.gate_value = gate_t.data<T>();
gru_value.reset_output_value = reset_hidden_prev_t.data<T>(); gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
math::GRUUnitFunctor<DeviceContext, T>::compute( phi::funcs::GRUUnitFunctor<DeviceContext, T>::compute(
dev_ctx, gru_value, frame_size, cur_batch_size, active_node, dev_ctx, gru_value, frame_size, cur_batch_size, active_node,
active_gate, origin_mode); active_gate, origin_mode);
gru_value.prev_out_value = gru_value.output_value; gru_value.prev_out_value = gru_value.output_value;
} }
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq; phi::funcs::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
batch_hidden->set_lod(batch_gate->lod()); batch_hidden->set_lod(batch_gate->lod());
to_seq(dev_ctx, *batch_hidden, hidden); to_seq(dev_ctx, *batch_hidden, hidden);
} }
......
...@@ -16,10 +16,10 @@ limitations under the License. */ ...@@ -16,10 +16,10 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/phi/kernels/funcs/detail/activation_functions.h"
#include "paddle/fluid/operators/math/gru_compute.h" #include "paddle/phi/kernels/funcs/gru_compute.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/sequence2batch.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -32,7 +32,7 @@ inline void ReorderInitState(const DeviceContext& ctx, ...@@ -32,7 +32,7 @@ inline void ReorderInitState(const DeviceContext& ctx,
const framework::Tensor& src, const framework::Tensor& src,
framework::Vector<size_t> index_lod, framework::Vector<size_t> index_lod,
framework::Tensor* dst, bool indexed_src) { framework::Tensor* dst, bool indexed_src) {
math::CopyMatrixRowsFunctor<DeviceContext, T> row_shuffle; phi::funcs::CopyMatrixRowsFunctor<DeviceContext, T> row_shuffle;
dst->mutable_data<T>(src.dims(), ctx.GetPlace()); dst->mutable_data<T>(src.dims(), ctx.GetPlace());
row_shuffle(ctx, src, index_lod, dst, indexed_src); row_shuffle(ctx, src, index_lod, dst, indexed_src);
} }
...@@ -63,7 +63,7 @@ class GRUGradKernel : public framework::OpKernel<T> { ...@@ -63,7 +63,7 @@ class GRUGradKernel : public framework::OpKernel<T> {
auto hidden_dims = hidden->dims(); auto hidden_dims = hidden->dims();
int frame_size = hidden_dims[1]; int frame_size = hidden_dims[1];
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch; phi::funcs::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
LoDTensor batch_hidden_grad, batch_gate_grad, batch_reset_hidden_prev_grad; LoDTensor batch_hidden_grad, batch_gate_grad, batch_reset_hidden_prev_grad;
batch_hidden_grad.mutable_data<T>(hidden_dims, context.GetPlace()); batch_hidden_grad.mutable_data<T>(hidden_dims, context.GetPlace());
batch_gate_grad.mutable_data<T>(gate_dims, context.GetPlace()); batch_gate_grad.mutable_data<T>(gate_dims, context.GetPlace());
...@@ -93,12 +93,12 @@ class GRUGradKernel : public framework::OpKernel<T> { ...@@ -93,12 +93,12 @@ class GRUGradKernel : public framework::OpKernel<T> {
batch_hidden_grad.set_lod(batch_hidden->lod()); batch_hidden_grad.set_lod(batch_hidden->lod());
to_batch(dev_ctx, *hidden_grad, &batch_hidden_grad, false, is_reverse); to_batch(dev_ctx, *hidden_grad, &batch_hidden_grad, false, is_reverse);
math::GRUMetaValue<T> gru_value; phi::funcs::GRUMetaValue<T> gru_value;
gru_value.gate_weight = const_cast<T*>(weight_data); gru_value.gate_weight = const_cast<T*>(weight_data);
gru_value.state_weight = gru_value.state_weight =
const_cast<T*>(weight_data + 2 * frame_size * frame_size); const_cast<T*>(weight_data + 2 * frame_size * frame_size);
math::GRUMetaGrad<T> gru_grad; phi::funcs::GRUMetaGrad<T> gru_grad;
if (weight_grad) { if (weight_grad) {
gru_grad.gate_weight_grad = gru_grad.gate_weight_grad =
weight_grad->mutable_data<T>(context.GetPlace()); weight_grad->mutable_data<T>(context.GetPlace());
...@@ -112,9 +112,9 @@ class GRUGradKernel : public framework::OpKernel<T> { ...@@ -112,9 +112,9 @@ class GRUGradKernel : public framework::OpKernel<T> {
auto batch_starts = batch_hidden_grad.lod()[0]; auto batch_starts = batch_hidden_grad.lod()[0];
size_t num_batch = batch_starts.size() - 1; size_t num_batch = batch_starts.size() - 1;
auto active_node = math::detail::GetActivationType( auto active_node = phi::funcs::detail::GetActivationType(
context.Attr<std::string>("activation")); context.Attr<std::string>("activation"));
auto active_gate = math::detail::GetActivationType( auto active_gate = phi::funcs::detail::GetActivationType(
context.Attr<std::string>("gate_activation")); context.Attr<std::string>("gate_activation"));
for (int n = static_cast<int>(num_batch) - 1; n >= 0; n--) { for (int n = static_cast<int>(num_batch) - 1; n >= 0; n--) {
int bstart = static_cast<int>(batch_starts[n]); int bstart = static_cast<int>(batch_starts[n]);
...@@ -145,13 +145,13 @@ class GRUGradKernel : public framework::OpKernel<T> { ...@@ -145,13 +145,13 @@ class GRUGradKernel : public framework::OpKernel<T> {
gru_grad.prev_out_grad = hidden_prev_grad_t.data<T>(); gru_grad.prev_out_grad = hidden_prev_grad_t.data<T>();
} }
gru_value.output_value = nullptr; gru_value.output_value = nullptr;
math::GRUUnitGradFunctor<DeviceContext, T>::compute( phi::funcs::GRUUnitGradFunctor<DeviceContext, T>::compute(
dev_ctx, gru_value, gru_grad, frame_size, cur_batch_size, active_node, dev_ctx, gru_value, gru_grad, frame_size, cur_batch_size, active_node,
active_gate, origin_mode); active_gate, origin_mode);
} }
if (input_grad) { if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace()); input_grad->mutable_data<T>(context.GetPlace());
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq; phi::funcs::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
batch_gate_grad.set_lod(batch_gate->lod()); batch_gate_grad.set_lod(batch_gate->lod());
to_seq(dev_ctx, batch_gate_grad, input_grad); to_seq(dev_ctx, batch_gate_grad, input_grad);
} }
......
...@@ -15,10 +15,10 @@ limitations under the License. */ ...@@ -15,10 +15,10 @@ limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/lstm_compute.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/detail/activation_functions.h"
#include "paddle/phi/kernels/funcs/lstm_compute.h"
#include "paddle/phi/kernels/funcs/sequence2batch.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -31,7 +31,7 @@ inline void ReorderInitState(const DeviceContext& ctx, ...@@ -31,7 +31,7 @@ inline void ReorderInitState(const DeviceContext& ctx,
const framework::Tensor& src, const framework::Tensor& src,
framework::Vector<size_t> index_lod, framework::Vector<size_t> index_lod,
framework::Tensor* dst, bool indexed_src) { framework::Tensor* dst, bool indexed_src) {
math::CopyMatrixRowsFunctor<DeviceContext, T> row_shuffle; phi::funcs::CopyMatrixRowsFunctor<DeviceContext, T> row_shuffle;
dst->mutable_data<T>(src.dims(), ctx.GetPlace()); dst->mutable_data<T>(src.dims(), ctx.GetPlace());
row_shuffle(ctx, src, index_lod, dst, indexed_src); row_shuffle(ctx, src, index_lod, dst, indexed_src);
} }
...@@ -64,7 +64,7 @@ class LSTMKernel : public framework::OpKernel<T> { ...@@ -64,7 +64,7 @@ class LSTMKernel : public framework::OpKernel<T> {
cell_out->mutable_data<T>(ctx.GetPlace()); cell_out->mutable_data<T>(ctx.GetPlace());
bool is_reverse = ctx.Attr<bool>("is_reverse"); bool is_reverse = ctx.Attr<bool>("is_reverse");
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch; phi::funcs::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
auto& device_ctx = ctx.template device_context<DeviceContext>(); auto& device_ctx = ctx.template device_context<DeviceContext>();
to_batch(device_ctx, *input, batch_gate, true, is_reverse); to_batch(device_ctx, *input, batch_gate, true, is_reverse);
...@@ -80,7 +80,7 @@ class LSTMKernel : public framework::OpKernel<T> { ...@@ -80,7 +80,7 @@ class LSTMKernel : public framework::OpKernel<T> {
add_bias(device_ctx, *batch_gate, gate_bias, batch_gate); add_bias(device_ctx, *batch_gate, gate_bias, batch_gate);
} }
math::LstmMetaValue<T> lstm_value; phi::funcs::LstmMetaValue<T> lstm_value;
if (bias && ctx.Attr<bool>("use_peepholes")) { if (bias && ctx.Attr<bool>("use_peepholes")) {
T* bias_data = const_cast<T*>(bias->data<T>()); T* bias_data = const_cast<T*>(bias->data<T>());
// the code style in LstmMetaValue will be updated later. // the code style in LstmMetaValue will be updated later.
...@@ -121,11 +121,11 @@ class LSTMKernel : public framework::OpKernel<T> { ...@@ -121,11 +121,11 @@ class LSTMKernel : public framework::OpKernel<T> {
auto batch_starts = batch_gate->lod()[0]; auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1; size_t num_batch = batch_starts.size() - 1;
auto gate_act = math::detail::GetActivationType( auto gate_act = phi::funcs::detail::GetActivationType(
ctx.Attr<std::string>("gate_activation")); ctx.Attr<std::string>("gate_activation"));
auto cell_act = math::detail::GetActivationType( auto cell_act = phi::funcs::detail::GetActivationType(
ctx.Attr<std::string>("cell_activation")); ctx.Attr<std::string>("cell_activation"));
auto cand_act = math::detail::GetActivationType( auto cand_act = phi::funcs::detail::GetActivationType(
ctx.Attr<std::string>("candidate_activation")); ctx.Attr<std::string>("candidate_activation"));
auto blas = phi::funcs::GetBlas<DeviceContext, T>(device_ctx); auto blas = phi::funcs::GetBlas<DeviceContext, T>(device_ctx);
...@@ -166,13 +166,13 @@ class LSTMKernel : public framework::OpKernel<T> { ...@@ -166,13 +166,13 @@ class LSTMKernel : public framework::OpKernel<T> {
lstm_value.state_value = cell_t.data<T>(); lstm_value.state_value = cell_t.data<T>();
lstm_value.state_active_value = cell_pre_act_t.data<T>(); lstm_value.state_active_value = cell_pre_act_t.data<T>();
T cell_clip = 0.0; T cell_clip = 0.0;
math::LstmUnitFunctor<DeviceContext, T>::compute( phi::funcs::LstmUnitFunctor<DeviceContext, T>::compute(
device_ctx, lstm_value, frame_size, cur_batch_size, cell_clip, device_ctx, lstm_value, frame_size, cur_batch_size, cell_clip,
gate_act, cell_act, cand_act); gate_act, cell_act, cand_act);
lstm_value.prev_state_value = lstm_value.state_value; lstm_value.prev_state_value = lstm_value.state_value;
} }
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq; phi::funcs::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
batch_hidden.set_lod(batch_gate->lod()); batch_hidden.set_lod(batch_gate->lod());
// restore the output hidden in LoDTensor from the batch hidden // restore the output hidden in LoDTensor from the batch hidden
to_seq(device_ctx, batch_hidden, hidden_out); to_seq(device_ctx, batch_hidden, hidden_out);
...@@ -241,7 +241,7 @@ class LSTMGradKernel : public framework::OpKernel<T> { ...@@ -241,7 +241,7 @@ class LSTMGradKernel : public framework::OpKernel<T> {
") should be %d, but received %d in LSTM@Grad operator.", ") should be %d, but received %d in LSTM@Grad operator.",
frame_size, out_dims[1])); frame_size, out_dims[1]));
math::LstmMetaValue<T> lstm_value; phi::funcs::LstmMetaValue<T> lstm_value;
if (bias && ctx.Attr<bool>("use_peepholes")) { if (bias && ctx.Attr<bool>("use_peepholes")) {
T* bias_data = const_cast<T*>(bias->data<T>()); T* bias_data = const_cast<T*>(bias->data<T>());
lstm_value.check_ig = bias_data + 4 * frame_size; lstm_value.check_ig = bias_data + 4 * frame_size;
...@@ -253,7 +253,7 @@ class LSTMGradKernel : public framework::OpKernel<T> { ...@@ -253,7 +253,7 @@ class LSTMGradKernel : public framework::OpKernel<T> {
lstm_value.check_og = nullptr; lstm_value.check_og = nullptr;
} }
math::LstmMetaGrad<T> lstm_grad; phi::funcs::LstmMetaGrad<T> lstm_grad;
if (bias && bias_g) { if (bias && bias_g) {
bias_g->mutable_data<T>(ctx.GetPlace()); bias_g->mutable_data<T>(ctx.GetPlace());
...@@ -270,7 +270,7 @@ class LSTMGradKernel : public framework::OpKernel<T> { ...@@ -270,7 +270,7 @@ class LSTMGradKernel : public framework::OpKernel<T> {
lstm_grad.check_og_grad = nullptr; lstm_grad.check_og_grad = nullptr;
} }
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch; phi::funcs::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
auto ToBatch = [&batch_gate, &to_batch]( auto ToBatch = [&batch_gate, &to_batch](
const DeviceContext& ctx, const framework::LoDTensor& src, const DeviceContext& ctx, const framework::LoDTensor& src,
...@@ -293,11 +293,11 @@ class LSTMGradKernel : public framework::OpKernel<T> { ...@@ -293,11 +293,11 @@ class LSTMGradKernel : public framework::OpKernel<T> {
batch_gate_g.mutable_data<T>(batch_gate->dims(), ctx.GetPlace()); batch_gate_g.mutable_data<T>(batch_gate->dims(), ctx.GetPlace());
batch_gate_g.set_lod(batch_gate->lod()); batch_gate_g.set_lod(batch_gate->lod());
auto gate_act = math::detail::GetActivationType( auto gate_act = phi::funcs::detail::GetActivationType(
ctx.Attr<std::string>("gate_activation")); ctx.Attr<std::string>("gate_activation"));
auto cell_act = math::detail::GetActivationType( auto cell_act = phi::funcs::detail::GetActivationType(
ctx.Attr<std::string>("cell_activation")); ctx.Attr<std::string>("cell_activation"));
auto cand_act = math::detail::GetActivationType( auto cand_act = phi::funcs::detail::GetActivationType(
ctx.Attr<std::string>("candidate_activation")); ctx.Attr<std::string>("candidate_activation"));
auto batch_starts = batch_gate->lod()[0]; auto batch_starts = batch_gate->lod()[0];
...@@ -338,7 +338,7 @@ class LSTMGradKernel : public framework::OpKernel<T> { ...@@ -338,7 +338,7 @@ class LSTMGradKernel : public framework::OpKernel<T> {
lstm_grad.state_active_grad = nullptr; lstm_grad.state_active_grad = nullptr;
int cur_batch_size = bend - bstart; int cur_batch_size = bend - bstart;
T cell_clip = 0.0; T cell_clip = 0.0;
math::LstmUnitGradFunctor<DeviceContext, T>::compute( phi::funcs::LstmUnitGradFunctor<DeviceContext, T>::compute(
device_ctx, lstm_value, lstm_grad, frame_size, cur_batch_size, device_ctx, lstm_value, lstm_grad, frame_size, cur_batch_size,
cell_clip, gate_act, cell_act, cand_act); cell_clip, gate_act, cell_act, cand_act);
...@@ -369,7 +369,7 @@ class LSTMGradKernel : public framework::OpKernel<T> { ...@@ -369,7 +369,7 @@ class LSTMGradKernel : public framework::OpKernel<T> {
} }
} }
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq; phi::funcs::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
if (in_g) { if (in_g) {
/* backward data */ /* backward data */
in_g->mutable_data<T>(ctx.GetPlace()); in_g->mutable_data<T>(ctx.GetPlace());
......
...@@ -18,12 +18,12 @@ limitations under the License. */ ...@@ -18,12 +18,12 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/lstm_compute.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/transform.h" #include "paddle/fluid/platform/transform.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/detail/activation_functions.h"
#include "paddle/phi/kernels/funcs/lstm_compute.h"
#include "paddle/phi/kernels/funcs/sequence2batch.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -72,7 +72,7 @@ inline void ReorderInitState(const DeviceContext& ctx, ...@@ -72,7 +72,7 @@ inline void ReorderInitState(const DeviceContext& ctx,
const framework::Tensor& src, const framework::Tensor& src,
framework::Vector<size_t> index, framework::Vector<size_t> index,
framework::Tensor* dst, bool indexed_src) { framework::Tensor* dst, bool indexed_src) {
math::CopyMatrixRowsFunctor<DeviceContext, T> row_shuffle; phi::funcs::CopyMatrixRowsFunctor<DeviceContext, T> row_shuffle;
dst->mutable_data<T>(src.dims(), ctx.GetPlace()); dst->mutable_data<T>(src.dims(), ctx.GetPlace());
row_shuffle(ctx, src, index, dst, indexed_src); row_shuffle(ctx, src, index, dst, indexed_src);
} }
...@@ -81,15 +81,15 @@ template <typename DeviceContext, typename T> ...@@ -81,15 +81,15 @@ template <typename DeviceContext, typename T>
class LSTMPKernel : public framework::OpKernel<T> { class LSTMPKernel : public framework::OpKernel<T> {
public: public:
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Y>
void ActCompute(const math::detail::ActivationType act_type, const Device& d, void ActCompute(const phi::funcs::detail::ActivationType act_type,
X x, Y y, platform::Place place) const { const Device& d, X x, Y y, platform::Place place) const {
if (act_type == math::detail::ActivationType::kIdentity) { if (act_type == phi::funcs::detail::ActivationType::kIdentity) {
y.device(d) = x; y.device(d) = x;
} else if (act_type == math::detail::ActivationType::kSigmoid) { } else if (act_type == phi::funcs::detail::ActivationType::kSigmoid) {
SigmoidFunctor<T>()(d, x, y); SigmoidFunctor<T>()(d, x, y);
} else if (act_type == math::detail::ActivationType::kTanh) { } else if (act_type == phi::funcs::detail::ActivationType::kTanh) {
TanhFunctor<T>()(d, x, y); TanhFunctor<T>()(d, x, y);
} else if (act_type == math::detail::ActivationType::kReLU) { } else if (act_type == phi::funcs::detail::ActivationType::kReLU) {
if (place == platform::CPUPlace()) if (place == platform::CPUPlace())
ReluCPUFunctor<T>()(d, x, y); ReluCPUFunctor<T>()(d, x, y);
else else
...@@ -120,7 +120,7 @@ class LSTMPKernel : public framework::OpKernel<T> { ...@@ -120,7 +120,7 @@ class LSTMPKernel : public framework::OpKernel<T> {
cell_out->mutable_data<T>(ctx.GetPlace()); cell_out->mutable_data<T>(ctx.GetPlace());
bool is_reverse = ctx.Attr<bool>("is_reverse"); bool is_reverse = ctx.Attr<bool>("is_reverse");
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch; phi::funcs::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
auto& device_ctx = ctx.template device_context<DeviceContext>(); auto& device_ctx = ctx.template device_context<DeviceContext>();
to_batch(device_ctx, *input, batch_gate, true, is_reverse); to_batch(device_ctx, *input, batch_gate, true, is_reverse);
...@@ -137,7 +137,7 @@ class LSTMPKernel : public framework::OpKernel<T> { ...@@ -137,7 +137,7 @@ class LSTMPKernel : public framework::OpKernel<T> {
add_bias(device_ctx, *batch_gate, gate_bias, batch_gate); add_bias(device_ctx, *batch_gate, gate_bias, batch_gate);
} }
math::LstmMetaValue<T> lstmp_value; phi::funcs::LstmMetaValue<T> lstmp_value;
if (bias && ctx.Attr<bool>("use_peepholes")) { if (bias && ctx.Attr<bool>("use_peepholes")) {
T* bias_data = const_cast<T*>(bias->data<T>()); T* bias_data = const_cast<T*>(bias->data<T>());
// the code style in LstmpMetaValue will be updated later. // the code style in LstmpMetaValue will be updated later.
...@@ -176,13 +176,13 @@ class LSTMPKernel : public framework::OpKernel<T> { ...@@ -176,13 +176,13 @@ class LSTMPKernel : public framework::OpKernel<T> {
auto batch_starts = batch_gate->lod()[0]; auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1; size_t num_batch = batch_starts.size() - 1;
auto gate_act = math::detail::GetActivationType( auto gate_act = phi::funcs::detail::GetActivationType(
ctx.Attr<std::string>("gate_activation")); ctx.Attr<std::string>("gate_activation"));
auto cell_act = math::detail::GetActivationType( auto cell_act = phi::funcs::detail::GetActivationType(
ctx.Attr<std::string>("cell_activation")); ctx.Attr<std::string>("cell_activation"));
auto cand_act = math::detail::GetActivationType( auto cand_act = phi::funcs::detail::GetActivationType(
ctx.Attr<std::string>("candidate_activation")); ctx.Attr<std::string>("candidate_activation"));
auto proj_act = math::detail::GetActivationType( auto proj_act = phi::funcs::detail::GetActivationType(
ctx.Attr<std::string>("proj_activation")); ctx.Attr<std::string>("proj_activation"));
auto& place = *ctx.template device_context<DeviceContext>().eigen_device(); auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(device_ctx); auto blas = phi::funcs::GetBlas<DeviceContext, T>(device_ctx);
...@@ -222,13 +222,13 @@ class LSTMPKernel : public framework::OpKernel<T> { ...@@ -222,13 +222,13 @@ class LSTMPKernel : public framework::OpKernel<T> {
lstmp_value.output_value = hidden_t.data<T>(); lstmp_value.output_value = hidden_t.data<T>();
lstmp_value.state_value = cell_t.data<T>(); lstmp_value.state_value = cell_t.data<T>();
lstmp_value.state_active_value = cell_pre_act_t.data<T>(); lstmp_value.state_active_value = cell_pre_act_t.data<T>();
math::LstmUnitFunctor<DeviceContext, T>::compute( phi::funcs::LstmUnitFunctor<DeviceContext, T>::compute(
device_ctx, lstmp_value, frame_size, cur_batch_size, cell_clip, device_ctx, lstmp_value, frame_size, cur_batch_size, cell_clip,
gate_act, cell_act, cand_act); gate_act, cell_act, cand_act);
lstmp_value.prev_state_value = lstmp_value.state_value; lstmp_value.prev_state_value = lstmp_value.state_value;
blas.MatMul(hidden_t, false, *proj_weight, false, static_cast<T>(1.0), blas.MatMul(hidden_t, false, *proj_weight, false, static_cast<T>(1.0),
&proj_t, static_cast<T>(0.0)); &proj_t, static_cast<T>(0.0));
if (proj_act != math::detail::ActivationType::kIdentity) { if (proj_act != phi::funcs::detail::ActivationType::kIdentity) {
auto proj_t_dev = EigenMatrix<T>::From(proj_t); auto proj_t_dev = EigenMatrix<T>::From(proj_t);
ActCompute(cell_act, place, proj_t_dev, proj_t_dev, ctx.GetPlace()); ActCompute(cell_act, place, proj_t_dev, proj_t_dev, ctx.GetPlace());
} }
...@@ -242,7 +242,7 @@ class LSTMPKernel : public framework::OpKernel<T> { ...@@ -242,7 +242,7 @@ class LSTMPKernel : public framework::OpKernel<T> {
} }
} }
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq; phi::funcs::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
batch_proj.set_lod(batch_gate->lod()); batch_proj.set_lod(batch_gate->lod());
// restore the output hidden in LoDTensor from the batch hidden // restore the output hidden in LoDTensor from the batch hidden
to_seq(device_ctx, batch_proj, proj_out); to_seq(device_ctx, batch_proj, proj_out);
...@@ -257,16 +257,16 @@ template <typename DeviceContext, typename T> ...@@ -257,16 +257,16 @@ template <typename DeviceContext, typename T>
class LSTMPGradKernel : public framework::OpKernel<T> { class LSTMPGradKernel : public framework::OpKernel<T> {
public: public:
template <typename Device, typename X, typename Y, typename DX, typename DY> template <typename Device, typename X, typename Y, typename DX, typename DY>
void ActGradCompute(const math::detail::ActivationType act_type, void ActGradCompute(const phi::funcs::detail::ActivationType act_type,
const Device& d, X x, Y y, DX dx, DY dy) const { const Device& d, X x, Y y, DX dx, DY dy) const {
// x is dummy and won't be used even in Relu(use y instead) // x is dummy and won't be used even in Relu(use y instead)
if (act_type == math::detail::ActivationType::kIdentity) if (act_type == phi::funcs::detail::ActivationType::kIdentity)
dx.device(d) = dy; dx.device(d) = dy;
else if (act_type == math::detail::ActivationType::kSigmoid) else if (act_type == phi::funcs::detail::ActivationType::kSigmoid)
SigmoidGradFunctor<T>()(d, x, y, dy, dx); SigmoidGradFunctor<T>()(d, x, y, dy, dx);
else if (act_type == math::detail::ActivationType::kTanh) else if (act_type == phi::funcs::detail::ActivationType::kTanh)
TanhGradFunctor<T>()(d, x, y, dy, dx); TanhGradFunctor<T>()(d, x, y, dy, dx);
else if (act_type == math::detail::ActivationType::kReLU) else if (act_type == phi::funcs::detail::ActivationType::kReLU)
ReluGradFunctor<T>()(d, x, y, dy, dx); ReluGradFunctor<T>()(d, x, y, dy, dx);
else else
PADDLE_THROW( PADDLE_THROW(
...@@ -340,7 +340,7 @@ class LSTMPGradKernel : public framework::OpKernel<T> { ...@@ -340,7 +340,7 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
"but received %d in LSTMP@Grad operator.", "but received %d in LSTMP@Grad operator.",
frame_size, out_dims[1])); frame_size, out_dims[1]));
math::LstmMetaValue<T> lstmp_value; phi::funcs::LstmMetaValue<T> lstmp_value;
if (bias && ctx.Attr<bool>("use_peepholes")) { if (bias && ctx.Attr<bool>("use_peepholes")) {
T* bias_data = const_cast<T*>(bias->data<T>()); T* bias_data = const_cast<T*>(bias->data<T>());
lstmp_value.check_ig = bias_data + 4 * frame_size; lstmp_value.check_ig = bias_data + 4 * frame_size;
...@@ -352,7 +352,7 @@ class LSTMPGradKernel : public framework::OpKernel<T> { ...@@ -352,7 +352,7 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
lstmp_value.check_og = nullptr; lstmp_value.check_og = nullptr;
} }
math::LstmMetaGrad<T> lstmp_grad; phi::funcs::LstmMetaGrad<T> lstmp_grad;
if (bias && bias_g) { if (bias && bias_g) {
bias_g->mutable_data<T>(ctx.GetPlace()); bias_g->mutable_data<T>(ctx.GetPlace());
...@@ -369,7 +369,7 @@ class LSTMPGradKernel : public framework::OpKernel<T> { ...@@ -369,7 +369,7 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
lstmp_grad.check_og_grad = nullptr; lstmp_grad.check_og_grad = nullptr;
} }
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch; phi::funcs::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
auto ToBatch = [&batch_gate, &to_batch]( auto ToBatch = [&batch_gate, &to_batch](
const DeviceContext& ctx, const framework::LoDTensor& src, const DeviceContext& ctx, const framework::LoDTensor& src,
...@@ -393,13 +393,13 @@ class LSTMPGradKernel : public framework::OpKernel<T> { ...@@ -393,13 +393,13 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
batch_gate_g.mutable_data<T>(batch_gate->dims(), ctx.GetPlace()); batch_gate_g.mutable_data<T>(batch_gate->dims(), ctx.GetPlace());
batch_gate_g.set_lod(batch_gate->lod()); batch_gate_g.set_lod(batch_gate->lod());
auto gate_act = math::detail::GetActivationType( auto gate_act = phi::funcs::detail::GetActivationType(
ctx.Attr<std::string>("gate_activation")); ctx.Attr<std::string>("gate_activation"));
auto cell_act = math::detail::GetActivationType( auto cell_act = phi::funcs::detail::GetActivationType(
ctx.Attr<std::string>("cell_activation")); ctx.Attr<std::string>("cell_activation"));
auto cand_act = math::detail::GetActivationType( auto cand_act = phi::funcs::detail::GetActivationType(
ctx.Attr<std::string>("candidate_activation")); ctx.Attr<std::string>("candidate_activation"));
auto proj_act = math::detail::GetActivationType( auto proj_act = phi::funcs::detail::GetActivationType(
ctx.Attr<std::string>("proj_activation")); ctx.Attr<std::string>("proj_activation"));
auto& place = *ctx.template device_context<DeviceContext>().eigen_device(); auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
...@@ -423,7 +423,7 @@ class LSTMPGradKernel : public framework::OpKernel<T> { ...@@ -423,7 +423,7 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
_ClipGradFunctor<T>(-1.0 * proj_clip, proj_clip)); _ClipGradFunctor<T>(-1.0 * proj_clip, proj_clip));
} }
if (proj_act != math::detail::ActivationType::kIdentity) { if (proj_act != phi::funcs::detail::ActivationType::kIdentity) {
auto cur_proj_dev = EigenMatrix<T>::From(cur_proj); auto cur_proj_dev = EigenMatrix<T>::From(cur_proj);
auto proj_g_dev = EigenMatrix<T>::From(proj_g); auto proj_g_dev = EigenMatrix<T>::From(proj_g);
ActGradCompute(cell_act, place, cur_proj_dev, cur_proj_dev, proj_g_dev, ActGradCompute(cell_act, place, cur_proj_dev, cur_proj_dev, proj_g_dev,
...@@ -470,7 +470,7 @@ class LSTMPGradKernel : public framework::OpKernel<T> { ...@@ -470,7 +470,7 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
lstmp_value.output_value = nullptr; lstmp_value.output_value = nullptr;
lstmp_grad.state_active_grad = nullptr; lstmp_grad.state_active_grad = nullptr;
math::LstmUnitGradFunctor<DeviceContext, T>::compute( phi::funcs::LstmUnitGradFunctor<DeviceContext, T>::compute(
device_ctx, lstmp_value, lstmp_grad, frame_size, cur_batch_size, device_ctx, lstmp_value, lstmp_grad, frame_size, cur_batch_size,
cell_clip, gate_act, cell_act, cand_act); cell_clip, gate_act, cell_act, cand_act);
...@@ -503,7 +503,7 @@ class LSTMPGradKernel : public framework::OpKernel<T> { ...@@ -503,7 +503,7 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
} }
} }
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq; phi::funcs::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
if (in_g) { if (in_g) {
/* backward data */ /* backward data */
in_g->mutable_data<T>(ctx.GetPlace()); in_g->mutable_data<T>(ctx.GetPlace());
......
add_subdirectory(detail)
if (WITH_ASCEND_CL) if (WITH_ASCEND_CL)
cc_library(beam_search_npu SRCS beam_search_npu.cc DEPS npu_op_runner) cc_library(beam_search_npu SRCS beam_search_npu.cc DEPS npu_op_runner)
endif() endif()
...@@ -18,8 +16,7 @@ math_library(im2col) ...@@ -18,8 +16,7 @@ math_library(im2col)
math_library(sample_prob) math_library(sample_prob)
math_library(sampler DEPS generator) math_library(sampler DEPS generator)
math_library(gru_compute DEPS activation_functions math_function) # math_library(math_function DEPS blas dense_tensor tensor)
math_library(lstm_compute DEPS activation_functions)
math_library(maxouting) math_library(maxouting)
math_library(pooling) math_library(pooling)
...@@ -29,7 +26,6 @@ else() ...@@ -29,7 +26,6 @@ else()
math_library(selected_rows_functor DEPS selected_rows_utils math_function blas) math_library(selected_rows_functor DEPS selected_rows_utils math_function blas)
endif() endif()
math_library(sequence2batch)
math_library(sequence_padding) math_library(sequence_padding)
math_library(sequence_pooling DEPS math_function jit_kernel_helper) math_library(sequence_pooling DEPS math_function jit_kernel_helper)
math_library(sequence_scale) math_library(sequence_scale)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/detail/lstm_gpu_kernel.h"
#include "paddle/fluid/operators/math/detail/lstm_kernel.h"
#include "paddle/fluid/operators/math/lstm_compute.h"
namespace paddle {
namespace operators {
namespace math {
template <class T>
struct LstmUnitFunctor<platform::CUDADeviceContext, T> {
static void compute(const platform::CUDADeviceContext& context,
LstmMetaValue<T> value, int frame_size, int batch_size,
T cell_clip, const detail::ActivationType& gate_act,
const detail::ActivationType& cell_act,
const detail::ActivationType& cand_act,
bool old_api_version = true) {
detail::gpu_lstm_forward<T>(context, detail::forward::lstm<T>(), value,
frame_size, batch_size, cell_clip, cand_act,
gate_act, cell_act);
}
};
template <class T>
struct LstmUnitGradFunctor<platform::CUDADeviceContext, T> {
static void compute(const platform::CUDADeviceContext& context,
LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frame_size, int batch_size, T cell_clip,
const detail::ActivationType& gate_act,
const detail::ActivationType& cell_act,
const detail::ActivationType& cand_act,
bool old_api_version = true) {
detail::gpu_lstm_backward(context, detail::backward::lstm<T>(), value, grad,
frame_size, batch_size, cell_clip, cand_act,
gate_act, cell_act);
}
};
template class LstmUnitFunctor<platform::CUDADeviceContext, float>;
template class LstmUnitFunctor<platform::CUDADeviceContext, double>;
template class LstmUnitGradFunctor<platform::CUDADeviceContext, float>;
template class LstmUnitGradFunctor<platform::CUDADeviceContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
...@@ -20,13 +20,13 @@ limitations under the License. */ ...@@ -20,13 +20,13 @@ limitations under the License. */
#include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/fc.h" #include "paddle/fluid/operators/math/fc.h"
#include "paddle/fluid/operators/math/gru_compute.h"
#include "paddle/fluid/operators/math/lstm_compute.h"
#include "paddle/fluid/operators/unique_op.h" #include "paddle/fluid/operators/unique_op.h"
#include "paddle/fluid/operators/utils.h" #include "paddle/fluid/operators/utils.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/detail/activation_functions.h"
#include "paddle/phi/kernels/funcs/gru_compute.h"
#include "paddle/phi/kernels/funcs/lstm_compute.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle { namespace paddle {
...@@ -100,7 +100,7 @@ struct Cell { ...@@ -100,7 +100,7 @@ struct Cell {
}; };
template <typename T, template <typename> class EigenActivationFunctor, template <typename T, template <typename> class EigenActivationFunctor,
math::detail::ActivationType act_type> phi::funcs::detail::ActivationType act_type>
struct SimpleRNNCell : Cell<T> { struct SimpleRNNCell : Cell<T> {
void operator()(const platform::CPUDeviceContext* device_ctx, Tensor* input, void operator()(const platform::CPUDeviceContext* device_ctx, Tensor* input,
const Tensor* weight_hh, const Tensor* init_h, const Tensor* weight_hh, const Tensor* init_h,
...@@ -148,7 +148,7 @@ struct GRUCell : Cell<T> { ...@@ -148,7 +148,7 @@ struct GRUCell : Cell<T> {
size_t frame_size = init_h->dims()[2]; size_t frame_size = init_h->dims()[2];
size_t batch_size = init_h->dims()[1]; size_t batch_size = init_h->dims()[1];
math::GRUMetaValue<T> gru_value; phi::funcs::GRUMetaValue<T> gru_value;
gru_value.gate_weight = weight_hh->data<T>(); gru_value.gate_weight = weight_hh->data<T>();
gru_value.state_weight = weight_hh->data<T>() + 2 * frame_size * frame_size; gru_value.state_weight = weight_hh->data<T>() + 2 * frame_size * frame_size;
gru_value.reset_bias = bias_hh->data<T>() + 2 * frame_size; gru_value.reset_bias = bias_hh->data<T>() + 2 * frame_size;
...@@ -158,10 +158,10 @@ struct GRUCell : Cell<T> { ...@@ -158,10 +158,10 @@ struct GRUCell : Cell<T> {
gru_value.output_value = output->data<T>(); gru_value.output_value = output->data<T>();
gru_value.prev_out_value = init_h->data<T>(); gru_value.prev_out_value = init_h->data<T>();
auto gate_act = math::detail::GetActivationType("sigmoid_v2"); auto gate_act = phi::funcs::detail::GetActivationType("sigmoid_v2");
auto cand_act = math::detail::GetActivationType("tanh_v2"); auto cand_act = phi::funcs::detail::GetActivationType("tanh_v2");
math::GRUUnitFunctorV2<platform::CPUDeviceContext, T>::compute( phi::funcs::GRUUnitFunctorV2<platform::CPUDeviceContext, T>::compute(
*device_ctx, gru_value, frame_size, batch_size, cand_act, gate_act); *device_ctx, gru_value, frame_size, batch_size, cand_act, gate_act);
} }
}; };
...@@ -184,14 +184,14 @@ struct LSTMCell : Cell<T> { ...@@ -184,14 +184,14 @@ struct LSTMCell : Cell<T> {
blas.MatMul(*init_h, mat_dim_a, *weight_hh, mat_dim_b, static_cast<T>(1.0), blas.MatMul(*init_h, mat_dim_a, *weight_hh, mat_dim_b, static_cast<T>(1.0),
input, static_cast<T>(1.0)); input, static_cast<T>(1.0));
math::LstmMetaValue<T> lstm_value; phi::funcs::LstmMetaValue<T> lstm_value;
lstm_value.check_ig = nullptr; lstm_value.check_ig = nullptr;
lstm_value.check_fg = nullptr; lstm_value.check_fg = nullptr;
lstm_value.check_og = nullptr; lstm_value.check_og = nullptr;
auto gate_act = math::detail::GetActivationType("sigmoid_v2"); auto gate_act = phi::funcs::detail::GetActivationType("sigmoid_v2");
auto cell_act = math::detail::GetActivationType("tanh_v2"); auto cell_act = phi::funcs::detail::GetActivationType("tanh_v2");
auto cand_act = math::detail::GetActivationType("tanh_v2"); auto cand_act = phi::funcs::detail::GetActivationType("tanh_v2");
size_t frame_size = init_h->dims()[2]; size_t frame_size = init_h->dims()[2];
size_t batch_size = init_h->dims()[1]; size_t batch_size = init_h->dims()[1];
...@@ -208,7 +208,7 @@ struct LSTMCell : Cell<T> { ...@@ -208,7 +208,7 @@ struct LSTMCell : Cell<T> {
lstm_value.state_value = last_c->data<T>(); lstm_value.state_value = last_c->data<T>();
lstm_value.state_active_value = last_c_act->data<T>(); lstm_value.state_active_value = last_c_act->data<T>();
T cell_clip = 0.0; T cell_clip = 0.0;
math::LstmUnitFunctor<platform::CPUDeviceContext, T>::compute( phi::funcs::LstmUnitFunctor<platform::CPUDeviceContext, T>::compute(
*device_ctx, lstm_value, frame_size, batch_size, cell_clip, gate_act, *device_ctx, lstm_value, frame_size, batch_size, cell_clip, gate_act,
cell_act, cand_act, false); cell_act, cand_act, false);
} }
...@@ -986,18 +986,18 @@ class RNNCPUKernel : public framework::OpKernel<T> { ...@@ -986,18 +986,18 @@ class RNNCPUKernel : public framework::OpKernel<T> {
seed, reserve_data); seed, reserve_data);
} else if (is_rnn_relu(ctx)) { } else if (is_rnn_relu(ctx)) {
gate_num = 1; gate_num = 1;
RnnFunc< RnnFunc<SimpleRNNCell<T, ReluCPUFunctor,
SimpleRNNCell<T, ReluCPUFunctor, math::detail::ActivationType::kReLU>, phi::funcs::detail::ActivationType::kReLU>,
Layer, SingleLayer, BidirLayer, T>( Layer, SingleLayer, BidirLayer, T>(
ctx, input, weight_list, pre_state[0], nullptr, sequence_length, ctx, input, weight_list, pre_state[0], nullptr, sequence_length,
state[0], nullptr, output, dropout_mask, num_layers, gate_num, state[0], nullptr, output, dropout_mask, num_layers, gate_num,
input_size, hidden_size, is_bidirec, mode, dropout_prob, is_test, input_size, hidden_size, is_bidirec, mode, dropout_prob, is_test,
seed, reserve_data); seed, reserve_data);
} else if (is_rnn_tanh(ctx)) { } else if (is_rnn_tanh(ctx)) {
gate_num = 1; gate_num = 1;
RnnFunc< RnnFunc<SimpleRNNCell<T, TanhFunctor,
SimpleRNNCell<T, TanhFunctor, math::detail::ActivationType::kTanhV2>, phi::funcs::detail::ActivationType::kTanhV2>,
Layer, SingleLayer, BidirLayer, T>( Layer, SingleLayer, BidirLayer, T>(
ctx, input, weight_list, pre_state[0], nullptr, sequence_length, ctx, input, weight_list, pre_state[0], nullptr, sequence_length,
state[0], nullptr, output, dropout_mask, num_layers, gate_num, state[0], nullptr, output, dropout_mask, num_layers, gate_num,
input_size, hidden_size, is_bidirec, mode, dropout_prob, is_test, input_size, hidden_size, is_bidirec, mode, dropout_prob, is_test,
...@@ -1014,14 +1014,14 @@ class RNNCPUKernel : public framework::OpKernel<T> { ...@@ -1014,14 +1014,14 @@ class RNNCPUKernel : public framework::OpKernel<T> {
}; };
template <typename T> template <typename T>
void create_lstm_value(math::LstmMetaValue<T>* lstm_value) { void create_lstm_value(phi::funcs::LstmMetaValue<T>* lstm_value) {
lstm_value->check_ig = nullptr; lstm_value->check_ig = nullptr;
lstm_value->check_fg = nullptr; lstm_value->check_fg = nullptr;
lstm_value->check_og = nullptr; lstm_value->check_og = nullptr;
} }
template <typename T> template <typename T>
void create_lstm_grad(math::LstmMetaGrad<T>* lstm_grad) { void create_lstm_grad(phi::funcs::LstmMetaGrad<T>* lstm_grad) {
lstm_grad->check_ig_grad = nullptr; lstm_grad->check_ig_grad = nullptr;
lstm_grad->check_fg_grad = nullptr; lstm_grad->check_fg_grad = nullptr;
lstm_grad->check_og_grad = nullptr; lstm_grad->check_og_grad = nullptr;
...@@ -1686,8 +1686,8 @@ struct GRUGradCell : GradCell<T> { ...@@ -1686,8 +1686,8 @@ struct GRUGradCell : GradCell<T> {
// zero pre_hidden // zero pre_hidden
phi::funcs::SetConstant<platform::CPUDeviceContext, T> zero; phi::funcs::SetConstant<platform::CPUDeviceContext, T> zero;
zero(device_ctx, grad_pre_hidden, static_cast<T>(0.0)); zero(device_ctx, grad_pre_hidden, static_cast<T>(0.0));
math::GRUMetaValue<T> gru_value; phi::funcs::GRUMetaValue<T> gru_value;
math::GRUMetaGrad<T> gru_grad; phi::funcs::GRUMetaGrad<T> gru_grad;
gru_value.gate_value = gate_tensor->data<T>(); gru_value.gate_value = gate_tensor->data<T>();
gru_value.prev_out_value = pre_hidden->data<T>(); gru_value.prev_out_value = pre_hidden->data<T>();
gru_value.reset_output_value = state_tensor->data<T>(); gru_value.reset_output_value = state_tensor->data<T>();
...@@ -1703,9 +1703,9 @@ struct GRUGradCell : GradCell<T> { ...@@ -1703,9 +1703,9 @@ struct GRUGradCell : GradCell<T> {
grad_weight_hh->data<T>() + 2 * frame_size * frame_size; grad_weight_hh->data<T>() + 2 * frame_size * frame_size;
gru_grad.bias_hh_grad = grad_bias_hh->data<T>(); gru_grad.bias_hh_grad = grad_bias_hh->data<T>();
auto act_gate = math::detail::GetActivationType("sigmoid_v2"); auto act_gate = phi::funcs::detail::GetActivationType("sigmoid_v2");
auto act_node = math::detail::GetActivationType("tanh_v2"); auto act_node = phi::funcs::detail::GetActivationType("tanh_v2");
math::GRUUnitGradFunctorV2<platform::CPUDeviceContext, T>::compute( phi::funcs::GRUUnitGradFunctorV2<platform::CPUDeviceContext, T>::compute(
device_ctx, gru_value, gru_grad, frame_size, batch_size, act_node, device_ctx, gru_value, gru_grad, frame_size, batch_size, act_node,
act_gate); act_gate);
...@@ -1738,8 +1738,8 @@ struct LSTMGradCell : GradCell<T> { ...@@ -1738,8 +1738,8 @@ struct LSTMGradCell : GradCell<T> {
backup_tensor<T>(context, &grad_pre_state_bak, grad_pre_state); backup_tensor<T>(context, &grad_pre_state_bak, grad_pre_state);
} }
math::LstmMetaValue<T> lstm_value; phi::funcs::LstmMetaValue<T> lstm_value;
math::LstmMetaGrad<T> lstm_grad; phi::funcs::LstmMetaGrad<T> lstm_grad;
create_lstm_value(&lstm_value); create_lstm_value(&lstm_value);
create_lstm_grad(&lstm_grad); create_lstm_grad(&lstm_grad);
lstm_value.gate_value = gate_tensor->data<T>(); lstm_value.gate_value = gate_tensor->data<T>();
...@@ -1755,12 +1755,12 @@ struct LSTMGradCell : GradCell<T> { ...@@ -1755,12 +1755,12 @@ struct LSTMGradCell : GradCell<T> {
lstm_value.output_value = nullptr; lstm_value.output_value = nullptr;
lstm_grad.state_active_grad = nullptr; lstm_grad.state_active_grad = nullptr;
auto gate_act = math::detail::GetActivationType("sigmoid_v2"); auto gate_act = phi::funcs::detail::GetActivationType("sigmoid_v2");
auto state_act = math::detail::GetActivationType("tanh_v2"); auto state_act = phi::funcs::detail::GetActivationType("tanh_v2");
auto cand_act = math::detail::GetActivationType("tanh_v2"); auto cand_act = phi::funcs::detail::GetActivationType("tanh_v2");
T cell_clip = 0.0; T cell_clip = 0.0;
math::LstmUnitGradFunctor<platform::CPUDeviceContext, T>::compute( phi::funcs::LstmUnitGradFunctor<platform::CPUDeviceContext, T>::compute(
device_ctx, lstm_value, lstm_grad, frame_size, batch_size, cell_clip, device_ctx, lstm_value, lstm_grad, frame_size, batch_size, cell_clip,
gate_act, state_act, cand_act, false); gate_act, state_act, cand_act, false);
this->update_pre_hidden_grad( this->update_pre_hidden_grad(
......
add_subdirectory(eigen) add_subdirectory(eigen)
add_subdirectory(blas) add_subdirectory(blas)
add_subdirectory(lapack) add_subdirectory(lapack)
add_subdirectory(detail)
math_library(math_function DEPS blas dense_tensor tensor) math_library(math_function DEPS blas dense_tensor tensor)
math_library(sequence2batch)
math_library(gru_compute DEPS activation_functions math_function)
math_library(lstm_compute DEPS activation_functions)
math_library(concat_and_split_functor DEPS dense_tensor) math_library(concat_and_split_functor DEPS dense_tensor)
...@@ -19,9 +19,8 @@ limitations under the License. */ ...@@ -19,9 +19,8 @@ limitations under the License. */
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/hostdevice.h"
namespace paddle { namespace phi {
namespace operators { namespace funcs {
namespace math {
namespace detail { namespace detail {
#define SIGMOID_THRESHOLD_MIN -40.0 #define SIGMOID_THRESHOLD_MIN -40.0
...@@ -132,25 +131,35 @@ struct Active { ...@@ -132,25 +131,35 @@ struct Active {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
static DEVICE Active<float>::Act kActFloat[] = { static DEVICE Active<float>::Act kActFloat[] = {&forward::Sigmoid<float>,
&forward::Sigmoid<float>, &forward::SigmoidV2<float>, &forward::SigmoidV2<float>,
&forward::Relu<float>, &forward::Tanh<float>, &forward::Relu<float>,
&forward::TanhV2<float>, &forward::Identity<float>}; &forward::Tanh<float>,
&forward::TanhV2<float>,
&forward::Identity<float>};
static DEVICE Active<float>::ActGrad kActGradFloat[] = { static DEVICE Active<float>::ActGrad kActGradFloat[] = {
&backward::Sigmoid<float>, &backward::Sigmoid<float>, &backward::Sigmoid<float>,
&backward::Relu<float>, &backward::Tanh<float>, &backward::Sigmoid<float>,
&backward::Tanh<float>, &backward::Identity<float>}; &backward::Relu<float>,
&backward::Tanh<float>,
static DEVICE Active<double>::Act kActDouble[] = { &backward::Tanh<float>,
&forward::Sigmoid<double>, &forward::SigmoidV2<double>, &backward::Identity<float>};
&forward::Relu<double>, &forward::Tanh<double>,
&forward::TanhV2<double>, &forward::Identity<double>}; static DEVICE Active<double>::Act kActDouble[] = {&forward::Sigmoid<double>,
&forward::SigmoidV2<double>,
&forward::Relu<double>,
&forward::Tanh<double>,
&forward::TanhV2<double>,
&forward::Identity<double>};
static DEVICE Active<double>::ActGrad kActGradDouble[] = { static DEVICE Active<double>::ActGrad kActGradDouble[] = {
&backward::Sigmoid<double>, &backward::Sigmoid<double>, &backward::Sigmoid<double>,
&backward::Relu<double>, &backward::Tanh<double>, &backward::Sigmoid<double>,
&backward::Tanh<double>, &backward::Identity<double>}; &backward::Relu<double>,
&backward::Tanh<double>,
&backward::Tanh<double>,
&backward::Identity<double>};
namespace forward { namespace forward {
inline DEVICE float activation(float a, int index) { inline DEVICE float activation(float a, int index) {
...@@ -287,13 +296,19 @@ __m256 Identity(const __m256 a, const __m256 b); ...@@ -287,13 +296,19 @@ __m256 Identity(const __m256 a, const __m256 b);
} // namespace avx } // namespace avx
} // namespace backward } // namespace backward
static Active<__m256>::Act kActAvx[] = { static Active<__m256>::Act kActAvx[] = {&forward::avx::Sigmoid,
&forward::avx::Sigmoid, &forward::avx::SigmoidV2, &forward::avx::Relu, &forward::avx::SigmoidV2,
&forward::avx::Tanh, &forward::avx::TanhV2, &forward::avx::Identity}; &forward::avx::Relu,
&forward::avx::Tanh,
&forward::avx::TanhV2,
&forward::avx::Identity};
static Active<__m256>::ActGrad kActGradAvx[] = { static Active<__m256>::ActGrad kActGradAvx[] = {&backward::avx::Sigmoid,
&backward::avx::Sigmoid, &backward::avx::Sigmoid, &backward::avx::Relu, &backward::avx::Sigmoid,
&backward::avx::Tanh, &backward::avx::Tanh, &backward::avx::Identity}; &backward::avx::Relu,
&backward::avx::Tanh,
&backward::avx::Tanh,
&backward::avx::Identity};
namespace forward { namespace forward {
inline __m256 activation(__m256 a, int index) { return kActAvx[index](a); } inline __m256 activation(__m256 a, int index) { return kActAvx[index](a); }
...@@ -308,6 +323,5 @@ inline __m256 activation(__m256 a, __m256 b, int index) { ...@@ -308,6 +323,5 @@ inline __m256 activation(__m256 a, __m256 b, int index) {
#endif #endif
} // namespace detail } // namespace detail
} // namespace math } // namespace funcs
} // namespace operators } // namespace phi
} // namespace paddle
...@@ -14,12 +14,11 @@ limitations under the License. */ ...@@ -14,12 +14,11 @@ limitations under the License. */
#ifdef __AVX__ #ifdef __AVX__
#include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/phi/kernels/funcs/detail/activation_functions.h"
#include "paddle/fluid/operators/math/detail/avx_mathfun.h" #include "paddle/phi/kernels/funcs/detail/avx_mathfun.h"
namespace paddle { namespace phi {
namespace operators { namespace funcs {
namespace math {
namespace detail { namespace detail {
__m256 Exp(__m256 a) { return exp256_ps(a); } __m256 Exp(__m256 a) { return exp256_ps(a); }
...@@ -77,8 +76,9 @@ namespace backward { ...@@ -77,8 +76,9 @@ namespace backward {
namespace avx { namespace avx {
__m256 Relu(const __m256 a, const __m256 b) { __m256 Relu(const __m256 a, const __m256 b) {
return _mm256_mul_ps( return _mm256_mul_ps(
a, _mm256_and_ps(_mm256_cmp_ps(b, _mm256_set1_ps(0.0f), _CMP_GT_OS), a,
_mm256_set1_ps(1.0f))); _mm256_and_ps(_mm256_cmp_ps(b, _mm256_set1_ps(0.0f), _CMP_GT_OS),
_mm256_set1_ps(1.0f)));
} }
__m256 Sigmoid(const __m256 a, const __m256 b) { __m256 Sigmoid(const __m256 a, const __m256 b) {
...@@ -96,8 +96,7 @@ __m256 Identity(const __m256 a, const __m256 b) { return a; } ...@@ -96,8 +96,7 @@ __m256 Identity(const __m256 a, const __m256 b) { return a; }
} // namespace backward } // namespace backward
} // namespace detail } // namespace detail
} // namespace math } // namespace funcs
} // namespace operators } // namespace phi
} // namespace paddle
#endif #endif
...@@ -49,9 +49,9 @@ typedef __m256 v8sf; // vector of 8 float (avx) ...@@ -49,9 +49,9 @@ typedef __m256 v8sf; // vector of 8 float (avx)
typedef __m256i v8si; // vector of 8 int (avx) typedef __m256i v8si; // vector of 8 int (avx)
typedef __m128i v4si; // vector of 8 int (avx) typedef __m128i v4si; // vector of 8 int (avx)
#define _PI32AVX_CONST(Name, Val) \ #define _PI32AVX_CONST(Name, Val) \
static const ALIGN32_BEG int _pi32avx_##Name[4] ALIGN32_END = {Val, Val, \ static const ALIGN32_BEG int _pi32avx_##Name[4] ALIGN32_END = { \
Val, Val} Val, Val, Val, Val}
_PI32AVX_CONST(1, 1); _PI32AVX_CONST(1, 1);
_PI32AVX_CONST(inv1, ~1); _PI32AVX_CONST(inv1, ~1);
......
...@@ -14,14 +14,13 @@ limitations under the License. */ ...@@ -14,14 +14,13 @@ limitations under the License. */
#pragma once #pragma once
#include <type_traits> #include <type_traits>
#include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/gru_compute.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/kernels/funcs/detail/activation_functions.h"
#include "paddle/phi/kernels/funcs/gru_compute.h"
namespace paddle { namespace phi {
namespace operators { namespace funcs {
namespace math {
namespace detail { namespace detail {
/* /*
...@@ -30,9 +29,11 @@ namespace detail { ...@@ -30,9 +29,11 @@ namespace detail {
*/ */
template <class OpResetOutput, bool is_batch, typename T> template <class OpResetOutput, bool is_batch, typename T>
__global__ void KeGruForwardResetOutput(OpResetOutput op_reset_output, __global__ void KeGruForwardResetOutput(OpResetOutput op_reset_output,
T *gate_value, T *reset_output_value, T *gate_value,
T *reset_output_value,
const T *prev_output_value, const T *prev_output_value,
int frame_size, int batch_size, int frame_size,
int batch_size,
ActivationType active_gate) { ActivationType active_gate) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return; if (frame_idx >= frame_size) return;
...@@ -55,8 +56,11 @@ __global__ void KeGruForwardResetOutput(OpResetOutput op_reset_output, ...@@ -55,8 +56,11 @@ __global__ void KeGruForwardResetOutput(OpResetOutput op_reset_output,
r_prev_out = prev_output_value[frame_idx]; r_prev_out = prev_output_value[frame_idx];
} }
op_reset_output(&r_value_update_gate, &r_value_reset_gate, &r_prev_out, op_reset_output(&r_value_update_gate,
&r_value_reset_output, active_gate); &r_value_reset_gate,
&r_prev_out,
&r_value_reset_output,
active_gate);
gate_value[frame_idx + frame_size * 0] = r_value_update_gate; gate_value[frame_idx + frame_size * 0] = r_value_update_gate;
gate_value[frame_idx + frame_size * 1] = r_value_reset_gate; gate_value[frame_idx + frame_size * 1] = r_value_reset_gate;
...@@ -68,10 +72,14 @@ __global__ void KeGruForwardResetOutput(OpResetOutput op_reset_output, ...@@ -68,10 +72,14 @@ __global__ void KeGruForwardResetOutput(OpResetOutput op_reset_output,
* grid(frame_blocks, batch_blocks) * grid(frame_blocks, batch_blocks)
*/ */
template <class OpFinalOutput, bool is_batch, typename T> template <class OpFinalOutput, bool is_batch, typename T>
__global__ void KeGruForwardFinalOutput( __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output,
OpFinalOutput op_final_output, T *gate_value, const T *prev_output_value, T *gate_value,
T *output_value, int frame_size, int batch_size, ActivationType active_node, const T *prev_output_value,
bool origin_mode) { T *output_value,
int frame_size,
int batch_size,
ActivationType active_node,
bool origin_mode) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return; if (frame_idx >= frame_size) return;
int batch_idx = 0; int batch_idx = 0;
...@@ -92,8 +100,12 @@ __global__ void KeGruForwardFinalOutput( ...@@ -92,8 +100,12 @@ __global__ void KeGruForwardFinalOutput(
r_prev_out = prev_output_value[frame_idx]; r_prev_out = prev_output_value[frame_idx];
} }
op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out, op_final_output(&r_value_update_gate,
&r_output, active_node, origin_mode); &r_value_frame_state,
&r_prev_out,
&r_output,
active_node,
origin_mode);
gate_value[frame_idx + frame_size * 2] = r_value_frame_state; gate_value[frame_idx + frame_size * 2] = r_value_frame_state;
output_value[frame_idx] = r_output; output_value[frame_idx] = r_output;
...@@ -106,7 +118,8 @@ __global__ void KeGruForwardFinalOutput( ...@@ -106,7 +118,8 @@ __global__ void KeGruForwardFinalOutput(
template <class T, int Tiled_size> template <class T, int Tiled_size>
__global__ void KeFastCollectiveGruGate(T *gate_value, __global__ void KeFastCollectiveGruGate(T *gate_value,
const T *prev_output_value, const T *prev_output_value,
const T *gate_weight, T *reset_output, const T *gate_weight,
T *reset_output,
int frame_size, int frame_size,
ActivationType active_node) { ActivationType active_node) {
T xt_0 = 0.0f; T xt_0 = 0.0f;
...@@ -164,9 +177,12 @@ __global__ void KeFastCollectiveGruGate(T *gate_value, ...@@ -164,9 +177,12 @@ __global__ void KeFastCollectiveGruGate(T *gate_value,
*/ */
template <class T, int Tiled_size> template <class T, int Tiled_size>
__global__ void KeFastCollectiveGruOut(const T *gate_weight, __global__ void KeFastCollectiveGruOut(const T *gate_weight,
const T *prev_out_value, T *output_value, const T *prev_out_value,
T *gate_value, T *reset_value, T *output_value,
int frame_size, ActivationType act_node, T *gate_value,
T *reset_value,
int frame_size,
ActivationType act_node,
bool origin_mode) { bool origin_mode) {
int COL = blockIdx.x * blockDim.x + threadIdx.x; int COL = blockIdx.x * blockDim.x + threadIdx.x;
...@@ -221,10 +237,14 @@ __global__ void KeFastCollectiveGruOut(const T *gate_weight, ...@@ -221,10 +237,14 @@ __global__ void KeFastCollectiveGruOut(const T *gate_weight,
* grid(frame_blocks, batch_blocks) * grid(frame_blocks, batch_blocks)
*/ */
template <class OpStateGrad, bool is_batch, typename T> template <class OpStateGrad, bool is_batch, typename T>
__global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value, __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad,
T *gate_grad, const T *prev_out_value, T *gate_value,
T *prev_out_grad, T *output_grad, T *gate_grad,
int frame_size, int batch_size, const T *prev_out_value,
T *prev_out_grad,
T *output_grad,
int frame_size,
int batch_size,
ActivationType active_node, ActivationType active_node,
bool origin_mode) { bool origin_mode) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
...@@ -254,9 +274,15 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value, ...@@ -254,9 +274,15 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value,
r_prev_out_grad = prev_out_grad[frame_idx]; r_prev_out_grad = prev_out_grad[frame_idx];
} }
op_state_grad(&r_update_gate_value, &r_update_gate_grad, &r_frame_state_value, op_state_grad(&r_update_gate_value,
&r_frame_state_grad, &r_prev_out_value, &r_prev_out_grad, &r_update_gate_grad,
&r_out_grad, active_node, origin_mode); &r_frame_state_value,
&r_frame_state_grad,
&r_prev_out_value,
&r_prev_out_grad,
&r_out_grad,
active_node,
origin_mode);
gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad; gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad;
gate_grad[frame_idx + frame_size * 2] = r_frame_state_grad; gate_grad[frame_idx + frame_size * 2] = r_frame_state_grad;
...@@ -270,10 +296,14 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value, ...@@ -270,10 +296,14 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value,
* grid(frame_blocks, batch_blocks) * grid(frame_blocks, batch_blocks)
*/ */
template <class OpResetGrad, bool is_batch, typename T> template <class OpResetGrad, bool is_batch, typename T>
__global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value, __global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad,
T *gate_grad, const T *prev_out_value, T *gate_value,
T *prev_out_grad, T *reset_output_grad, T *gate_grad,
int frame_size, int batch_size, const T *prev_out_value,
T *prev_out_grad,
T *reset_output_grad,
int frame_size,
int batch_size,
ActivationType active_gate) { ActivationType active_gate) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return; if (frame_idx >= frame_size) return;
...@@ -302,9 +332,14 @@ __global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value, ...@@ -302,9 +332,14 @@ __global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value,
r_reset_output_grad = reset_output_grad[frame_idx]; r_reset_output_grad = reset_output_grad[frame_idx];
} }
op_reset_grad(&r_update_gate_value, &r_update_gate_grad, &r_reset_gate_value, op_reset_grad(&r_update_gate_value,
&r_reset_gate_grad, &r_prev_out_value, &r_prev_out_grad, &r_update_gate_grad,
&r_reset_output_grad, active_gate); &r_reset_gate_value,
&r_reset_gate_grad,
&r_prev_out_value,
&r_prev_out_grad,
&r_reset_output_grad,
active_gate);
gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad; gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad;
gate_grad[frame_idx + frame_size * 1] = r_reset_gate_grad; gate_grad[frame_idx + frame_size * 1] = r_reset_gate_grad;
...@@ -313,6 +348,5 @@ __global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value, ...@@ -313,6 +348,5 @@ __global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value,
} }
} }
} // namespace detail } // namespace detail
} // namespace math } // namespace funcs
} // namespace operators } // namespace phi
} // namespace paddle
...@@ -14,13 +14,12 @@ limitations under the License. */ ...@@ -14,13 +14,12 @@ limitations under the License. */
#pragma once #pragma once
#include <type_traits> #include <type_traits>
#include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/funcs/detail/activation_functions.h"
// TODO(guosheng): refine code style in gru_kernel // TODO(guosheng): refine code style in gru_kernel
namespace paddle { namespace phi {
namespace operators { namespace funcs {
namespace math {
namespace detail { namespace detail {
namespace forward { namespace forward {
...@@ -28,8 +27,10 @@ namespace forward { ...@@ -28,8 +27,10 @@ namespace forward {
template <typename T> template <typename T>
class gru_resetOutput { class gru_resetOutput {
public: public:
HOSTDEVICE void operator()(T *value_update_gate, T *value_reset_gate, HOSTDEVICE void operator()(T *value_update_gate,
T *prev_out, T *value_reset_output, T *value_reset_gate,
T *prev_out,
T *value_reset_output,
ActivationType act_gate, ActivationType act_gate,
T *value_reset_bias = nullptr, T *value_reset_bias = nullptr,
bool old_version = true) { bool old_version = true) {
...@@ -48,7 +49,8 @@ class gru_resetOutput { ...@@ -48,7 +49,8 @@ class gru_resetOutput {
#else #else
static const bool avx = true; static const bool avx = true;
HOSTDEVICE void operator()(__m256 *value_update_gate, HOSTDEVICE void operator()(__m256 *value_update_gate,
__m256 *value_reset_gate, __m256 *prev_out, __m256 *value_reset_gate,
__m256 *prev_out,
__m256 *value_reset_output, __m256 *value_reset_output,
ActivationType act_gate, ActivationType act_gate,
__m256 *value_reset_bias = nullptr, __m256 *value_reset_bias = nullptr,
...@@ -71,9 +73,12 @@ class gru_resetOutput { ...@@ -71,9 +73,12 @@ class gru_resetOutput {
template <typename T> template <typename T>
class gru_finalOutput { class gru_finalOutput {
public: public:
HOSTDEVICE void operator()(T *value_update_gate, T *value_frame_state, HOSTDEVICE void operator()(T *value_update_gate,
T *prev_out, T *value_output, T *value_frame_state,
ActivationType act_input, bool origin_mode) { T *prev_out,
T *value_output,
ActivationType act_input,
bool origin_mode) {
*value_frame_state = activation(*value_frame_state, act_input); *value_frame_state = activation(*value_frame_state, act_input);
if (origin_mode) { if (origin_mode) {
*value_output = ((*value_update_gate) * (*prev_out)) + *value_output = ((*value_update_gate) * (*prev_out)) +
...@@ -90,8 +95,10 @@ class gru_finalOutput { ...@@ -90,8 +95,10 @@ class gru_finalOutput {
#else #else
static const bool avx = true; static const bool avx = true;
HOSTDEVICE void operator()(__m256 *value_update_gate, HOSTDEVICE void operator()(__m256 *value_update_gate,
__m256 *value_frame_state, __m256 *prev_out, __m256 *value_frame_state,
__m256 *value_output, ActivationType act_input, __m256 *prev_out,
__m256 *value_output,
ActivationType act_input,
bool origin_mode) { bool origin_mode) {
*value_frame_state = activation(*value_frame_state, act_input); *value_frame_state = activation(*value_frame_state, act_input);
if (origin_mode) { if (origin_mode) {
...@@ -116,10 +123,14 @@ namespace backward { ...@@ -116,10 +123,14 @@ namespace backward {
template <typename T> template <typename T>
class gru_stateGrad { class gru_stateGrad {
public: public:
HOSTDEVICE void operator()(T *value_update_gate, T *grad_update_gate, HOSTDEVICE void operator()(T *value_update_gate,
T *value_frame_state, T *grad_frame_state, T *grad_update_gate,
T *value_prev_out, T *grad_prev_out, T *value_frame_state,
T *grad_output, ActivationType act_input, T *grad_frame_state,
T *value_prev_out,
T *grad_prev_out,
T *grad_output,
ActivationType act_input,
bool origin_mode) { bool origin_mode) {
if (origin_mode) { if (origin_mode) {
*grad_update_gate = *grad_update_gate =
...@@ -127,14 +138,15 @@ class gru_stateGrad { ...@@ -127,14 +138,15 @@ class gru_stateGrad {
*grad_prev_out += (*grad_output * (*value_update_gate)); *grad_prev_out += (*grad_output * (*value_update_gate));
*grad_frame_state = activation( *grad_frame_state = activation(
*grad_output * (static_cast<T>(1.0) - (*value_update_gate)), *grad_output * (static_cast<T>(1.0) - (*value_update_gate)),
*value_frame_state, act_input); *value_frame_state,
act_input);
} else { } else {
*grad_update_gate = *grad_update_gate =
(*grad_output) * ((*value_frame_state) - (*value_prev_out)); (*grad_output) * ((*value_frame_state) - (*value_prev_out));
*grad_prev_out += *grad_prev_out +=
(*grad_output * (static_cast<T>(1.0) - *value_update_gate)); (*grad_output * (static_cast<T>(1.0) - *value_update_gate));
*grad_frame_state = activation(*grad_output * (*value_update_gate), *grad_frame_state = activation(
*value_frame_state, act_input); *grad_output * (*value_update_gate), *value_frame_state, act_input);
} }
} }
#if !defined(__NVCC__) && !defined(__HIPCC___) // @{ Group GRU state grad #if !defined(__NVCC__) && !defined(__HIPCC___) // @{ Group GRU state grad
...@@ -145,28 +157,35 @@ class gru_stateGrad { ...@@ -145,28 +157,35 @@ class gru_stateGrad {
HOSTDEVICE void operator()(__m256 *value_update_gate, HOSTDEVICE void operator()(__m256 *value_update_gate,
__m256 *grad_update_gate, __m256 *grad_update_gate,
__m256 *value_frame_state, __m256 *value_frame_state,
__m256 *grad_frame_state, __m256 *value_prev_out, __m256 *grad_frame_state,
__m256 *grad_prev_out, __m256 *grad_output, __m256 *value_prev_out,
ActivationType act_input, bool origin_mode) { __m256 *grad_prev_out,
__m256 *grad_output,
ActivationType act_input,
bool origin_mode) {
if (origin_mode) { if (origin_mode) {
*grad_update_gate = _mm256_mul_ps( *grad_update_gate = _mm256_mul_ps(
*grad_output, _mm256_sub_ps(*value_prev_out, *value_frame_state)); *grad_output, _mm256_sub_ps(*value_prev_out, *value_frame_state));
*grad_prev_out = _mm256_add_ps( *grad_prev_out = _mm256_add_ps(
*grad_prev_out, _mm256_mul_ps(*grad_output, *value_update_gate)); *grad_prev_out, _mm256_mul_ps(*grad_output, *value_update_gate));
*grad_frame_state = activation( *grad_frame_state = activation(
_mm256_mul_ps(*grad_output, _mm256_sub_ps(_mm256_set1_ps(1.0f), _mm256_mul_ps(
*value_update_gate)), *grad_output,
*value_frame_state, act_input); _mm256_sub_ps(_mm256_set1_ps(1.0f), *value_update_gate)),
*value_frame_state,
act_input);
} else { } else {
*grad_update_gate = _mm256_mul_ps( *grad_update_gate = _mm256_mul_ps(
*grad_output, _mm256_sub_ps(*value_frame_state, *value_prev_out)); *grad_output, _mm256_sub_ps(*value_frame_state, *value_prev_out));
*grad_prev_out = _mm256_add_ps( *grad_prev_out = _mm256_add_ps(
*grad_prev_out, *grad_prev_out,
_mm256_mul_ps(*grad_output, _mm256_sub_ps(_mm256_set1_ps(1.0f), _mm256_mul_ps(
*value_update_gate))); *grad_output,
_mm256_sub_ps(_mm256_set1_ps(1.0f), *value_update_gate)));
*grad_frame_state = *grad_frame_state =
activation(_mm256_mul_ps(*grad_output, *value_update_gate), activation(_mm256_mul_ps(*grad_output, *value_update_gate),
*value_frame_state, act_input); *value_frame_state,
act_input);
} }
} }
#endif #endif
...@@ -176,10 +195,14 @@ class gru_stateGrad { ...@@ -176,10 +195,14 @@ class gru_stateGrad {
template <typename T> template <typename T>
class gru_resetGrad { class gru_resetGrad {
public: public:
HOSTDEVICE void operator()(T *value_update_gate, T *grad_update_gate, HOSTDEVICE void operator()(T *value_update_gate,
T *value_reset_gate, T *grad_reset_gate, T *grad_update_gate,
T *value_prev_out, T *grad_prev_out, T *value_reset_gate,
T *grad_reset_output, ActivationType act_gate) { T *grad_reset_gate,
T *value_prev_out,
T *grad_prev_out,
T *grad_reset_output,
ActivationType act_gate) {
*grad_reset_gate = (*grad_reset_output * (*value_prev_out)); *grad_reset_gate = (*grad_reset_output * (*value_prev_out));
*grad_prev_out += (*grad_reset_output * (*value_reset_gate)); *grad_prev_out += (*grad_reset_output * (*value_reset_gate));
*grad_update_gate = *grad_update_gate =
...@@ -193,9 +216,12 @@ class gru_resetGrad { ...@@ -193,9 +216,12 @@ class gru_resetGrad {
#else #else
static const bool avx = true; static const bool avx = true;
HOSTDEVICE void operator()(__m256 *value_update_gate, HOSTDEVICE void operator()(__m256 *value_update_gate,
__m256 *grad_update_gate, __m256 *value_reset_gate, __m256 *grad_update_gate,
__m256 *grad_reset_gate, __m256 *value_prev_out, __m256 *value_reset_gate,
__m256 *grad_prev_out, __m256 *grad_reset_output, __m256 *grad_reset_gate,
__m256 *value_prev_out,
__m256 *grad_prev_out,
__m256 *grad_reset_output,
ActivationType act_gate) { ActivationType act_gate) {
*grad_reset_gate = _mm256_mul_ps(*grad_reset_output, *value_prev_out); *grad_reset_gate = _mm256_mul_ps(*grad_reset_output, *value_prev_out);
*grad_prev_out = _mm256_add_ps( *grad_prev_out = _mm256_add_ps(
...@@ -211,23 +237,31 @@ class gru_resetGrad { ...@@ -211,23 +237,31 @@ class gru_resetGrad {
template <typename T> template <typename T>
class gru { class gru {
public: public:
HOSTDEVICE void operator()(T *value_reset_gate, T *grad_reset_gate, HOSTDEVICE void operator()(T *value_reset_gate,
T *value_update_gate, T *grad_update_gate, T *grad_reset_gate,
T *value_frame_state, T *grad_frame_state, T *value_update_gate,
T *value_prev_out, T *grad_prev_out, T *grad_update_gate,
T *grad_output, T *value_reset_output, T *value_frame_state,
T *grad_reset_output, ActivationType act_node, T *grad_frame_state,
T *value_prev_out,
T *grad_prev_out,
T *grad_output,
T *value_reset_output,
T *grad_reset_output,
ActivationType act_node,
ActivationType act_gate) { ActivationType act_gate) {
*grad_update_gate = *grad_update_gate =
activation((*grad_output) * ((*value_prev_out) - (*value_frame_state)), activation((*grad_output) * ((*value_prev_out) - (*value_frame_state)),
(*value_update_gate), act_gate); (*value_update_gate),
act_gate);
*grad_prev_out += (*grad_output * (*value_update_gate)); *grad_prev_out += (*grad_output * (*value_update_gate));
*grad_frame_state = *grad_frame_state =
activation(*grad_output * (static_cast<T>(1.0) - (*value_update_gate)), activation(*grad_output * (static_cast<T>(1.0) - (*value_update_gate)),
*value_frame_state, act_node); *value_frame_state,
act_node);
T reset_output = (*value_reset_output) / (*value_reset_gate); T reset_output = (*value_reset_output) / (*value_reset_gate);
*grad_reset_gate = activation(reset_output * (*grad_frame_state), *grad_reset_gate = activation(
*value_reset_gate, act_gate); reset_output * (*grad_frame_state), *value_reset_gate, act_gate);
*grad_reset_output = (*value_reset_gate) * (*grad_frame_state); *grad_reset_output = (*value_reset_gate) * (*grad_frame_state);
} }
#if !defined(__NVCC__) && !defined(__HIPCC___) // @{ Group GRU CPU #if !defined(__NVCC__) && !defined(__HIPCC___) // @{ Group GRU CPU
...@@ -235,29 +269,36 @@ class gru { ...@@ -235,29 +269,36 @@ class gru {
static const bool avx = false; static const bool avx = false;
#else #else
static const bool avx = true; static const bool avx = true;
HOSTDEVICE void operator()(__m256 *value_reset_gate, __m256 *grad_reset_gate, HOSTDEVICE void operator()(__m256 *value_reset_gate,
__m256 *grad_reset_gate,
__m256 *value_update_gate, __m256 *value_update_gate,
__m256 *grad_update_gate, __m256 *grad_update_gate,
__m256 *value_frame_state, __m256 *value_frame_state,
__m256 *grad_frame_state, __m256 *value_prev_out, __m256 *grad_frame_state,
__m256 *grad_prev_out, __m256 *grad_output, __m256 *value_prev_out,
__m256 *grad_prev_out,
__m256 *grad_output,
__m256 *value_reset_output, __m256 *value_reset_output,
__m256 *grad_reset_output, ActivationType act_node, __m256 *grad_reset_output,
ActivationType act_node,
ActivationType act_gate) { ActivationType act_gate) {
*grad_update_gate = activation( *grad_update_gate = activation(
_mm256_mul_ps(*grad_output, _mm256_mul_ps(*grad_output,
_mm256_sub_ps(*value_prev_out, *value_frame_state)), _mm256_sub_ps(*value_prev_out, *value_frame_state)),
*value_update_gate, act_gate); *value_update_gate,
act_gate);
*grad_prev_out = _mm256_add_ps( *grad_prev_out = _mm256_add_ps(
*grad_prev_out, _mm256_mul_ps(*grad_output, *value_update_gate)); *grad_prev_out, _mm256_mul_ps(*grad_output, *value_update_gate));
*grad_frame_state = activation( *grad_frame_state = activation(
_mm256_mul_ps(*grad_output, _mm256_mul_ps(*grad_output,
_mm256_sub_ps(_mm256_set1_ps(1.0f), *value_update_gate)), _mm256_sub_ps(_mm256_set1_ps(1.0f), *value_update_gate)),
*value_frame_state, act_node); *value_frame_state,
act_node);
__m256 reset_output = _mm256_div_ps(*value_reset_output, *value_reset_gate); __m256 reset_output = _mm256_div_ps(*value_reset_output, *value_reset_gate);
*grad_reset_gate = *grad_reset_gate =
activation(_mm256_mul_ps(reset_output, *grad_frame_state), activation(_mm256_mul_ps(reset_output, *grad_frame_state),
*value_reset_gate, act_gate); *value_reset_gate,
act_gate);
*grad_reset_output = _mm256_mul_ps(*value_reset_gate, *grad_frame_state); *grad_reset_output = _mm256_mul_ps(*value_reset_gate, *grad_frame_state);
} }
#endif #endif
...@@ -267,6 +308,5 @@ class gru { ...@@ -267,6 +308,5 @@ class gru {
} // namespace backward } // namespace backward
} // namespace detail } // namespace detail
} // namespace math } // namespace funcs
} // namespace operators } // namespace phi
} // namespace paddle
...@@ -16,8 +16,8 @@ limitations under the License. */ ...@@ -16,8 +16,8 @@ limitations under the License. */
#include <type_traits> #include <type_traits>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/phi/kernels/funcs/detail/activation_functions.h"
#include "paddle/fluid/operators/math/lstm_compute.h" #include "paddle/phi/kernels/funcs/lstm_compute.h"
#if defined(_WIN32) #if defined(_WIN32)
#if defined(__AVX2__) || defined(__AVX__) #if defined(__AVX2__) || defined(__AVX__)
...@@ -25,21 +25,23 @@ inline __m256 operator+=(__m256 a, __m256 b) { return _mm256_add_ps(a, b); } ...@@ -25,21 +25,23 @@ inline __m256 operator+=(__m256 a, __m256 b) { return _mm256_add_ps(a, b); }
#endif #endif
#endif #endif
namespace paddle { namespace phi {
namespace operators { namespace funcs {
namespace math {
namespace detail { namespace detail {
using Array1 = Eigen::DSizes<int64_t, 1>; using Array1 = Eigen::DSizes<int64_t, 1>;
template <typename T, int MajorType = Eigen::RowMajor, template <typename T,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>; using EigenVector = paddle::framework::EigenVector<T, MajorType, IndexType>;
#if !defined(__NVCC__) && !defined(__HIPCC___) // @{ Group LSTM CPU #if !defined(__NVCC__) && !defined(__HIPCC___) // @{ Group LSTM CPU
template <class T, class Op> template <class T, class Op>
void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, void naive_lstm_forward_one_sequence(Op op,
int frame_size, T cell_clip, phi::funcs::LstmMetaValue<T> value,
int frame_size,
T cell_clip,
ActivationType active_node, ActivationType active_node,
ActivationType active_gate, ActivationType active_gate,
ActivationType active_state, ActivationType active_state,
...@@ -79,9 +81,21 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -79,9 +81,21 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
r_prev_state = value.prev_state_value[i]; r_prev_state = value.prev_state_value[i];
} }
op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state, op(&r_value_in,
&r_state, &r_state_atv, &r_out, &r_checkI, &r_checkF, &r_checkO, &r_value_ig,
&cell_clip, active_node, active_gate, active_state); &r_value_fg,
&r_value_og,
&r_prev_state,
&r_state,
&r_state_atv,
&r_out,
&r_checkI,
&r_checkF,
&r_checkO,
&cell_clip,
active_node,
active_gate,
active_state);
value_in[i] = r_value_in; value_in[i] = r_value_in;
value_ig[i] = r_value_ig; value_ig[i] = r_value_ig;
...@@ -94,9 +108,12 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -94,9 +108,12 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
} }
template <class T, class Op> template <class T, class Op>
void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, void naive_lstm_backward_one_sequence(Op op,
LstmMetaGrad<T> grad, int frame_size, phi::funcs::LstmMetaValue<T> value,
T cell_clip, ActivationType active_node, phi::funcs::LstmMetaGrad<T> grad,
int frame_size,
T cell_clip,
ActivationType active_node,
ActivationType active_gate, ActivationType active_gate,
ActivationType active_state, ActivationType active_state,
bool old_api_version) { bool old_api_version) {
...@@ -157,11 +174,30 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -157,11 +174,30 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
r_prev_state = value.prev_state_value[i]; r_prev_state = value.prev_state_value[i];
} }
op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_grad_in, op(&r_value_in,
&r_grad_ig, &r_grad_fg, &r_grad_og, &r_prev_state, &r_prev_state_grad, &r_value_ig,
&r_state, &r_state_grad, &r_state_atv, &r_output_grad, &r_checkI, &r_value_fg,
&r_checkF, &r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad, &r_value_og,
&cell_clip, active_node, active_gate, active_state); &r_grad_in,
&r_grad_ig,
&r_grad_fg,
&r_grad_og,
&r_prev_state,
&r_prev_state_grad,
&r_state,
&r_state_grad,
&r_state_atv,
&r_output_grad,
&r_checkI,
&r_checkF,
&r_checkO,
&r_checkIGrad,
&r_checkFGrad,
&r_checkOGrad,
&cell_clip,
active_node,
active_gate,
active_state);
grad_in[i] = r_grad_in; grad_in[i] = r_grad_in;
grad_ig[i] = r_grad_ig; grad_ig[i] = r_grad_ig;
...@@ -179,8 +215,10 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -179,8 +215,10 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
} }
template <class T, class Op> template <class T, class Op>
void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, void avx_lstm_forward_one_sequence(Op op,
int frame_size, T cell_clip, phi::funcs::LstmMetaValue<T> value,
int frame_size,
T cell_clip,
ActivationType active_node, ActivationType active_node,
ActivationType active_gate, ActivationType active_gate,
ActivationType active_state, ActivationType active_state,
...@@ -226,9 +264,21 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -226,9 +264,21 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
(reinterpret_cast<__m256 const *>(value.prev_state_value))[i]; (reinterpret_cast<__m256 const *>(value.prev_state_value))[i];
} }
op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state, op(&r_value_in,
&r_state, &r_state_atv, &r_out, &r_checkI, &r_checkF, &r_checkO, &r_value_ig,
&cell_clip, active_node, active_gate, active_state); &r_value_fg,
&r_value_og,
&r_prev_state,
&r_state,
&r_state_atv,
&r_out,
&r_checkI,
&r_checkF,
&r_checkO,
&cell_clip,
active_node,
active_gate,
active_state);
value_in[i] = r_value_in; value_in[i] = r_value_in;
value_ig[i] = r_value_ig; value_ig[i] = r_value_ig;
...@@ -242,9 +292,12 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -242,9 +292,12 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
} }
template <class T, class Op> template <class T, class Op>
void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, void avx_lstm_backward_one_sequence(Op op,
LstmMetaGrad<T> grad, int frame_size, phi::funcs::LstmMetaValue<T> value,
T cell_clip, ActivationType active_node, phi::funcs::LstmMetaGrad<T> grad,
int frame_size,
T cell_clip,
ActivationType active_node,
ActivationType active_gate, ActivationType active_gate,
ActivationType active_state, ActivationType active_state,
bool old_api_version) { bool old_api_version) {
...@@ -311,11 +364,30 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -311,11 +364,30 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
(reinterpret_cast<__m256 const *>(value.prev_state_value))[i]; (reinterpret_cast<__m256 const *>(value.prev_state_value))[i];
} }
op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_grad_in, op(&r_value_in,
&r_grad_ig, &r_grad_fg, &r_grad_og, &r_prev_state, &r_prev_state_grad, &r_value_ig,
&r_state, &r_state_grad, &r_state_atv, &r_output_grad, &r_checkI, &r_value_fg,
&r_checkF, &r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad, &r_value_og,
&cell_clip, active_node, active_gate, active_state); &r_grad_in,
&r_grad_ig,
&r_grad_fg,
&r_grad_og,
&r_prev_state,
&r_prev_state_grad,
&r_state,
&r_state_grad,
&r_state_atv,
&r_output_grad,
&r_checkI,
&r_checkF,
&r_checkO,
&r_checkIGrad,
&r_checkFGrad,
&r_checkOGrad,
&cell_clip,
active_node,
active_gate,
active_state);
grad_in[i] = r_grad_in; grad_in[i] = r_grad_in;
grad_ig[i] = r_grad_ig; grad_ig[i] = r_grad_ig;
...@@ -338,8 +410,10 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -338,8 +410,10 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
} }
template <class T> template <class T>
void eigen_lstm_forward_one_sequence(const platform::CPUDeviceContext &context, void eigen_lstm_forward_one_sequence(
LstmMetaValue<T> value, int frame_size) { const paddle::platform::CPUDeviceContext &context,
phi::funcs::LstmMetaValue<T> value,
int frame_size) {
auto eigen_value_ig = auto eigen_value_ig =
typename EigenVector<T>::Type(value.gate_value, Array1(frame_size)); typename EigenVector<T>::Type(value.gate_value, Array1(frame_size));
auto eigen_value_fg = typename EigenVector<T>::Type( auto eigen_value_fg = typename EigenVector<T>::Type(
...@@ -356,10 +430,10 @@ void eigen_lstm_forward_one_sequence(const platform::CPUDeviceContext &context, ...@@ -356,10 +430,10 @@ void eigen_lstm_forward_one_sequence(const platform::CPUDeviceContext &context,
typename EigenVector<T>::Type(value.output_value, Array1(frame_size)); typename EigenVector<T>::Type(value.output_value, Array1(frame_size));
auto &place = *context.eigen_device(); auto &place = *context.eigen_device();
TanhFunctor<T>()(place, eigen_value_in, eigen_value_in); paddle::operators::TanhFunctor<T>()(place, eigen_value_in, eigen_value_in);
SigmoidFunctor<T>()(place, eigen_value_ig, eigen_value_ig); paddle::operators::SigmoidFunctor<T>()(place, eigen_value_ig, eigen_value_ig);
SigmoidFunctor<T>()(place, eigen_value_fg, eigen_value_fg); paddle::operators::SigmoidFunctor<T>()(place, eigen_value_fg, eigen_value_fg);
SigmoidFunctor<T>()(place, eigen_value_og, eigen_value_og); paddle::operators::SigmoidFunctor<T>()(place, eigen_value_og, eigen_value_og);
eigen_state.device(place) = eigen_value_in * eigen_value_ig; eigen_state.device(place) = eigen_value_in * eigen_value_ig;
if (value.prev_state_value) { if (value.prev_state_value) {
...@@ -368,14 +442,16 @@ void eigen_lstm_forward_one_sequence(const platform::CPUDeviceContext &context, ...@@ -368,14 +442,16 @@ void eigen_lstm_forward_one_sequence(const platform::CPUDeviceContext &context,
eigen_state.device(place) = eigen_state + eigen_prev_state * eigen_value_fg; eigen_state.device(place) = eigen_state + eigen_prev_state * eigen_value_fg;
} }
TanhFunctor<T>()(place, eigen_state, eigen_state_act); paddle::operators::TanhFunctor<T>()(place, eigen_state, eigen_state_act);
eigen_output.device(place) = eigen_value_og * eigen_state_act; eigen_output.device(place) = eigen_value_og * eigen_state_act;
} }
template <class T> template <class T>
void eigen_lstm_backward_one_sequence(const platform::CPUDeviceContext &context, void eigen_lstm_backward_one_sequence(
LstmMetaValue<T> value, const paddle::platform::CPUDeviceContext &context,
LstmMetaGrad<T> grad, int frame_size) { phi::funcs::LstmMetaValue<T> value,
phi::funcs::LstmMetaGrad<T> grad,
int frame_size) {
auto eigen_value_ig = auto eigen_value_ig =
typename EigenVector<T>::Type(value.gate_value, Array1(frame_size)); typename EigenVector<T>::Type(value.gate_value, Array1(frame_size));
auto eigen_value_fg = typename EigenVector<T>::Type( auto eigen_value_fg = typename EigenVector<T>::Type(
...@@ -401,23 +477,38 @@ void eigen_lstm_backward_one_sequence(const platform::CPUDeviceContext &context, ...@@ -401,23 +477,38 @@ void eigen_lstm_backward_one_sequence(const platform::CPUDeviceContext &context,
typename EigenVector<T>::Type(grad.state_grad, Array1(frame_size)); typename EigenVector<T>::Type(grad.state_grad, Array1(frame_size));
auto &place = *context.eigen_device(); auto &place = *context.eigen_device();
SigmoidGradFunctor<T>()(place, 1 /*useless*/, eigen_value_og, paddle::operators::SigmoidGradFunctor<T>()(
eigen_grad_output * eigen_state_act, eigen_grad_og); place,
1 /*useless*/,
eigen_value_og,
eigen_grad_output * eigen_state_act,
eigen_grad_og);
eigen_grad_state.device(place) = eigen_grad_state.device(place) =
eigen_grad_state + eigen_grad_state +
eigen_grad_output * eigen_value_og * eigen_grad_output * eigen_value_og *
(static_cast<T>(1) - eigen_state_act * eigen_state_act); (static_cast<T>(1) - eigen_state_act * eigen_state_act);
TanhGradFunctor<T>()(place, 1, eigen_value_in, paddle::operators::TanhGradFunctor<T>()(place,
eigen_grad_state * eigen_value_ig, eigen_grad_in); 1,
SigmoidGradFunctor<T>()(place, 1, eigen_value_ig, eigen_value_in,
eigen_grad_state * eigen_value_in, eigen_grad_ig); eigen_grad_state * eigen_value_ig,
eigen_grad_in);
paddle::operators::SigmoidGradFunctor<T>()(place,
1,
eigen_value_ig,
eigen_grad_state * eigen_value_in,
eigen_grad_ig);
if (value.prev_state_value) { if (value.prev_state_value) {
auto eigen_prev_state = typename EigenVector<T>::ConstType( auto eigen_prev_state = typename EigenVector<T>::ConstType(
value.prev_state_value, Array1(frame_size)); value.prev_state_value, Array1(frame_size));
SigmoidGradFunctor<T>()(place, 1, eigen_value_fg, paddle::operators::SigmoidGradFunctor<T>()(
eigen_grad_state * eigen_prev_state, eigen_grad_fg); place,
1,
eigen_value_fg,
eigen_grad_state * eigen_prev_state,
eigen_grad_fg);
} else { } else {
SigmoidGradFunctor<T>()(place, 1, eigen_value_fg, 0, eigen_grad_fg); paddle::operators::SigmoidGradFunctor<T>()(
place, 1, eigen_value_fg, 0, eigen_grad_fg);
} }
if (grad.prev_state_grad) { if (grad.prev_state_grad) {
auto eigen_grad_pre_state = auto eigen_grad_pre_state =
...@@ -427,42 +518,74 @@ void eigen_lstm_backward_one_sequence(const platform::CPUDeviceContext &context, ...@@ -427,42 +518,74 @@ void eigen_lstm_backward_one_sequence(const platform::CPUDeviceContext &context,
} }
template <class T, class Op> template <class T, class Op>
void cpu_lstm_forward(const platform::CPUDeviceContext &context, Op op, void cpu_lstm_forward(const paddle::platform::CPUDeviceContext &context,
LstmMetaValue<T> value, int frame_size, T cell_clip, Op op,
ActivationType active_node, ActivationType active_gate, phi::funcs::LstmMetaValue<T> value,
ActivationType active_state, bool old_api_version) { int frame_size,
T cell_clip,
ActivationType active_node,
ActivationType active_gate,
ActivationType active_state,
bool old_api_version) {
if (!old_api_version) { if (!old_api_version) {
eigen_lstm_forward_one_sequence<T>(context, value, frame_size); eigen_lstm_forward_one_sequence<T>(context, value, frame_size);
} else { } else {
if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same<T, float>::value)) { if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same<T, float>::value)) {
avx_lstm_forward_one_sequence<T>(op, value, frame_size, cell_clip, avx_lstm_forward_one_sequence<T>(op,
active_node, active_gate, active_state, value,
frame_size,
cell_clip,
active_node,
active_gate,
active_state,
old_api_version); old_api_version);
} else { } else {
naive_lstm_forward_one_sequence<T>(op, value, frame_size, cell_clip, naive_lstm_forward_one_sequence<T>(op,
active_node, active_gate, active_state, value,
frame_size,
cell_clip,
active_node,
active_gate,
active_state,
old_api_version); old_api_version);
} }
} }
} }
template <class T, class Op> template <class T, class Op>
void cpu_lstm_backward(const platform::CPUDeviceContext &context, Op op, void cpu_lstm_backward(const paddle::platform::CPUDeviceContext &context,
LstmMetaValue<T> value, LstmMetaGrad<T> grad, Op op,
int frame_size, T cell_clip, ActivationType active_node, phi::funcs::LstmMetaValue<T> value,
ActivationType active_gate, ActivationType active_state, phi::funcs::LstmMetaGrad<T> grad,
int frame_size,
T cell_clip,
ActivationType active_node,
ActivationType active_gate,
ActivationType active_state,
bool old_api_version) { bool old_api_version) {
if (!old_api_version) { if (!old_api_version) {
eigen_lstm_backward_one_sequence<T>(context, value, grad, frame_size); eigen_lstm_backward_one_sequence<T>(context, value, grad, frame_size);
} else { } else {
if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same<T, float>::value)) { if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same<T, float>::value)) {
avx_lstm_backward_one_sequence<T>(op, value, grad, frame_size, cell_clip, avx_lstm_backward_one_sequence<T>(op,
active_node, active_gate, active_state, value,
grad,
frame_size,
cell_clip,
active_node,
active_gate,
active_state,
old_api_version); old_api_version);
} else { } else {
naive_lstm_backward_one_sequence<T>(op, value, grad, frame_size, naive_lstm_backward_one_sequence<T>(op,
cell_clip, active_node, active_gate, value,
active_state, old_api_version); grad,
frame_size,
cell_clip,
active_node,
active_gate,
active_state,
old_api_version);
} }
} }
} }
...@@ -470,6 +593,5 @@ void cpu_lstm_backward(const platform::CPUDeviceContext &context, Op op, ...@@ -470,6 +593,5 @@ void cpu_lstm_backward(const platform::CPUDeviceContext &context, Op op,
#endif // @{ End Group LSTM CPU #endif // @{ End Group LSTM CPU
} // namespace detail } // namespace detail
} // namespace math } // namespace funcs
} // namespace operators } // namespace phi
} // namespace paddle
...@@ -15,14 +15,13 @@ limitations under the License. */ ...@@ -15,14 +15,13 @@ limitations under the License. */
#pragma once #pragma once
#include <type_traits> #include <type_traits>
#include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/lstm_compute.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/kernels/funcs/detail/activation_functions.h"
#include "paddle/phi/kernels/funcs/lstm_compute.h"
namespace paddle { namespace phi {
namespace operators { namespace funcs {
namespace math {
namespace detail { namespace detail {
/* /*
...@@ -30,8 +29,11 @@ namespace detail { ...@@ -30,8 +29,11 @@ namespace detail {
* grid(frame_blocks, batch_blocks) * grid(frame_blocks, batch_blocks)
*/ */
template <class T, class Op, bool is_batch> template <class T, class Op, bool is_batch>
__global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frame_size, __global__ void KeLstmForward(Op op,
int batch_size, T cell_clip, phi::funcs::LstmMetaValue<T> value,
int frame_size,
int batch_size,
T cell_clip,
ActivationType active_node, ActivationType active_node,
ActivationType active_gate, ActivationType active_gate,
ActivationType active_state) { ActivationType active_state) {
...@@ -71,9 +73,21 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frame_size, ...@@ -71,9 +73,21 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frame_size,
r_prev_state = value.prev_state_value[frame_idx]; r_prev_state = value.prev_state_value[frame_idx];
} }
op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state, op(&r_value_in,
&r_state, &r_state_atv, &r_out, &r_checkI, &r_checkF, &r_checkO, &r_value_ig,
&cell_clip, active_node, active_gate, active_state); &r_value_fg,
&r_value_og,
&r_prev_state,
&r_state,
&r_state_atv,
&r_out,
&r_checkI,
&r_checkF,
&r_checkO,
&cell_clip,
active_node,
active_gate,
active_state);
value.gate_value[frame_idx] = r_value_in; value.gate_value[frame_idx] = r_value_in;
value.gate_value[frame_idx + frame_size] = r_value_ig; value.gate_value[frame_idx + frame_size] = r_value_ig;
...@@ -90,9 +104,12 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frame_size, ...@@ -90,9 +104,12 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frame_size,
* grid(frame_blocks, batch_blocks) * grid(frame_blocks, batch_blocks)
*/ */
template <class T, class Op, bool is_batch> template <class T, class Op, bool is_batch>
__global__ void KeLstmBackward(Op op, LstmMetaValue<T> value, __global__ void KeLstmBackward(Op op,
LstmMetaGrad<T> grad, int frame_size, phi::funcs::LstmMetaValue<T> value,
int batch_size, T cell_clip, phi::funcs::LstmMetaGrad<T> grad,
int frame_size,
int batch_size,
T cell_clip,
ActivationType active_node, ActivationType active_node,
ActivationType active_gate, ActivationType active_gate,
ActivationType active_state) { ActivationType active_state) {
...@@ -147,11 +164,30 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value, ...@@ -147,11 +164,30 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
r_prev_state = value.prev_state_value[frame_idx]; r_prev_state = value.prev_state_value[frame_idx];
} }
op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_grad_in, &r_grad_ig, op(&r_value_in,
&r_grad_fg, &r_grad_og, &r_prev_state, &r_prev_state_grad, &r_state, &r_value_ig,
&r_state_grad, &r_state_atv, &r_output_grad, &r_checkI, &r_checkF, &r_value_fg,
&r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad, &cell_clip, &r_value_og,
active_node, active_gate, active_state); &r_grad_in,
&r_grad_ig,
&r_grad_fg,
&r_grad_og,
&r_prev_state,
&r_prev_state_grad,
&r_state,
&r_state_grad,
&r_state_atv,
&r_output_grad,
&r_checkI,
&r_checkF,
&r_checkO,
&r_checkIGrad,
&r_checkFGrad,
&r_checkOGrad,
&cell_clip,
active_node,
active_gate,
active_state);
grad.gate_grad[frame_idx] = r_grad_in; grad.gate_grad[frame_idx] = r_grad_in;
grad.gate_grad[frame_idx + frame_size] = r_grad_ig; grad.gate_grad[frame_idx + frame_size] = r_grad_ig;
...@@ -185,10 +221,15 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value, ...@@ -185,10 +221,15 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
} }
template <class T, class Op> template <class T, class Op>
void gpu_lstm_forward(const platform::DeviceContext& context, Op op, void gpu_lstm_forward(const paddle::platform::DeviceContext& context,
LstmMetaValue<T> value, int frame_size, int batch_size, Op op,
T cell_clip, ActivationType active_node, phi::funcs::LstmMetaValue<T> value,
ActivationType active_gate, ActivationType active_state) { int frame_size,
int batch_size,
T cell_clip,
ActivationType active_node,
ActivationType active_gate,
ActivationType active_state) {
dim3 threads; dim3 threads;
dim3 grid; dim3 grid;
if (batch_size == 1) { if (batch_size == 1) {
...@@ -203,25 +244,45 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op, ...@@ -203,25 +244,45 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op,
} }
auto stream = auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(context).stream(); reinterpret_cast<const paddle::platform::CUDADeviceContext&>(context)
.stream();
if (batch_size == 1) { if (batch_size == 1) {
KeLstmForward<T, Op, KeLstmForward<T,
Op,
/* is_batch= */ false><<<grid, threads, 0, stream>>>( /* is_batch= */ false><<<grid, threads, 0, stream>>>(
op, value, frame_size, batch_size, cell_clip, active_node, active_gate, op,
value,
frame_size,
batch_size,
cell_clip,
active_node,
active_gate,
active_state); active_state);
} else { } else {
KeLstmForward<T, Op, KeLstmForward<T,
Op,
/* is_batch= */ true><<<grid, threads, 0, stream>>>( /* is_batch= */ true><<<grid, threads, 0, stream>>>(
op, value, frame_size, batch_size, cell_clip, active_node, active_gate, op,
value,
frame_size,
batch_size,
cell_clip,
active_node,
active_gate,
active_state); active_state);
} }
} }
template <class T, class Op> template <class T, class Op>
void gpu_lstm_backward(const platform::DeviceContext& context, Op op, void gpu_lstm_backward(const paddle::platform::DeviceContext& context,
LstmMetaValue<T> value, LstmMetaGrad<T> grad, Op op,
int frame_size, int batch_size, T cell_clip, phi::funcs::LstmMetaValue<T> value,
ActivationType active_node, ActivationType active_gate, phi::funcs::LstmMetaGrad<T> grad,
int frame_size,
int batch_size,
T cell_clip,
ActivationType active_node,
ActivationType active_gate,
ActivationType active_state) { ActivationType active_state) {
dim3 threads; dim3 threads;
dim3 grid; dim3 grid;
...@@ -237,21 +298,37 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op, ...@@ -237,21 +298,37 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op,
} }
auto stream = auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(context).stream(); reinterpret_cast<const paddle::platform::CUDADeviceContext&>(context)
.stream();
if (batch_size == 1) { if (batch_size == 1) {
KeLstmBackward<T, Op, KeLstmBackward<T,
Op,
/* is_batch= */ false><<<grid, threads, 0, stream>>>( /* is_batch= */ false><<<grid, threads, 0, stream>>>(
op, value, grad, frame_size, batch_size, cell_clip, active_node, op,
active_gate, active_state); value,
grad,
frame_size,
batch_size,
cell_clip,
active_node,
active_gate,
active_state);
} else { } else {
KeLstmBackward<T, Op, KeLstmBackward<T,
Op,
/* is_batch= */ true><<<grid, threads, 0, stream>>>( /* is_batch= */ true><<<grid, threads, 0, stream>>>(
op, value, grad, frame_size, batch_size, cell_clip, active_node, op,
active_gate, active_state); value,
grad,
frame_size,
batch_size,
cell_clip,
active_node,
active_gate,
active_state);
} }
} }
} // namespace detail } // namespace detail
} // namespace math } // namespace funcs
} // namespace operators } // namespace phi
} // namespace paddle
...@@ -14,12 +14,11 @@ limitations under the License. */ ...@@ -14,12 +14,11 @@ limitations under the License. */
#pragma once #pragma once
#include <type_traits> #include <type_traits>
#include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/funcs/detail/activation_functions.h"
namespace paddle { namespace phi {
namespace operators { namespace funcs {
namespace math {
namespace detail { namespace detail {
namespace forward { namespace forward {
...@@ -27,9 +26,18 @@ namespace forward { ...@@ -27,9 +26,18 @@ namespace forward {
template <class T> template <class T>
class lstm { class lstm {
public: public:
HOSTDEVICE void operator()(T *value_in, T *value_ig, T *value_fg, T *value_og, HOSTDEVICE void operator()(T *value_in,
T *prev_state, T *state, T *state_atv, T *output, T *value_ig,
T *checkI, T *checkF, T *checkO, T *cell_clip, T *value_fg,
T *value_og,
T *prev_state,
T *state,
T *state_atv,
T *output,
T *checkI,
T *checkF,
T *checkO,
T *cell_clip,
ActivationType active_node, ActivationType active_node,
ActivationType active_gate, ActivationType active_gate,
ActivationType active_state) { ActivationType active_state) {
...@@ -57,11 +65,18 @@ class lstm { ...@@ -57,11 +65,18 @@ class lstm {
// Only float support AVX optimization // Only float support AVX optimization
static const bool avx = std::is_same<T, float>::value; static const bool avx = std::is_same<T, float>::value;
HOSTDEVICE void operator()(__m256 *value_in, __m256 *value_ig, HOSTDEVICE void operator()(__m256 *value_in,
__m256 *value_fg, __m256 *value_og, __m256 *value_ig,
__m256 *prev_state, __m256 *state, __m256 *value_fg,
__m256 *state_atv, __m256 *output, __m256 *checkI, __m256 *value_og,
__m256 *checkF, __m256 *checkO, T *cell_clip, __m256 *prev_state,
__m256 *state,
__m256 *state_atv,
__m256 *output,
__m256 *checkI,
__m256 *checkF,
__m256 *checkO,
T *cell_clip,
ActivationType active_node, ActivationType active_node,
ActivationType active_gate, ActivationType active_gate,
ActivationType active_state) { ActivationType active_state) {
...@@ -97,12 +112,27 @@ namespace backward { ...@@ -97,12 +112,27 @@ namespace backward {
template <class T> template <class T>
class lstm { class lstm {
public: public:
HOSTDEVICE void operator()(T *value_in, T *value_ig, T *value_fg, T *value_og, HOSTDEVICE void operator()(T *value_in,
T *grad_in, T *grad_ig, T *grad_fg, T *grad_og, T *value_ig,
T *prev_state, T *prev_state_grad, T *state, T *value_fg,
T *state_grad, T *state_atv, T *output_grad, T *value_og,
T *checkI, T *checkF, T *checkO, T *checkIGrad, T *grad_in,
T *checkFGrad, T *checkOGrad, T *cell_clip, T *grad_ig,
T *grad_fg,
T *grad_og,
T *prev_state,
T *prev_state_grad,
T *state,
T *state_grad,
T *state_atv,
T *output_grad,
T *checkI,
T *checkF,
T *checkO,
T *checkIGrad,
T *checkFGrad,
T *checkOGrad,
T *cell_clip,
ActivationType active_node, ActivationType active_node,
ActivationType active_gate, ActivationType active_gate,
ActivationType active_state) { ActivationType active_state) {
...@@ -138,17 +168,32 @@ class lstm { ...@@ -138,17 +168,32 @@ class lstm {
#else #else
// Only float support AVX optimization // Only float support AVX optimization
static const bool avx = std::is_same<T, float>::value; static const bool avx = std::is_same<T, float>::value;
HOSTDEVICE void operator()( HOSTDEVICE void operator()(__m256 *value_in,
__m256 *value_in, __m256 *value_ig, __m256 *value_fg, __m256 *value_og, __m256 *value_ig,
__m256 *grad_in, __m256 *grad_ig, __m256 *grad_fg, __m256 *grad_og, __m256 *value_fg,
__m256 *prev_state, __m256 *prev_state_grad, __m256 *state, __m256 *value_og,
__m256 *state_grad, __m256 *state_atv, __m256 *output_grad, __m256 *grad_in,
__m256 *checkI, __m256 *checkF, __m256 *checkO, __m256 *checkIGrad, __m256 *grad_ig,
__m256 *checkFGrad, __m256 *checkOGrad, T *cell_clip, __m256 *grad_fg,
ActivationType active_node, ActivationType active_gate, __m256 *grad_og,
ActivationType active_state) { __m256 *prev_state,
*grad_og = activation(_mm256_mul_ps(*output_grad, *state_atv), *value_og, __m256 *prev_state_grad,
active_gate); __m256 *state,
__m256 *state_grad,
__m256 *state_atv,
__m256 *output_grad,
__m256 *checkI,
__m256 *checkF,
__m256 *checkO,
__m256 *checkIGrad,
__m256 *checkFGrad,
__m256 *checkOGrad,
T *cell_clip,
ActivationType active_node,
ActivationType active_gate,
ActivationType active_state) {
*grad_og = activation(
_mm256_mul_ps(*output_grad, *state_atv), *value_og, active_gate);
if (*cell_clip > 0.0f) { if (*cell_clip > 0.0f) {
T *state_ = reinterpret_cast<T *>(state); T *state_ = reinterpret_cast<T *>(state);
if (*state_ >= (*cell_clip) || *state_ <= (0.0f - (*cell_clip))) { if (*state_ >= (*cell_clip) || *state_ <= (0.0f - (*cell_clip))) {
...@@ -156,18 +201,19 @@ class lstm { ...@@ -156,18 +201,19 @@ class lstm {
} else { } else {
*state_grad = *state_grad =
_mm256_add_ps(activation(_mm256_mul_ps(*output_grad, *value_og), _mm256_add_ps(activation(_mm256_mul_ps(*output_grad, *value_og),
*state_atv, active_state), *state_atv,
active_state),
*state_grad); *state_grad);
*state_grad = *state_grad =
_mm256_add_ps(_mm256_mul_ps(*grad_og, *checkO), *state_grad); _mm256_add_ps(_mm256_mul_ps(*grad_og, *checkO), *state_grad);
} }
} }
*grad_in = activation(_mm256_mul_ps(*state_grad, *value_ig), *value_in, *grad_in = activation(
active_node); _mm256_mul_ps(*state_grad, *value_ig), *value_in, active_node);
*grad_ig = activation(_mm256_mul_ps(*state_grad, *value_in), *value_ig, *grad_ig = activation(
active_gate); _mm256_mul_ps(*state_grad, *value_in), *value_ig, active_gate);
*grad_fg = activation(_mm256_mul_ps(*state_grad, *prev_state), *value_fg, *grad_fg = activation(
active_gate); _mm256_mul_ps(*state_grad, *prev_state), *value_fg, active_gate);
*prev_state_grad = _mm256_add_ps(_mm256_mul_ps(*grad_ig, *checkI), *prev_state_grad = _mm256_add_ps(_mm256_mul_ps(*grad_ig, *checkI),
_mm256_mul_ps(*grad_fg, *checkF)); _mm256_mul_ps(*grad_fg, *checkF));
*prev_state_grad = *prev_state_grad =
...@@ -183,6 +229,5 @@ class lstm { ...@@ -183,6 +229,5 @@ class lstm {
} // namespace backward } // namespace backward
} // namespace detail } // namespace detail
} // namespace math } // namespace funcs
} // namespace operators } // namespace phi
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/funcs/gru_compute.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/detail/gru_cpu_kernel.h"
#include "paddle/phi/kernels/funcs/detail/gru_kernel.h"
namespace phi {
namespace funcs {
template <typename T>
struct GRUUnitFunctor<paddle::platform::CPUDeviceContext, T> {
static void compute(const paddle::platform::CPUDeviceContext &context,
GRUMetaValue<T> value,
int frame_size,
int batch_size,
const phi::funcs::detail::ActivationType active_node,
const phi::funcs::detail::ActivationType active_gate,
bool origin_mode) {
#if !defined(__NVCC__) && !defined(__HIPCC___)
auto blas =
phi::funcs::GetBlas<paddle::platform::CPUDeviceContext, T>(context);
if (value.prev_out_value) {
blas.GEMM(false,
false,
batch_size,
frame_size * 2,
frame_size,
1,
value.prev_out_value,
frame_size,
value.gate_weight,
frame_size * 2,
1,
value.gate_value,
frame_size * 3);
}
detail::forward_reset_output(
phi::funcs::detail::forward::gru_resetOutput<T>(),
value,
frame_size,
batch_size,
active_gate,
true,
nullptr);
if (value.prev_out_value) {
blas.GEMM(false,
false,
batch_size,
frame_size,
frame_size,
1,
value.reset_output_value,
frame_size,
value.state_weight,
frame_size,
1,
value.gate_value + frame_size * 2,
frame_size * 3);
}
detail::forward_final_output(
phi::funcs::detail::forward::gru_finalOutput<T>(),
value,
frame_size,
batch_size,
active_node,
origin_mode,
true,
nullptr);
#endif
}
};
template <typename T>
struct GRUUnitGradFunctor<paddle::platform::CPUDeviceContext, T> {
static void compute(const paddle::platform::CPUDeviceContext &context,
GRUMetaValue<T> value,
GRUMetaGrad<T> grad,
int frame_size,
int batch_size,
const phi::funcs::detail::ActivationType active_node,
const phi::funcs::detail::ActivationType active_gate,
bool origin_mode) {
#if !defined(__NVCC__) && !defined(__HIPCC___)
detail::backward_state_grad(
phi::funcs::detail::backward::gru_stateGrad<T>(),
value,
grad,
frame_size,
batch_size,
active_node,
origin_mode);
auto blas =
phi::funcs::GetBlas<paddle::platform::CPUDeviceContext, T>(context);
if (value.prev_out_value && grad.prev_out_grad) {
blas.GEMM(false,
true,
batch_size,
frame_size,
frame_size,
1,
grad.gate_grad + frame_size * 2,
frame_size * 3,
value.state_weight,
frame_size,
0,
grad.reset_output_grad,
frame_size);
if (grad.state_weight_grad) {
blas.GEMM(true,
false,
frame_size,
frame_size,
batch_size,
1,
value.reset_output_value,
frame_size,
grad.gate_grad + frame_size * 2,
frame_size * 3,
1,
grad.state_weight_grad,
frame_size);
}
}
detail::backward_reset_grad(
phi::funcs::detail::backward::gru_resetGrad<T>(),
value,
grad,
frame_size,
batch_size,
active_gate);
if (grad.prev_out_grad && value.prev_out_value) {
blas.GEMM(false,
true,
batch_size,
frame_size,
frame_size * 2,
1,
grad.gate_grad,
frame_size * 3,
value.gate_weight,
frame_size * 2,
1,
grad.prev_out_grad,
frame_size);
if (grad.gate_weight_grad) {
blas.GEMM(true,
false,
frame_size,
frame_size * 2,
batch_size,
1,
value.prev_out_value,
frame_size,
grad.gate_grad,
frame_size * 3,
1,
grad.gate_weight_grad,
frame_size * 2);
}
}
#endif
}
};
template <typename T>
struct GRUUnitFunctorV2<paddle::platform::CPUDeviceContext, T> {
static void compute(const paddle::platform::CPUDeviceContext &context,
GRUMetaValue<T> value,
int frame_size,
int batch_size,
const phi::funcs::detail::ActivationType active_node,
const phi::funcs::detail::ActivationType active_gate) {
#if !defined(__NVCC__) && !defined(__HIPCC___)
auto blas =
phi::funcs::GetBlas<paddle::platform::CPUDeviceContext, T>(context);
if (value.prev_out_value) {
blas.GEMM(CblasNoTrans,
CblasTrans,
batch_size,
frame_size,
frame_size,
1,
value.prev_out_value,
value.state_weight,
0,
value.reset_output_value);
}
detail::forward_reset_output(
phi::funcs::detail::forward::gru_resetOutput<T>(),
value,
frame_size,
batch_size,
active_gate,
false,
&context);
T *cell_state_value = value.gate_value + 2 * frame_size;
T *reset_output_value = value.reset_output_value;
for (int b = 0; b < batch_size; ++b) {
blas.VADD(
frame_size, cell_state_value, reset_output_value, cell_state_value);
cell_state_value += frame_size * 3;
reset_output_value += frame_size;
}
detail::forward_final_output(
phi::funcs::detail::forward::gru_finalOutput<T>(),
value,
frame_size,
batch_size,
active_node,
true,
false,
&context);
#endif
}
};
template <typename T>
struct GRUUnitGradFunctorV2<paddle::platform::CPUDeviceContext, T> {
static void compute(const paddle::platform::CPUDeviceContext &context,
GRUMetaValue<T> value,
GRUMetaGrad<T> grad,
int frame_size,
int batch_size,
const phi::funcs::detail::ActivationType active_node,
const phi::funcs::detail::ActivationType active_gate) {
#if !defined(__NVCC__) && !defined(__HIPCC___)
// calculate grad_update_gate, grad_frame_state,
// grad_reset_output, grad_reset_gate
detail::cpu_gru_backward(context,
phi::funcs::detail::backward::gru<T>(),
value,
grad,
frame_size,
batch_size,
active_node,
active_gate);
auto blas =
phi::funcs::GetBlas<paddle::platform::CPUDeviceContext, T>(context);
if (grad.prev_out_grad && value.prev_out_value) {
// update prev_out_grad
blas.GEMM(false,
false,
batch_size,
frame_size,
frame_size,
1,
grad.gate_grad,
frame_size * 3,
value.gate_weight,
frame_size,
1,
grad.prev_out_grad,
frame_size);
blas.GEMM(false,
false,
batch_size,
frame_size,
frame_size,
1,
grad.gate_grad + frame_size,
frame_size * 3,
value.gate_weight + frame_size * frame_size,
frame_size,
1,
grad.prev_out_grad,
frame_size);
blas.GEMM(false,
false,
batch_size,
frame_size,
frame_size,
1,
grad.reset_output_grad,
frame_size,
value.state_weight,
frame_size,
1,
grad.prev_out_grad,
frame_size);
// update weight_hh_grad
if (grad.gate_weight_grad) {
// reset gate
blas.GEMM(true,
false,
frame_size,
frame_size,
batch_size,
1,
grad.gate_grad,
frame_size * 3,
value.prev_out_value,
frame_size,
1,
grad.gate_weight_grad,
frame_size);
// update gate
blas.GEMM(true,
false,
frame_size,
frame_size,
batch_size,
1,
grad.gate_grad + frame_size,
frame_size * 3,
value.prev_out_value,
frame_size,
1,
grad.gate_weight_grad + frame_size * frame_size,
frame_size);
// cell state
blas.GEMM(true,
false,
frame_size,
frame_size,
batch_size,
1,
grad.reset_output_grad,
frame_size,
value.prev_out_value,
frame_size,
1,
grad.state_weight_grad,
frame_size);
}
}
// update bias_hh_grad
T *gate_grad = grad.gate_grad;
T *bias_hh_grad = grad.bias_hh_grad;
T *state_bias_grad = grad.bias_hh_grad + 2 * frame_size;
T *reset_output_grad = grad.reset_output_grad;
for (int b = 0; b < batch_size; ++b) {
blas.VADD(2 * frame_size, bias_hh_grad, gate_grad, bias_hh_grad);
blas.VADD(
frame_size, state_bias_grad, reset_output_grad, state_bias_grad);
gate_grad += 3 * frame_size;
reset_output_grad += frame_size;
}
#endif
}
};
template struct GRUUnitFunctor<paddle::platform::CPUDeviceContext, float>;
template struct GRUUnitFunctor<paddle::platform::CPUDeviceContext, double>;
template struct GRUUnitGradFunctor<paddle::platform::CPUDeviceContext, float>;
template struct GRUUnitGradFunctor<paddle::platform::CPUDeviceContext, double>;
template struct GRUUnitFunctorV2<paddle::platform::CPUDeviceContext, float>;
template struct GRUUnitFunctorV2<paddle::platform::CPUDeviceContext, double>;
template struct GRUUnitGradFunctorV2<paddle::platform::CPUDeviceContext, float>;
template struct GRUUnitGradFunctorV2<paddle::platform::CPUDeviceContext,
double>;
} // namespace funcs
} // namespace phi
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <paddle/fluid/platform/device_context.h>
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/detail/gru_gpu_kernel.h"
#include "paddle/phi/kernels/funcs/detail/gru_kernel.h"
#include "paddle/phi/kernels/funcs/gru_compute.h"
namespace phi {
namespace funcs {
template <typename T>
struct GRUUnitFunctor<paddle::platform::CUDADeviceContext, T> {
static void compute(const paddle::platform::CUDADeviceContext &context,
GRUMetaValue<T> value,
int frame_size,
int batch_size,
const phi::funcs::detail::ActivationType active_node,
const phi::funcs::detail::ActivationType active_gate,
bool origin_mode) {
auto stream = context.stream();
dim3 threads;
dim3 grid;
if (batch_size == 1) {
if (context.GetComputeCapability() >= 70) {
if (frame_size < 16) {
constexpr int tiled_size = 8;
int frame_blocks = (frame_size * 2 + tiled_size - 1) / tiled_size;
threads = dim3(tiled_size, 1);
grid = dim3(frame_blocks, 1);
detail::KeFastCollectiveGruGate<
T,
tiled_size><<<grid, threads, 0, stream>>>(
value.gate_value,
value.prev_out_value,
value.gate_weight,
value.reset_output_value,
frame_size,
active_gate);
frame_blocks = (frame_size + tiled_size - 1) / tiled_size;
grid = dim3(frame_blocks, 1);
detail::KeFastCollectiveGruOut<
T,
tiled_size><<<grid, threads, 0, stream>>>(
value.state_weight,
value.prev_out_value,
value.output_value,
value.gate_value,
value.reset_output_value,
frame_size,
active_node,
origin_mode);
} else {
constexpr int tiled_size = 16;
int frame_blocks = (frame_size * 2 + tiled_size - 1) / tiled_size;
threads = dim3(tiled_size, 1);
grid = dim3(frame_blocks, 1);
detail::KeFastCollectiveGruGate<
T,
tiled_size><<<grid, threads, 0, stream>>>(
value.gate_value,
value.prev_out_value,
value.gate_weight,
value.reset_output_value,
frame_size,
active_gate);
frame_blocks = (frame_size + tiled_size - 1) / tiled_size;
grid = dim3(frame_blocks, 1);
detail::KeFastCollectiveGruOut<
T,
tiled_size><<<grid, threads, 0, stream>>>(
value.state_weight,
value.prev_out_value,
value.output_value,
value.gate_value,
value.reset_output_value,
frame_size,
active_node,
origin_mode);
}
return;
} else {
int frame_per_block = frame_size <= 1024 ? frame_size : 1024;
int frame_blocks = (frame_size + 1024 - 1) / 1024;
threads = dim3(frame_per_block, 1);
grid = dim3(frame_blocks, 1);
}
} else {
threads = dim3(32, 32);
grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32);
}
auto blas =
phi::funcs::GetBlas<paddle::platform::CUDADeviceContext, T>(context);
if (value.prev_out_value) {
blas.GEMM(false,
false,
batch_size,
frame_size * 2,
frame_size,
1,
value.prev_out_value,
frame_size,
value.gate_weight,
frame_size * 2,
1,
value.gate_value,
frame_size * 3);
}
if (batch_size == 1) {
detail::KeGruForwardResetOutput<
phi::funcs::detail::forward::gru_resetOutput<T>,
/* is_batch= */ false,
T><<<grid, threads, 0, stream>>>(
phi::funcs::detail::forward::gru_resetOutput<T>(),
value.gate_value,
value.reset_output_value,
value.prev_out_value,
frame_size,
batch_size,
active_gate);
} else {
detail::KeGruForwardResetOutput<
phi::funcs::detail::forward::gru_resetOutput<T>,
/* is_batch= */ true,
T><<<grid, threads, 0, stream>>>(
phi::funcs::detail::forward::gru_resetOutput<T>(),
value.gate_value,
value.reset_output_value,
value.prev_out_value,
frame_size,
batch_size,
active_gate);
}
if (value.prev_out_value) {
blas.GEMM(false,
false,
batch_size,
frame_size,
frame_size,
1,
value.reset_output_value,
frame_size,
value.state_weight,
frame_size,
1,
value.gate_value + frame_size * 2,
frame_size * 3);
}
if (batch_size == 1) {
detail::KeGruForwardFinalOutput<
phi::funcs::detail::forward::gru_finalOutput<T>,
/* is_batch= */ false,
T><<<grid, threads, 0, stream>>>(
phi::funcs::detail::forward::gru_finalOutput<T>(),
value.gate_value,
value.prev_out_value,
value.output_value,
frame_size,
batch_size,
active_node,
origin_mode);
} else {
detail::KeGruForwardFinalOutput<
phi::funcs::detail::forward::gru_finalOutput<T>,
/* is_batch= */ true,
T><<<grid, threads, 0, stream>>>(
phi::funcs::detail::forward::gru_finalOutput<T>(),
value.gate_value,
value.prev_out_value,
value.output_value,
frame_size,
batch_size,
active_node,
origin_mode);
}
}
};
template <typename T>
struct GRUUnitGradFunctor<paddle::platform::CUDADeviceContext, T> {
static void compute(const paddle::platform::CUDADeviceContext &context,
GRUMetaValue<T> value,
GRUMetaGrad<T> grad,
int frame_size,
int batch_size,
const phi::funcs::detail::ActivationType active_node,
const phi::funcs::detail::ActivationType active_gate,
bool origin_mode) {
auto stream = context.stream();
dim3 threads;
dim3 grid;
if (batch_size == 1) {
int frame_per_block = frame_size <= 1024 ? frame_size : 1024;
int frame_blocks = (frame_size + 1024 - 1) / 1024;
threads = dim3(frame_per_block, 1);
grid = dim3(frame_blocks, 1);
} else {
threads = dim3(32, 32);
grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32);
}
if (batch_size == 1) {
detail::KeGruBackwardStateGrad<
phi::funcs::detail::backward::gru_stateGrad<T>,
/* is_batch= */ false><<<grid, threads, 0, stream>>>(
phi::funcs::detail::backward::gru_stateGrad<T>(),
value.gate_value,
grad.gate_grad,
value.prev_out_value,
grad.prev_out_grad,
grad.output_grad,
frame_size,
batch_size,
active_node,
origin_mode);
} else {
detail::KeGruBackwardStateGrad<
phi::funcs::detail::backward::gru_stateGrad<T>,
/* is_batch= */ true><<<grid, threads, 0, stream>>>(
phi::funcs::detail::backward::gru_stateGrad<T>(),
value.gate_value,
grad.gate_grad,
value.prev_out_value,
grad.prev_out_grad,
grad.output_grad,
frame_size,
batch_size,
active_node,
origin_mode);
}
auto blas =
phi::funcs::GetBlas<paddle::platform::CUDADeviceContext, T>(context);
if (value.prev_out_value && grad.prev_out_grad) {
blas.GEMM(false,
true,
batch_size,
frame_size,
frame_size,
1,
grad.gate_grad + frame_size * 2,
frame_size * 3,
value.state_weight,
frame_size,
0,
grad.reset_output_grad,
frame_size);
if (grad.state_weight_grad) {
blas.GEMM(true,
false,
frame_size,
frame_size,
batch_size,
1,
value.reset_output_value,
frame_size,
grad.gate_grad + frame_size * 2,
frame_size * 3,
1,
grad.state_weight_grad,
frame_size);
}
}
if (batch_size == 1) {
detail::KeGruBackwardResetGrad<
phi::funcs::detail::backward::gru_resetGrad<T>,
/* is_batch= */ false><<<grid, threads, 0, stream>>>(
phi::funcs::detail::backward::gru_resetGrad<T>(),
value.gate_value,
grad.gate_grad,
value.prev_out_value,
grad.prev_out_grad,
grad.reset_output_grad,
frame_size,
batch_size,
active_gate);
} else {
detail::KeGruBackwardResetGrad<
phi::funcs::detail::backward::gru_resetGrad<T>,
/* is_batch= */ true><<<grid, threads, 0, stream>>>(
phi::funcs::detail::backward::gru_resetGrad<T>(),
value.gate_value,
grad.gate_grad,
value.prev_out_value,
grad.prev_out_grad,
grad.reset_output_grad,
frame_size,
batch_size,
active_gate);
}
if (grad.prev_out_grad && value.prev_out_value) {
blas.GEMM(false,
true,
batch_size,
frame_size,
frame_size * 2,
1,
grad.gate_grad,
frame_size * 3,
value.gate_weight,
frame_size * 2,
1,
grad.prev_out_grad,
frame_size);
if (grad.gate_weight_grad) {
blas.GEMM(true,
false,
frame_size,
frame_size * 2,
batch_size,
1,
value.prev_out_value,
frame_size,
grad.gate_grad,
frame_size * 3,
1,
grad.gate_weight_grad,
frame_size * 2);
}
}
}
};
template struct GRUUnitFunctor<paddle::platform::CUDADeviceContext, float>;
template struct GRUUnitFunctor<paddle::platform::CUDADeviceContext, double>;
template struct GRUUnitGradFunctor<paddle::platform::CUDADeviceContext, float>;
template struct GRUUnitGradFunctor<paddle::platform::CUDADeviceContext, double>;
} // namespace funcs
} // namespace phi
...@@ -11,13 +11,12 @@ limitations under the License. */ ...@@ -11,13 +11,12 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/kernels/funcs/detail/activation_functions.h"
namespace paddle { namespace phi {
namespace operators { namespace funcs {
namespace math {
template <typename T> template <typename T>
struct GRUMetaValue { struct GRUMetaValue {
...@@ -43,38 +42,47 @@ struct GRUMetaGrad { ...@@ -43,38 +42,47 @@ struct GRUMetaGrad {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct GRUUnitFunctor { struct GRUUnitFunctor {
static void compute(const DeviceContext &context, GRUMetaValue<T> value, static void compute(const DeviceContext &context,
int frame_size, int batch_size, GRUMetaValue<T> value,
const detail::ActivationType active_node, int frame_size,
const detail::ActivationType active_gate, int batch_size,
const phi::funcs::detail::ActivationType active_node,
const phi::funcs::detail::ActivationType active_gate,
bool origin_mode); bool origin_mode);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct GRUUnitGradFunctor { struct GRUUnitGradFunctor {
static void compute(const DeviceContext &context, GRUMetaValue<T> value, static void compute(const DeviceContext &context,
GRUMetaGrad<T> grad, int frame_size, int batch_size, GRUMetaValue<T> value,
const detail::ActivationType active_node, GRUMetaGrad<T> grad,
const detail::ActivationType active_gate, int frame_size,
int batch_size,
const phi::funcs::detail::ActivationType active_node,
const phi::funcs::detail::ActivationType active_gate,
bool origin_mode); bool origin_mode);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct GRUUnitFunctorV2 { struct GRUUnitFunctorV2 {
static void compute(const DeviceContext &context, GRUMetaValue<T> value, static void compute(const DeviceContext &context,
int frame_size, int batch_size, GRUMetaValue<T> value,
const detail::ActivationType active_node, int frame_size,
const detail::ActivationType active_gate); int batch_size,
const phi::funcs::detail::ActivationType active_node,
const phi::funcs::detail::ActivationType active_gate);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct GRUUnitGradFunctorV2 { struct GRUUnitGradFunctorV2 {
static void compute(const DeviceContext &context, GRUMetaValue<T> value, static void compute(const DeviceContext &context,
GRUMetaGrad<T> grad, int frame_size, int batch_size, GRUMetaValue<T> value,
const detail::ActivationType active_node, GRUMetaGrad<T> grad,
const detail::ActivationType active_gate); int frame_size,
int batch_size,
const phi::funcs::detail::ActivationType active_node,
const phi::funcs::detail::ActivationType active_gate);
}; };
} // namespace math } // namespace funcs
} // namespace operators } // namespace phi
} // namespace paddle
...@@ -12,33 +12,34 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,33 +12,34 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/lstm_compute.h" #include "paddle/phi/kernels/funcs/lstm_compute.h"
#include "paddle/phi/kernels/funcs/detail/lstm_cpu_kernel.h"
#include "paddle/phi/kernels/funcs/detail/lstm_kernel.h"
#include "paddle/fluid/operators/math/detail/lstm_cpu_kernel.h" namespace phi {
#include "paddle/fluid/operators/math/detail/lstm_kernel.h" namespace funcs {
namespace paddle {
namespace platform {
class CPUDeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace operators {
namespace math {
template <class T> template <class T>
struct LstmUnitFunctor<platform::CPUDeviceContext, T> { struct LstmUnitFunctor<paddle::platform::CPUDeviceContext, T> {
static void compute(const platform::CPUDeviceContext& context, static void compute(const paddle::platform::CPUDeviceContext& context,
LstmMetaValue<T> value, int frame_size, int batch_size, LstmMetaValue<T> value,
T cell_clip, const detail::ActivationType& gate_act, int frame_size,
const detail::ActivationType& cell_act, int batch_size,
const detail::ActivationType& cand_act, T cell_clip,
const phi::funcs::detail::ActivationType& gate_act,
const phi::funcs::detail::ActivationType& cell_act,
const phi::funcs::detail::ActivationType& cand_act,
bool old_api_version = true) { bool old_api_version = true) {
for (int b = 0; b < batch_size; b++) { for (int b = 0; b < batch_size; b++) {
detail::cpu_lstm_forward(context, detail::forward::lstm<T>(), value, detail::cpu_lstm_forward(context,
frame_size, cell_clip, cand_act, gate_act, phi::funcs::detail::forward::lstm<T>(),
cell_act, old_api_version); value,
frame_size,
cell_clip,
cand_act,
gate_act,
cell_act,
old_api_version);
value.gate_value += frame_size * 4; value.gate_value += frame_size * 4;
value.state_value += frame_size; value.state_value += frame_size;
value.state_active_value += frame_size; value.state_active_value += frame_size;
...@@ -51,18 +52,28 @@ struct LstmUnitFunctor<platform::CPUDeviceContext, T> { ...@@ -51,18 +52,28 @@ struct LstmUnitFunctor<platform::CPUDeviceContext, T> {
}; };
template <class T> template <class T>
struct LstmUnitGradFunctor<platform::CPUDeviceContext, T> { struct LstmUnitGradFunctor<paddle::platform::CPUDeviceContext, T> {
static void compute(const platform::CPUDeviceContext& context, static void compute(const paddle::platform::CPUDeviceContext& context,
LstmMetaValue<T> value, LstmMetaGrad<T> grad, LstmMetaValue<T> value,
int frame_size, int batch_size, T cell_clip, LstmMetaGrad<T> grad,
const detail::ActivationType& gate_act, int frame_size,
const detail::ActivationType& cell_act, int batch_size,
const detail::ActivationType& cand_act, T cell_clip,
const phi::funcs::detail::ActivationType& gate_act,
const phi::funcs::detail::ActivationType& cell_act,
const phi::funcs::detail::ActivationType& cand_act,
bool old_api_version = true) { bool old_api_version = true) {
for (int b = 0; b < batch_size; b++) { for (int b = 0; b < batch_size; b++) {
detail::cpu_lstm_backward(context, detail::backward::lstm<T>(), value, detail::cpu_lstm_backward(context,
grad, frame_size, cell_clip, cand_act, gate_act, phi::funcs::detail::backward::lstm<T>(),
cell_act, old_api_version); value,
grad,
frame_size,
cell_clip,
cand_act,
gate_act,
cell_act,
old_api_version);
value.gate_value += frame_size * 4; value.gate_value += frame_size * 4;
value.state_value += frame_size; value.state_value += frame_size;
...@@ -83,11 +94,10 @@ struct LstmUnitGradFunctor<platform::CPUDeviceContext, T> { ...@@ -83,11 +94,10 @@ struct LstmUnitGradFunctor<platform::CPUDeviceContext, T> {
} }
}; };
template class LstmUnitFunctor<platform::CPUDeviceContext, float>; template class LstmUnitFunctor<paddle::platform::CPUDeviceContext, float>;
template class LstmUnitFunctor<platform::CPUDeviceContext, double>; template class LstmUnitFunctor<paddle::platform::CPUDeviceContext, double>;
template class LstmUnitGradFunctor<platform::CPUDeviceContext, float>; template class LstmUnitGradFunctor<paddle::platform::CPUDeviceContext, float>;
template class LstmUnitGradFunctor<platform::CPUDeviceContext, double>; template class LstmUnitGradFunctor<paddle::platform::CPUDeviceContext, double>;
} // namespace math } // namespace funcs
} // namespace operators } // namespace phi
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/funcs/detail/lstm_gpu_kernel.h"
#include "paddle/phi/kernels/funcs/detail/lstm_kernel.h"
#include "paddle/phi/kernels/funcs/lstm_compute.h"
namespace phi {
namespace funcs {
template <class T>
struct LstmUnitFunctor<paddle::platform::CUDADeviceContext, T> {
static void compute(const paddle::platform::CUDADeviceContext& context,
LstmMetaValue<T> value,
int frame_size,
int batch_size,
T cell_clip,
const phi::funcs::detail::ActivationType& gate_act,
const phi::funcs::detail::ActivationType& cell_act,
const phi::funcs::detail::ActivationType& cand_act,
bool old_api_version = true) {
detail::gpu_lstm_forward<T>(context,
phi::funcs::detail::forward::lstm<T>(),
value,
frame_size,
batch_size,
cell_clip,
cand_act,
gate_act,
cell_act);
}
};
template <class T>
struct LstmUnitGradFunctor<paddle::platform::CUDADeviceContext, T> {
static void compute(const paddle::platform::CUDADeviceContext& context,
LstmMetaValue<T> value,
LstmMetaGrad<T> grad,
int frame_size,
int batch_size,
T cell_clip,
const phi::funcs::detail::ActivationType& gate_act,
const phi::funcs::detail::ActivationType& cell_act,
const phi::funcs::detail::ActivationType& cand_act,
bool old_api_version = true) {
detail::gpu_lstm_backward(context,
phi::funcs::detail::backward::lstm<T>(),
value,
grad,
frame_size,
batch_size,
cell_clip,
cand_act,
gate_act,
cell_act);
}
};
template class LstmUnitFunctor<paddle::platform::CUDADeviceContext, float>;
template class LstmUnitFunctor<paddle::platform::CUDADeviceContext, double>;
template class LstmUnitGradFunctor<paddle::platform::CUDADeviceContext, float>;
template class LstmUnitGradFunctor<paddle::platform::CUDADeviceContext, double>;
} // namespace funcs
} // namespace phi
...@@ -14,13 +14,12 @@ limitations under the License. */ ...@@ -14,13 +14,12 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/kernels/funcs/detail/activation_functions.h"
namespace paddle { namespace phi {
namespace operators { namespace funcs {
namespace math {
template <class T> template <class T>
struct LstmMetaValue { struct LstmMetaValue {
...@@ -49,25 +48,31 @@ struct LstmMetaGrad { ...@@ -49,25 +48,31 @@ struct LstmMetaGrad {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class LstmUnitFunctor { class LstmUnitFunctor {
public: public:
static void compute(const DeviceContext &context, LstmMetaValue<T> value, static void compute(const DeviceContext &context,
int frame_size, int batch_size, T cell_clip, LstmMetaValue<T> value,
const detail::ActivationType &gate_act, int frame_size,
const detail::ActivationType &cell_act, int batch_size,
const detail::ActivationType &cand_act, T cell_clip,
const phi::funcs::detail::ActivationType &gate_act,
const phi::funcs::detail::ActivationType &cell_act,
const phi::funcs::detail::ActivationType &cand_act,
bool old_api_version = true); bool old_api_version = true);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class LstmUnitGradFunctor { class LstmUnitGradFunctor {
public: public:
static void compute(const DeviceContext &context, LstmMetaValue<T> value, static void compute(const DeviceContext &context,
LstmMetaGrad<T> grad, int frame_size, int batch_size, LstmMetaValue<T> value,
T cell_clip, const detail::ActivationType &gate_act, LstmMetaGrad<T> grad,
const detail::ActivationType &cell_act, int frame_size,
const detail::ActivationType &cand_act, int batch_size,
T cell_clip,
const phi::funcs::detail::ActivationType &gate_act,
const phi::funcs::detail::ActivationType &cell_act,
const phi::funcs::detail::ActivationType &cand_act,
bool old_api_version = true); bool old_api_version = true);
}; };
} // namespace math } // namespace funcs
} // namespace operators } // namespace phi
} // namespace paddle
...@@ -12,47 +12,45 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,47 +12,45 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/phi/kernels/funcs/sequence2batch.h"
namespace paddle { namespace phi {
namespace platform { namespace funcs {
class CPUDeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace operators {
namespace math {
template <typename T> template <typename T>
class CopyMatrixRowsFunctor<platform::CPUDeviceContext, T> { class CopyMatrixRowsFunctor<paddle::platform::CPUDeviceContext, T> {
public: public:
void operator()(const platform::CPUDeviceContext& context, void operator()(const paddle::platform::CPUDeviceContext& context,
const framework::Tensor& src, const paddle::framework::Tensor& src,
framework::Vector<size_t> index_lod, framework::Tensor* dst, paddle::framework::Vector<size_t> index_lod,
paddle::framework::Tensor* dst,
bool is_src_index) { bool is_src_index) {
size_t* index = index_lod.data(); size_t* index = index_lod.data();
auto src_dims = src.dims(); auto src_dims = src.dims();
auto dst_dims = dst->dims(); auto dst_dims = dst->dims();
PADDLE_ENFORCE_EQ(src_dims.size(), 2UL, PADDLE_ENFORCE_EQ(src_dims.size(),
platform::errors::InvalidArgument( 2UL,
phi::errors::InvalidArgument(
"The source tensor must be a matrix with rank 2, but " "The source tensor must be a matrix with rank 2, but "
"got the source tensor rank is %lu. " "got the source tensor rank is %lu. "
"Please check the rank of the source tensor", "Please check the rank of the source tensor",
src_dims.size())); src_dims.size()));
PADDLE_ENFORCE_EQ(dst_dims.size(), 2UL, PADDLE_ENFORCE_EQ(dst_dims.size(),
platform::errors::InvalidArgument( 2UL,
phi::errors::InvalidArgument(
"The destination tensor must be a matrix with rank, " "The destination tensor must be a matrix with rank, "
"but got the destination tensor rank is %lu. " "but got the destination tensor rank is %lu. "
"Please check the rank of the destination tensor", "Please check the rank of the destination tensor",
dst_dims.size())); dst_dims.size()));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
src_dims[1], dst_dims[1], src_dims[1],
platform::errors::InvalidArgument( dst_dims[1],
phi::errors::InvalidArgument(
"The width of the source tensor and the destination tensor must be " "The width of the source tensor and the destination tensor must be "
"same. But got %lu != %lu.Please check the rank of the source " "same. But got %lu != %lu.Please check the rank of the source "
"tensor", "tensor",
src_dims.size(), dst_dims.size())); src_dims.size(),
dst_dims.size()));
auto height = dst_dims[0]; auto height = dst_dims[0];
auto width = dst_dims[1]; auto width = dst_dims[1];
auto* src_data = src.data<T>(); auto* src_data = src.data<T>();
...@@ -70,14 +68,18 @@ class CopyMatrixRowsFunctor<platform::CPUDeviceContext, T> { ...@@ -70,14 +68,18 @@ class CopyMatrixRowsFunctor<platform::CPUDeviceContext, T> {
} }
}; };
template class CopyMatrixRowsFunctor<platform::CPUDeviceContext, float>; template class CopyMatrixRowsFunctor<paddle::platform::CPUDeviceContext, float>;
template class CopyMatrixRowsFunctor<platform::CPUDeviceContext, double>; template class CopyMatrixRowsFunctor<paddle::platform::CPUDeviceContext,
double>;
template class LoDTensor2BatchFunctor<platform::CPUDeviceContext, float>; template class LoDTensor2BatchFunctor<paddle::platform::CPUDeviceContext,
template class LoDTensor2BatchFunctor<platform::CPUDeviceContext, double>; float>;
template class Batch2LoDTensorFunctor<platform::CPUDeviceContext, float>; template class LoDTensor2BatchFunctor<paddle::platform::CPUDeviceContext,
template class Batch2LoDTensorFunctor<platform::CPUDeviceContext, double>; double>;
template class Batch2LoDTensorFunctor<paddle::platform::CPUDeviceContext,
float>;
template class Batch2LoDTensorFunctor<paddle::platform::CPUDeviceContext,
double>;
} // namespace math } // namespace funcs
} // namespace operators } // namespace phi
} // namespace paddle
...@@ -11,15 +11,17 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,15 +11,17 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/phi/kernels/funcs/sequence2batch.h"
namespace paddle { namespace phi {
namespace operators { namespace funcs {
namespace math {
template <typename T, int BlockDimX, int BlockDimY, int GridDimX> template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
__global__ void CopyMatrixRowsKernel(const T* src, T* dst, const size_t* index, __global__ void CopyMatrixRowsKernel(const T* src,
int64_t height, int64_t width, T* dst,
const size_t* index,
int64_t height,
int64_t width,
bool is_src_index) { bool is_src_index) {
int idx = threadIdx.x; int idx = threadIdx.x;
int idy = threadIdx.y; int idy = threadIdx.y;
...@@ -37,33 +39,38 @@ __global__ void CopyMatrixRowsKernel(const T* src, T* dst, const size_t* index, ...@@ -37,33 +39,38 @@ __global__ void CopyMatrixRowsKernel(const T* src, T* dst, const size_t* index,
} }
template <typename T> template <typename T>
class CopyMatrixRowsFunctor<platform::CUDADeviceContext, T> { class CopyMatrixRowsFunctor<paddle::platform::CUDADeviceContext, T> {
public: public:
void operator()(const platform::CUDADeviceContext& context, void operator()(const paddle::platform::CUDADeviceContext& context,
const framework::Tensor& src, const paddle::framework::Tensor& src,
framework::Vector<size_t> index_lod, framework::Tensor* dst, paddle::framework::Vector<size_t> index_lod,
paddle::framework::Tensor* dst,
bool is_src_index) { bool is_src_index) {
auto src_dims = src.dims(); auto src_dims = src.dims();
auto dst_dims = dst->dims(); auto dst_dims = dst->dims();
PADDLE_ENFORCE_EQ(src_dims.size(), 2, PADDLE_ENFORCE_EQ(src_dims.size(),
platform::errors::InvalidArgument( 2,
phi::errors::InvalidArgument(
"The source tensor must be a matrix with rank 2, but " "The source tensor must be a matrix with rank 2, but "
"got the source tensor rank is %lu. " "got the source tensor rank is %lu. "
"Please check the rank of the source tensor", "Please check the rank of the source tensor",
src_dims.size())); src_dims.size()));
PADDLE_ENFORCE_EQ(dst_dims.size(), 2, PADDLE_ENFORCE_EQ(dst_dims.size(),
platform::errors::InvalidArgument( 2,
phi::errors::InvalidArgument(
"The destination tensor must be a matrix with rank, " "The destination tensor must be a matrix with rank, "
"but got the destination tensor rank is %lu. " "but got the destination tensor rank is %lu. "
"Please check the rank of the destination tensor", "Please check the rank of the destination tensor",
dst_dims.size())); dst_dims.size()));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
src_dims[1], dst_dims[1], src_dims[1],
platform::errors::InvalidArgument( dst_dims[1],
phi::errors::InvalidArgument(
"The width of the source tensor and the destination tensor must be " "The width of the source tensor and the destination tensor must be "
"same. But got %lu != %lu.Please check the rank of the source " "same. But got %lu != %lu.Please check the rank of the source "
"tensor", "tensor",
src_dims.size(), dst_dims.size())); src_dims.size(),
dst_dims.size()));
auto height = dst_dims[0]; auto height = dst_dims[0];
auto width = dst_dims[1]; auto width = dst_dims[1];
auto* src_data = src.data<T>(); auto* src_data = src.data<T>();
...@@ -74,19 +81,28 @@ class CopyMatrixRowsFunctor<platform::CUDADeviceContext, T> { ...@@ -74,19 +81,28 @@ class CopyMatrixRowsFunctor<platform::CUDADeviceContext, T> {
auto stream = context.stream(); auto stream = context.stream();
paddle::framework::MixVector<size_t> mix_index_lod(&index_lod); paddle::framework::MixVector<size_t> mix_index_lod(&index_lod);
CopyMatrixRowsKernel<T, 128, 8, 8><<<grid, threads, 0, stream>>>( CopyMatrixRowsKernel<T, 128, 8, 8><<<grid, threads, 0, stream>>>(
src_data, dst_data, mix_index_lod.CUDAData(context.GetPlace()), height, src_data,
width, is_src_index); dst_data,
mix_index_lod.CUDAData(context.GetPlace()),
height,
width,
is_src_index);
} }
}; };
template class CopyMatrixRowsFunctor<platform::CUDADeviceContext, float>; template class CopyMatrixRowsFunctor<paddle::platform::CUDADeviceContext,
template class CopyMatrixRowsFunctor<platform::CUDADeviceContext, double>; float>;
template class CopyMatrixRowsFunctor<paddle::platform::CUDADeviceContext,
double>;
template class LoDTensor2BatchFunctor<platform::CUDADeviceContext, float>; template class LoDTensor2BatchFunctor<paddle::platform::CUDADeviceContext,
template class LoDTensor2BatchFunctor<platform::CUDADeviceContext, double>; float>;
template class Batch2LoDTensorFunctor<platform::CUDADeviceContext, float>; template class LoDTensor2BatchFunctor<paddle::platform::CUDADeviceContext,
template class Batch2LoDTensorFunctor<platform::CUDADeviceContext, double>; double>;
template class Batch2LoDTensorFunctor<paddle::platform::CUDADeviceContext,
float>;
template class Batch2LoDTensorFunctor<paddle::platform::CUDADeviceContext,
double>;
} // namespace math } // namespace funcs
} // namespace operators } // namespace phi
} // namespace paddle
...@@ -20,13 +20,13 @@ limitations under the License. */ ...@@ -20,13 +20,13 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
namespace paddle { namespace phi {
namespace operators { namespace funcs {
namespace math {
template <typename T, int MajorType = Eigen::RowMajor, template <typename T,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = paddle::framework::EigenMatrix<T, MajorType, IndexType>;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class CopyMatrixRowsFunctor { class CopyMatrixRowsFunctor {
...@@ -36,8 +36,10 @@ class CopyMatrixRowsFunctor { ...@@ -36,8 +36,10 @@ class CopyMatrixRowsFunctor {
// If is_src_index is false, // If is_src_index is false,
// copy the input src to the indexed rows of output dst. // copy the input src to the indexed rows of output dst.
// The indexed rows are based on the input index. // The indexed rows are based on the input index.
void operator()(const DeviceContext& context, const framework::Tensor& src, void operator()(const DeviceContext& context,
framework::Vector<size_t> index_lod, framework::Tensor* dst, const paddle::framework::Tensor& src,
paddle::framework::Vector<size_t> index_lod,
paddle::framework::Tensor* dst,
bool is_src_index); bool is_src_index);
}; };
...@@ -59,32 +61,37 @@ class LoDTensor2BatchFunctor { ...@@ -59,32 +61,37 @@ class LoDTensor2BatchFunctor {
public: public:
void operator()(const DeviceContext& context, void operator()(const DeviceContext& context,
const framework::LoDTensor& lod_tensor, const paddle::framework::LoDTensor& lod_tensor,
framework::LoDTensor* batch, bool is_cal_batch_lod, paddle::framework::LoDTensor* batch,
bool is_cal_batch_lod,
bool is_reverse = false) const { bool is_reverse = false) const {
if (!is_cal_batch_lod) { if (!is_cal_batch_lod) {
auto lods = batch->lod(); auto lods = batch->lod();
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
lods.size(), 2UL, lods.size(),
platform::errors::InvalidArgument( 2UL,
phi::errors::InvalidArgument(
"The LoD of LoDTensor should inlcude at least 2-level " "The LoD of LoDTensor should inlcude at least 2-level "
"sequence information, but got the LoD level is %lu. Please " "sequence information, but got the LoD level is %lu. Please "
"check the input value.", "check the input value.",
lods.size())); lods.size()));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
lods[1].size(), static_cast<size_t>(lod_tensor.dims()[0]), lods[1].size(),
platform::errors::InvalidArgument( static_cast<size_t>(lod_tensor.dims()[0]),
phi::errors::InvalidArgument(
"The LoD information should be consistent with the dims, but got " "The LoD information should be consistent with the dims, but got "
"%lu != %lu. Please check the input value.", "%lu != %lu. Please check the input value.",
lods[1].size(), static_cast<size_t>(lod_tensor.dims()[0]))); lods[1].size(),
static_cast<size_t>(lod_tensor.dims()[0])));
CopyMatrixRowsFunctor<DeviceContext, T> to_batch; CopyMatrixRowsFunctor<DeviceContext, T> to_batch;
to_batch(context, lod_tensor, lods[1], batch, true); to_batch(context, lod_tensor, lods[1], batch, true);
return; return;
} }
auto lods = lod_tensor.lod(); auto lods = lod_tensor.lod();
PADDLE_ENFORCE_EQ(lods.size(), 1UL, PADDLE_ENFORCE_EQ(lods.size(),
platform::errors::InvalidArgument( 1UL,
phi::errors::InvalidArgument(
"Only support one level sequence now, but got the " "Only support one level sequence now, but got the "
"LoD level is %lu. Please check the input value.", "LoD level is %lu. Please check the input value.",
lods.size())); lods.size()));
...@@ -97,8 +104,9 @@ class LoDTensor2BatchFunctor { ...@@ -97,8 +104,9 @@ class LoDTensor2BatchFunctor {
seq_info.emplace_back(lod[seq_id], length, seq_id); seq_info.emplace_back(lod[seq_id], length, seq_id);
} }
std::sort(seq_info.begin(), seq_info.end(), std::sort(seq_info.begin(), seq_info.end(), [](SeqInfo a, SeqInfo b) {
[](SeqInfo a, SeqInfo b) { return a.length > b.length; }); return a.length > b.length;
});
// Calculate the start position of each batch. // Calculate the start position of each batch.
// example: sequences = {s0, s1, s2} // example: sequences = {s0, s1, s2}
...@@ -169,27 +177,29 @@ template <typename DeviceContext, typename T> ...@@ -169,27 +177,29 @@ template <typename DeviceContext, typename T>
class Batch2LoDTensorFunctor { class Batch2LoDTensorFunctor {
public: public:
void operator()(const DeviceContext& context, void operator()(const DeviceContext& context,
const framework::LoDTensor& batch, const paddle::framework::LoDTensor& batch,
framework::LoDTensor* lod_tensor) const { paddle::framework::LoDTensor* lod_tensor) const {
auto in_lod = batch.lod(); auto in_lod = batch.lod();
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
in_lod.size(), 2UL, in_lod.size(),
platform::errors::InvalidArgument( 2UL,
phi::errors::InvalidArgument(
"The LoD of LoDTensor should inlcude at least 2-level " "The LoD of LoDTensor should inlcude at least 2-level "
"sequence information, but got the LoD level is %lu. Please check " "sequence information, but got the LoD level is %lu. Please check "
"the input value.", "the input value.",
in_lod.size())); in_lod.size()));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
in_lod[1].size(), static_cast<size_t>(lod_tensor->dims()[0]), in_lod[1].size(),
platform::errors::InvalidArgument( static_cast<size_t>(lod_tensor->dims()[0]),
phi::errors::InvalidArgument(
"The LoD information should be consistent with the dims, but got " "The LoD information should be consistent with the dims, but got "
"%lu != %lu. Please check the input value.", "%lu != %lu. Please check the input value.",
in_lod[1].size(), static_cast<size_t>(lod_tensor->dims()[0]))); in_lod[1].size(),
static_cast<size_t>(lod_tensor->dims()[0])));
CopyMatrixRowsFunctor<DeviceContext, T> to_seq; CopyMatrixRowsFunctor<DeviceContext, T> to_seq;
to_seq(context, batch, in_lod[1], lod_tensor, false); to_seq(context, batch, in_lod[1], lod_tensor, false);
} }
}; };
} // namespace math } // namespace funcs
} // namespace operators } // namespace phi
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册