提交 a6ff5240 编写于 作者: G guosheng

Refine the activation type of GRUOp by following comments

上级 23b53c48
...@@ -90,6 +90,10 @@ class GRUKernel : public framework::OpKernel<T> { ...@@ -90,6 +90,10 @@ class GRUKernel : public framework::OpKernel<T> {
} }
auto batch_starts = batch_gate->lod()[0]; auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1; size_t num_batch = batch_starts.size() - 1;
auto active_node = math::detail::GetActivationType(
context.Attr<std::string>("activation"));
auto active_gate = math::detail::GetActivationType(
context.Attr<std::string>("gate_activation"));
for (size_t n = 0; n < num_batch; n++) { for (size_t n = 0; n < num_batch; n++) {
int bstart = static_cast<int>(batch_starts[n]); int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]); int bend = static_cast<int>(batch_starts[n + 1]);
...@@ -102,11 +106,8 @@ class GRUKernel : public framework::OpKernel<T> { ...@@ -102,11 +106,8 @@ class GRUKernel : public framework::OpKernel<T> {
gru_value.gate_value = gate_t.data<T>(); gru_value.gate_value = gate_t.data<T>();
gru_value.reset_output_value = reset_hidden_prev_t.data<T>(); gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
math::GRUUnitFunctor<DeviceContext, T>::compute( math::GRUUnitFunctor<DeviceContext, T>::compute(
dev_ctx, gru_value, frame_size, cur_batch_size, dev_ctx, gru_value, frame_size, cur_batch_size, active_node,
math::detail::GetActivationType( active_gate);
context.Attr<std::string>("activation")),
math::detail::GetActivationType(
context.Attr<std::string>("gate_activation")));
gru_value.prev_out_value = gru_value.output_value; gru_value.prev_out_value = gru_value.output_value;
} }
...@@ -192,6 +193,10 @@ class GRUGradKernel : public framework::OpKernel<T> { ...@@ -192,6 +193,10 @@ class GRUGradKernel : public framework::OpKernel<T> {
auto batch_starts = batch_hidden_grad.lod()[0]; auto batch_starts = batch_hidden_grad.lod()[0];
size_t num_batch = batch_starts.size() - 1; size_t num_batch = batch_starts.size() - 1;
auto active_node = math::detail::GetActivationType(
context.Attr<std::string>("activation"));
auto active_gate = math::detail::GetActivationType(
context.Attr<std::string>("gate_activation"));
for (int n = static_cast<int>(num_batch) - 1; n >= 0; n--) { for (int n = static_cast<int>(num_batch) - 1; n >= 0; n--) {
int bstart = static_cast<int>(batch_starts[n]); int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]); int bend = static_cast<int>(batch_starts[n + 1]);
...@@ -222,11 +227,8 @@ class GRUGradKernel : public framework::OpKernel<T> { ...@@ -222,11 +227,8 @@ class GRUGradKernel : public framework::OpKernel<T> {
} }
math::GRUUnitGradFunctor<DeviceContext, T>::compute( math::GRUUnitGradFunctor<DeviceContext, T>::compute(
dev_ctx, gru_value, gru_grad, frame_size, cur_batch_size, dev_ctx, gru_value, gru_grad, frame_size, cur_batch_size, active_node,
math::detail::GetActivationType( active_gate);
context.Attr<std::string>("activation")),
math::detail::GetActivationType(
context.Attr<std::string>("gate_activation")));
} }
if (input_grad) { if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace()); input_grad->mutable_data<T>(context.GetPlace());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册