提交 23b53c48 编写于 作者: G guosheng

Delete the old activation type for LSTM and GRU operator

上级 f74dff97
......@@ -22,7 +22,8 @@ template <typename T>
struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
static void compute(const platform::CUDADeviceContext &context,
GRUMetaValue<T> value, int frame_size, int batch_size,
ActivationType active_node, ActivationType active_gate) {
const detail::ActivationType active_node,
const detail::ActivationType active_gate) {
auto stream = context.stream();
dim3 threads;
dim3 grid;
......@@ -89,7 +90,8 @@ struct GRUUnitGradFunctor<platform::CUDADeviceContext, T> {
static void compute(const platform::CUDADeviceContext &context,
GRUMetaValue<T> value, GRUMetaGrad<T> grad,
int frame_size, int batch_size,
ActivationType active_node, ActivationType active_gate) {
const detail::ActivationType active_node,
const detail::ActivationType active_gate) {
auto stream = context.stream();
dim3 threads;
dim3 grid;
......
......@@ -22,14 +22,6 @@ namespace paddle {
namespace operators {
namespace math {
typedef enum {
HL_ACTIVATION_SIGMOID = 0,
HL_ACTIVATION_RELU = 1,
HL_ACTIVATION_TANH = 2,
HL_ACTIVATION_LINEAR = 3,
HL_ACTIVATION_END
} activation_mode_t;
template <class T>
struct LstmMetaValue {
T *gate_value;
......@@ -54,20 +46,6 @@ struct LstmMetaGrad {
T *check_og_grad;
};
inline activation_mode_t ActiveType(const std::string &type) {
if (type == "sigmoid") {
return HL_ACTIVATION_SIGMOID;
} else if (type == "relu") {
return HL_ACTIVATION_RELU;
} else if (type == "tanh") {
return HL_ACTIVATION_TANH;
} else if (type == "linear" || type == "identity" || type == "") {
return HL_ACTIVATION_LINEAR;
} else {
PADDLE_THROW("Do not support activation type.");
}
}
template <typename DeviceContext, typename T>
class LstmUnitFunctor {
public:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册