diff --git a/paddle/operators/math/gru_compute.cu b/paddle/operators/math/gru_compute.cu index aab3e2309bb952a40136e1eeac018227c0398fbc..d5a0e630ea0eadea990988c3170395c842a91900 100644 --- a/paddle/operators/math/gru_compute.cu +++ b/paddle/operators/math/gru_compute.cu @@ -22,7 +22,8 @@ template struct GRUUnitFunctor { static void compute(const platform::CUDADeviceContext &context, GRUMetaValue value, int frame_size, int batch_size, - ActivationType active_node, ActivationType active_gate) { + const detail::ActivationType active_node, + const detail::ActivationType active_gate) { auto stream = context.stream(); dim3 threads; dim3 grid; @@ -89,7 +90,8 @@ struct GRUUnitGradFunctor { static void compute(const platform::CUDADeviceContext &context, GRUMetaValue value, GRUMetaGrad grad, int frame_size, int batch_size, - ActivationType active_node, ActivationType active_gate) { + const detail::ActivationType active_node, + const detail::ActivationType active_gate) { auto stream = context.stream(); dim3 threads; dim3 grid; diff --git a/paddle/operators/math/lstm_compute.h b/paddle/operators/math/lstm_compute.h index 954762f92286fe13bd2c08ec03c3ac96bb663cca..e1ad6b64d201ef99d83eaa2a821356008dcc9c8e 100644 --- a/paddle/operators/math/lstm_compute.h +++ b/paddle/operators/math/lstm_compute.h @@ -22,14 +22,6 @@ namespace paddle { namespace operators { namespace math { -typedef enum { - HL_ACTIVATION_SIGMOID = 0, - HL_ACTIVATION_RELU = 1, - HL_ACTIVATION_TANH = 2, - HL_ACTIVATION_LINEAR = 3, - HL_ACTIVATION_END -} activation_mode_t; - template struct LstmMetaValue { T *gate_value; @@ -54,20 +46,6 @@ struct LstmMetaGrad { T *check_og_grad; }; -inline activation_mode_t ActiveType(const std::string &type) { - if (type == "sigmoid") { - return HL_ACTIVATION_SIGMOID; - } else if (type == "relu") { - return HL_ACTIVATION_RELU; - } else if (type == "tanh") { - return HL_ACTIVATION_TANH; - } else if (type == "linear" || type == "identity" || type == "") { - return HL_ACTIVATION_LINEAR; - } else { - PADDLE_THROW("Do not support activation type."); - } -} - template class LstmUnitFunctor { public: