diff --git a/paddle/operators/lstm_op.h b/paddle/operators/lstm_op.h index 1ce8b5fbe4acad0e973fa7f05c194b4b5211bf20..c57ee414dc5b3417549c8ac3a7fd57a9c8f452df 100644 --- a/paddle/operators/lstm_op.h +++ b/paddle/operators/lstm_op.h @@ -14,10 +14,10 @@ limitations under the License. */ #pragma once #include "paddle/framework/op_registry.h" +#include "paddle/operators/math/detail/activation_functions.h" #include "paddle/operators/math/lstm_compute.h" #include "paddle/operators/math/math_function.h" #include "paddle/operators/math/sequence2batch.h" -#include "paddle/operators/math/detail/activation_functions.h" namespace paddle { namespace operators { diff --git a/paddle/operators/math/detail/activation_functions.h b/paddle/operators/math/detail/activation_functions.h index 9e8b591cf48586ed78d902e573e8ed02afb04459..585a0123437a39c2b610306b18fe0a970c0ed072 100644 --- a/paddle/operators/math/detail/activation_functions.h +++ b/paddle/operators/math/detail/activation_functions.h @@ -14,8 +14,8 @@ limitations under the License. */ #pragma once #include -#include "paddle/platform/hostdevice.h" #include "paddle/platform/enforce.h" +#include "paddle/platform/hostdevice.h" #ifdef __AVX__ #include @@ -37,20 +37,19 @@ enum ActivationType { kIdentity, }; -inline ActivationType GetActivationType (const std::string &type) { +inline ActivationType GetActivationType(const std::string &type) { if (type == "sigmoid") { return ActivationType::kSigmoid; } else if (type == "relu") { return ActivationType::kReLU; } else if (type == "tanh") { return ActivationType::kTanh; - } else if (type == "identity") { + } else if (type == "identity" || type == "") { return ActivationType::kIdentity; } PADDLE_THROW("Not support type %s.", type); } - namespace forward { template diff --git a/paddle/operators/math/detail/lstm_cpu_kernel.h b/paddle/operators/math/detail/lstm_cpu_kernel.h index b37d85b7399a462bbf897ef0e838d20bb01971bf..42888fcdb0a464892e3007ee73c195fcd2a431bb 100644 --- a/paddle/operators/math/detail/lstm_cpu_kernel.h +++ b/paddle/operators/math/detail/lstm_cpu_kernel.h @@ -26,8 +26,7 @@ namespace detail { template void naive_lstm_forward_one_sequence(Op op, LstmMetaValue value, - int frame_size, - ActivationType active_node, + int frame_size, ActivationType active_node, ActivationType active_gate, ActivationType active_state) { T r_value_in; @@ -149,8 +148,7 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue value, template void avx_lstm_forward_one_sequence(Op op, LstmMetaValue value, - int frame_size, - ActivationType active_node, + int frame_size, ActivationType active_node, ActivationType active_gate, ActivationType active_state) { #ifdef __AVX__ @@ -281,8 +279,7 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue value, template void cpu_lstm_forward(Op op, LstmMetaValue value, int frame_size, - ActivationType active_node, - ActivationType active_gate, + ActivationType active_node, ActivationType active_gate, ActivationType active_state) { if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same::value)) { avx_lstm_forward_one_sequence(op, value, frame_size, active_node, diff --git a/paddle/operators/math/detail/lstm_gpu_kernel.h b/paddle/operators/math/detail/lstm_gpu_kernel.h index e1a787deeef7b403398ca811e232269c8895a616..e31e657e8b6964c2b99f6e456545c83d8da8e7f9 100644 --- a/paddle/operators/math/detail/lstm_gpu_kernel.h +++ b/paddle/operators/math/detail/lstm_gpu_kernel.h @@ -185,8 +185,7 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue value, template void gpu_lstm_forward(const platform::DeviceContext& context, Op op, LstmMetaValue value, int frame_size, int batch_size, - ActivationType active_node, - ActivationType active_gate, + ActivationType active_node, ActivationType active_gate, ActivationType active_state) { dim3 threads; dim3 grid; @@ -220,8 +219,7 @@ template void gpu_lstm_backward(const platform::DeviceContext& context, Op op, LstmMetaValue value, LstmMetaGrad grad, int frame_size, int batch_size, - ActivationType active_node, - ActivationType active_gate, + ActivationType active_node, ActivationType active_gate, ActivationType active_state) { dim3 threads; dim3 grid; diff --git a/paddle/operators/math/lstm_compute.cu b/paddle/operators/math/lstm_compute.cu index 4d8651e39760d6270ff0e6246e9b6c32ca941e03..82065d699f760db6cc86bf3d6c56e51c583c6ace 100644 --- a/paddle/operators/math/lstm_compute.cu +++ b/paddle/operators/math/lstm_compute.cu @@ -28,8 +28,8 @@ struct LstmUnitFunctor { const detail::ActivationType& cell_act, const detail::ActivationType& cand_act) { detail::gpu_lstm_forward(context, detail::forward::lstm(), value, - frame_size, batch_size, cand_act, - gate_act, cell_act); + frame_size, batch_size, cand_act, gate_act, + cell_act); } }; @@ -42,8 +42,8 @@ struct LstmUnitGradFunctor { const detail::ActivationType& cell_act, const detail::ActivationType& cand_act) { detail::gpu_lstm_backward(context, detail::backward::lstm(), value, grad, - frame_size, batch_size, cand_act, - gate_act, cell_act); + frame_size, batch_size, cand_act, gate_act, + cell_act); } }; diff --git a/paddle/operators/math/lstm_compute.h b/paddle/operators/math/lstm_compute.h index 4935f8ebd15b7412df1e9b053f38e55b28691a99..954762f92286fe13bd2c08ec03c3ac96bb663cca 100644 --- a/paddle/operators/math/lstm_compute.h +++ b/paddle/operators/math/lstm_compute.h @@ -14,9 +14,9 @@ limitations under the License. */ #pragma once +#include "paddle/operators/math/detail/activation_functions.h" #include "paddle/platform/device_context.h" #include "paddle/platform/enforce.h" -#include "paddle/operators/math/detail/activation_functions.h" namespace paddle { namespace operators { @@ -30,7 +30,6 @@ typedef enum { HL_ACTIVATION_END } activation_mode_t; - template struct LstmMetaValue { T *gate_value;