diff --git a/paddle/operators/lstm_op.h b/paddle/operators/lstm_op.h index 14abd4bf0a6e73a9c0f000f53a5e1e380f01d1c0..c57ee414dc5b3417549c8ac3a7fd57a9c8f452df 100644 --- a/paddle/operators/lstm_op.h +++ b/paddle/operators/lstm_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once #include "paddle/framework/op_registry.h" +#include "paddle/operators/math/detail/activation_functions.h" #include "paddle/operators/math/lstm_compute.h" #include "paddle/operators/math/math_function.h" #include "paddle/operators/math/sequence2batch.h" @@ -102,9 +103,12 @@ class LSTMKernel : public framework::OpKernel { auto batch_starts = batch_gate->lod()[0]; size_t num_batch = batch_starts.size() - 1; - auto gate_act = ctx.Attr("gate_activation"); - auto cell_act = ctx.Attr("cell_activation"); - auto cand_act = ctx.Attr("candidate_activation"); + auto gate_act = math::detail::GetActivationType( + ctx.Attr("gate_activation")); + auto cell_act = math::detail::GetActivationType( + ctx.Attr("cell_activation")); + auto cand_act = math::detail::GetActivationType( + ctx.Attr("candidate_activation")); for (size_t n = 0; n < num_batch; n++) { int bstart = static_cast(batch_starts[n]); @@ -264,9 +268,12 @@ class LSTMGradKernel : public framework::OpKernel { batch_gate_g.mutable_data(batch_gate->dims(), ctx.GetPlace()); batch_gate_g.set_lod(batch_gate->lod()); - auto gate_act = ctx.Attr("gate_activation"); - auto cell_act = ctx.Attr("cell_activation"); - auto cand_act = ctx.Attr("candidate_activation"); + auto gate_act = math::detail::GetActivationType( + ctx.Attr("gate_activation")); + auto cell_act = math::detail::GetActivationType( + ctx.Attr("cell_activation")); + auto cand_act = math::detail::GetActivationType( + ctx.Attr("candidate_activation")); auto batch_starts = batch_gate->lod()[0]; size_t num_batch = batch_starts.size() - 1; diff --git a/paddle/operators/math/detail/activation_functions.h b/paddle/operators/math/detail/activation_functions.h index a20c35d1d9dc4a3a6fae92023fd1aae787a716ec..585a0123437a39c2b610306b18fe0a970c0ed072 100644 --- a/paddle/operators/math/detail/activation_functions.h +++ b/paddle/operators/math/detail/activation_functions.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once #include +#include "paddle/platform/enforce.h" #include "paddle/platform/hostdevice.h" #ifdef __AVX__ @@ -29,6 +30,26 @@ namespace detail { #define SIGMOID_THRESHOLD_MAX 13.0 #define EXP_MAX_INPUT 40.0 +enum ActivationType { + kSigmoid, + kReLU, + kTanh, + kIdentity, +}; + +inline ActivationType GetActivationType(const std::string &type) { + if (type == "sigmoid") { + return ActivationType::kSigmoid; + } else if (type == "relu") { + return ActivationType::kReLU; + } else if (type == "tanh") { + return ActivationType::kTanh; + } else if (type == "identity" || type == "") { + return ActivationType::kIdentity; + } + PADDLE_THROW("Not support type %s.", type); +} + namespace forward { template diff --git a/paddle/operators/math/detail/lstm_cpu_kernel.h b/paddle/operators/math/detail/lstm_cpu_kernel.h index a734ad31eea4816e952641bad73776d93d8c8d34..42888fcdb0a464892e3007ee73c195fcd2a431bb 100644 --- a/paddle/operators/math/detail/lstm_cpu_kernel.h +++ b/paddle/operators/math/detail/lstm_cpu_kernel.h @@ -26,10 +26,9 @@ namespace detail { template void naive_lstm_forward_one_sequence(Op op, LstmMetaValue value, - int frame_size, - activation_mode_t active_node, - activation_mode_t active_gate, - activation_mode_t active_state) { + int frame_size, ActivationType active_node, + ActivationType active_gate, + ActivationType active_state) { T r_value_in; T r_value_ig; T r_value_fg; @@ -77,9 +76,9 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue value, template void naive_lstm_backward_one_sequence(Op op, LstmMetaValue value, LstmMetaGrad grad, int frame_size, - activation_mode_t active_node, - activation_mode_t active_gate, - activation_mode_t active_state) { + ActivationType active_node, + ActivationType active_gate, + ActivationType active_state) { T r_value_in; T r_value_ig; T r_value_fg; @@ -149,10 +148,9 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue value, template void avx_lstm_forward_one_sequence(Op op, LstmMetaValue value, - int frame_size, - activation_mode_t active_node, - activation_mode_t active_gate, - activation_mode_t active_state) { + int frame_size, ActivationType active_node, + ActivationType active_gate, + ActivationType active_state) { #ifdef __AVX__ __m256 r_value_in; __m256 r_value_ig; @@ -204,9 +202,9 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue value, template void avx_lstm_backward_one_sequence(Op op, LstmMetaValue value, LstmMetaGrad grad, int frame_size, - activation_mode_t active_node, - activation_mode_t active_gate, - activation_mode_t active_state) { + ActivationType active_node, + ActivationType active_gate, + ActivationType active_state) { #ifdef __AVX__ __m256 r_value_in; __m256 r_value_ig; @@ -281,9 +279,8 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue value, template void cpu_lstm_forward(Op op, LstmMetaValue value, int frame_size, - activation_mode_t active_node, - activation_mode_t active_gate, - activation_mode_t active_state) { + ActivationType active_node, ActivationType active_gate, + ActivationType active_state) { if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same::value)) { avx_lstm_forward_one_sequence(op, value, frame_size, active_node, active_gate, active_state); @@ -295,9 +292,9 @@ void cpu_lstm_forward(Op op, LstmMetaValue value, int frame_size, template void cpu_lstm_backward(Op op, LstmMetaValue value, LstmMetaGrad grad, - int frame_size, activation_mode_t active_node, - activation_mode_t active_gate, - activation_mode_t active_state) { + int frame_size, ActivationType active_node, + ActivationType active_gate, + ActivationType active_state) { if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same::value)) { avx_lstm_backward_one_sequence(op, value, grad, frame_size, active_node, active_gate, active_state); diff --git a/paddle/operators/math/detail/lstm_gpu_kernel.h b/paddle/operators/math/detail/lstm_gpu_kernel.h index 91bfedea53a2600156c9025f6ff3615d695a712b..e31e657e8b6964c2b99f6e456545c83d8da8e7f9 100644 --- a/paddle/operators/math/detail/lstm_gpu_kernel.h +++ b/paddle/operators/math/detail/lstm_gpu_kernel.h @@ -31,9 +31,9 @@ namespace detail { */ template __global__ void KeLstmForward(Op op, LstmMetaValue value, int frame_size, - int batch_size, activation_mode_t active_node, - activation_mode_t active_gate, - activation_mode_t active_state) { + int batch_size, ActivationType active_node, + ActivationType active_gate, + ActivationType active_state) { const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; if (frame_idx >= frame_size) return; @@ -91,9 +91,9 @@ __global__ void KeLstmForward(Op op, LstmMetaValue value, int frame_size, template __global__ void KeLstmBackward(Op op, LstmMetaValue value, LstmMetaGrad grad, int frame_size, - int batch_size, activation_mode_t active_node, - activation_mode_t active_gate, - activation_mode_t active_state) { + int batch_size, ActivationType active_node, + ActivationType active_gate, + ActivationType active_state) { const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; if (frame_idx >= frame_size) return; @@ -185,9 +185,8 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue value, template void gpu_lstm_forward(const platform::DeviceContext& context, Op op, LstmMetaValue value, int frame_size, int batch_size, - activation_mode_t active_node, - activation_mode_t active_gate, - activation_mode_t active_state) { + ActivationType active_node, ActivationType active_gate, + ActivationType active_state) { dim3 threads; dim3 grid; if (batch_size == 1) { @@ -220,9 +219,8 @@ template void gpu_lstm_backward(const platform::DeviceContext& context, Op op, LstmMetaValue value, LstmMetaGrad grad, int frame_size, int batch_size, - activation_mode_t active_node, - activation_mode_t active_gate, - activation_mode_t active_state) { + ActivationType active_node, ActivationType active_gate, + ActivationType active_state) { dim3 threads; dim3 grid; if (batch_size == 1) { diff --git a/paddle/operators/math/detail/lstm_kernel.h b/paddle/operators/math/detail/lstm_kernel.h index 78f9a249a3d5d413452952edf990975c02f1a369..fed8f9c4ca48905ad4c524ba400e8c7bb2f7fbd1 100644 --- a/paddle/operators/math/detail/lstm_kernel.h +++ b/paddle/operators/math/detail/lstm_kernel.h @@ -30,9 +30,9 @@ class lstm { HOSTDEVICE void operator()(T &value_in, T &value_ig, T &value_fg, T &value_og, T &prev_state, T &state, T &state_atv, T &output, T &checkI, T &checkF, T &checkO, - activation_mode_t active_node, - activation_mode_t active_gate, - activation_mode_t active_state) { + ActivationType active_node, + ActivationType active_gate, + ActivationType active_state) { value_in = activation(value_in, active_node); value_ig = activation(value_ig + prev_state * checkI, active_gate); value_fg = activation(value_fg + prev_state * checkF, active_gate); @@ -53,9 +53,9 @@ class lstm { __m256 &prev_state, __m256 &state, __m256 &state_atv, __m256 &output, __m256 &checkI, __m256 &checkF, __m256 &checkO, - activation_mode_t active_node, - activation_mode_t active_gate, - activation_mode_t active_state) { + ActivationType active_node, + ActivationType active_gate, + ActivationType active_state) { value_in = activation(value_in, active_node); value_ig = activation(_mm256_add_ps(value_ig, _mm256_mul_ps(prev_state, checkI)), @@ -87,9 +87,9 @@ class lstm { T &state_grad, T &state_atv, T &output_grad, T &checkI, T &checkF, T &checkO, T &checkIGrad, T &checkFGrad, T &checkOGrad, - activation_mode_t active_node, - activation_mode_t active_gate, - activation_mode_t active_state) { + ActivationType active_node, + ActivationType active_gate, + ActivationType active_state) { grad_og = activation(output_grad * state_atv, value_og, active_gate); state_grad += activation(output_grad * value_og, state_atv, active_state) + grad_og * checkO; @@ -114,8 +114,8 @@ class lstm { __m256 &prev_state, __m256 &prev_state_grad, __m256 &state, __m256 &state_grad, __m256 &state_atv, __m256 &output_grad, __m256 &checkI, __m256 &checkF, __m256 &checkO, __m256 &checkIGrad, - __m256 &checkFGrad, __m256 &checkOGrad, activation_mode_t active_node, - activation_mode_t active_gate, activation_mode_t active_state) { + __m256 &checkFGrad, __m256 &checkOGrad, ActivationType active_node, + ActivationType active_gate, ActivationType active_state) { grad_og = activation(_mm256_mul_ps(output_grad, state_atv), value_og, active_gate); state_grad = _mm256_add_ps(activation(_mm256_mul_ps(output_grad, value_og), diff --git a/paddle/operators/math/lstm_compute.cc b/paddle/operators/math/lstm_compute.cc index 2c2e8bb82e6f51e21a00de53bbfce5f0b4868e27..d453102ecefc9d79e1f4474ba94be0eb69a87c85 100644 --- a/paddle/operators/math/lstm_compute.cc +++ b/paddle/operators/math/lstm_compute.cc @@ -24,12 +24,12 @@ template struct LstmUnitFunctor { static void compute(const platform::CPUDeviceContext& context, LstmMetaValue value, int frame_size, int batch_size, - const std::string& gate_act, const std::string& cell_act, - const std::string& cand_act) { + const detail::ActivationType& gate_act, + const detail::ActivationType& cell_act, + const detail::ActivationType& cand_act) { for (int b = 0; b < batch_size; b++) { detail::cpu_lstm_forward(detail::forward::lstm(), value, frame_size, - ActiveType(cand_act), ActiveType(gate_act), - ActiveType(cell_act)); + cand_act, gate_act, cell_act); value.gate_value += frame_size * 4; value.state_value += frame_size; value.state_active_value += frame_size; @@ -46,12 +46,12 @@ struct LstmUnitGradFunctor { static void compute(const platform::CPUDeviceContext& context, LstmMetaValue value, LstmMetaGrad grad, int frame_size, int batch_size, - const std::string& gate_act, const std::string& cell_act, - const std::string& cand_act) { + const detail::ActivationType& gate_act, + const detail::ActivationType& cell_act, + const detail::ActivationType& cand_act) { for (int b = 0; b < batch_size; b++) { detail::cpu_lstm_backward(detail::backward::lstm(), value, grad, - frame_size, ActiveType(cand_act), - ActiveType(gate_act), ActiveType(cell_act)); + frame_size, cand_act, gate_act, cell_act); value.gate_value += frame_size * 4; value.state_value += frame_size; diff --git a/paddle/operators/math/lstm_compute.cu b/paddle/operators/math/lstm_compute.cu index 92b1f4228b49709d2903fab518e7649133932fad..82065d699f760db6cc86bf3d6c56e51c583c6ace 100644 --- a/paddle/operators/math/lstm_compute.cu +++ b/paddle/operators/math/lstm_compute.cu @@ -24,11 +24,12 @@ template struct LstmUnitFunctor { static void compute(const platform::CUDADeviceContext& context, LstmMetaValue value, int frame_size, int batch_size, - const std::string& gate_act, const std::string& cell_act, - const std::string& cand_act) { + const detail::ActivationType& gate_act, + const detail::ActivationType& cell_act, + const detail::ActivationType& cand_act) { detail::gpu_lstm_forward(context, detail::forward::lstm(), value, - frame_size, batch_size, ActiveType(cand_act), - ActiveType(gate_act), ActiveType(cell_act)); + frame_size, batch_size, cand_act, gate_act, + cell_act); } }; @@ -37,11 +38,12 @@ struct LstmUnitGradFunctor { static void compute(const platform::CUDADeviceContext& context, LstmMetaValue value, LstmMetaGrad grad, int frame_size, int batch_size, - const std::string& gate_act, const std::string& cell_act, - const std::string& cand_act) { + const detail::ActivationType& gate_act, + const detail::ActivationType& cell_act, + const detail::ActivationType& cand_act) { detail::gpu_lstm_backward(context, detail::backward::lstm(), value, grad, - frame_size, batch_size, ActiveType(cand_act), - ActiveType(gate_act), ActiveType(cell_act)); + frame_size, batch_size, cand_act, gate_act, + cell_act); } }; diff --git a/paddle/operators/math/lstm_compute.h b/paddle/operators/math/lstm_compute.h index 5f74e273585aea5184281bf294df694235150e30..954762f92286fe13bd2c08ec03c3ac96bb663cca 100644 --- a/paddle/operators/math/lstm_compute.h +++ b/paddle/operators/math/lstm_compute.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include "paddle/operators/math/detail/activation_functions.h" #include "paddle/platform/device_context.h" #include "paddle/platform/enforce.h" @@ -72,8 +73,9 @@ class LstmUnitFunctor { public: static void compute(const DeviceContext &context, LstmMetaValue value, int frame_size, int batch_size, - const std::string &gate_act, const std::string &cell_act, - const std::string &cand_act); + const detail::ActivationType &gate_act, + const detail::ActivationType &cell_act, + const detail::ActivationType &cand_act); }; template @@ -81,8 +83,9 @@ class LstmUnitGradFunctor { public: static void compute(const DeviceContext &context, LstmMetaValue value, LstmMetaGrad grad, int frame_size, int batch_size, - const std::string &gate_act, const std::string &cell_act, - const std::string &cand_act); + const detail::ActivationType &gate_act, + const detail::ActivationType &cell_act, + const detail::ActivationType &cand_act); }; } // namespace math