提交 23b53c48 编写于 作者: G guosheng

Delete the old activation type for LSTM and GRU operator

上级 f74dff97
...@@ -22,7 +22,8 @@ template <typename T> ...@@ -22,7 +22,8 @@ template <typename T>
struct GRUUnitFunctor<platform::CUDADeviceContext, T> { struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
static void compute(const platform::CUDADeviceContext &context, static void compute(const platform::CUDADeviceContext &context,
GRUMetaValue<T> value, int frame_size, int batch_size, GRUMetaValue<T> value, int frame_size, int batch_size,
ActivationType active_node, ActivationType active_gate) { const detail::ActivationType active_node,
const detail::ActivationType active_gate) {
auto stream = context.stream(); auto stream = context.stream();
dim3 threads; dim3 threads;
dim3 grid; dim3 grid;
...@@ -89,7 +90,8 @@ struct GRUUnitGradFunctor<platform::CUDADeviceContext, T> { ...@@ -89,7 +90,8 @@ struct GRUUnitGradFunctor<platform::CUDADeviceContext, T> {
static void compute(const platform::CUDADeviceContext &context, static void compute(const platform::CUDADeviceContext &context,
GRUMetaValue<T> value, GRUMetaGrad<T> grad, GRUMetaValue<T> value, GRUMetaGrad<T> grad,
int frame_size, int batch_size, int frame_size, int batch_size,
ActivationType active_node, ActivationType active_gate) { const detail::ActivationType active_node,
const detail::ActivationType active_gate) {
auto stream = context.stream(); auto stream = context.stream();
dim3 threads; dim3 threads;
dim3 grid; dim3 grid;
......
...@@ -22,14 +22,6 @@ namespace paddle { ...@@ -22,14 +22,6 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
typedef enum {
HL_ACTIVATION_SIGMOID = 0,
HL_ACTIVATION_RELU = 1,
HL_ACTIVATION_TANH = 2,
HL_ACTIVATION_LINEAR = 3,
HL_ACTIVATION_END
} activation_mode_t;
template <class T> template <class T>
struct LstmMetaValue { struct LstmMetaValue {
T *gate_value; T *gate_value;
...@@ -54,20 +46,6 @@ struct LstmMetaGrad { ...@@ -54,20 +46,6 @@ struct LstmMetaGrad {
T *check_og_grad; T *check_og_grad;
}; };
inline activation_mode_t ActiveType(const std::string &type) {
if (type == "sigmoid") {
return HL_ACTIVATION_SIGMOID;
} else if (type == "relu") {
return HL_ACTIVATION_RELU;
} else if (type == "tanh") {
return HL_ACTIVATION_TANH;
} else if (type == "linear" || type == "identity" || type == "") {
return HL_ACTIVATION_LINEAR;
} else {
PADDLE_THROW("Do not support activation type.");
}
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class LstmUnitFunctor { class LstmUnitFunctor {
public: public:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册