提交 75426e01 编写于 作者: G guosheng

Refine GRU Operator

上级 2bed9612
...@@ -154,6 +154,7 @@ class GRUGradKernel : public framework::OpKernel<T> { ...@@ -154,6 +154,7 @@ class GRUGradKernel : public framework::OpKernel<T> {
} }
if (h0_grad) { if (h0_grad) {
ordered_h0_grad.mutable_data<T>(h0_grad->dims(), context.GetPlace()); ordered_h0_grad.mutable_data<T>(h0_grad->dims(), context.GetPlace());
zero(context.device_context(), &ordered_h0_grad, static_cast<T>(0.0));
} }
bool is_reverse = context.Attr<bool>("is_reverse"); bool is_reverse = context.Attr<bool>("is_reverse");
......
...@@ -149,7 +149,7 @@ class TestGRUOpReverse(TestGRUOp): ...@@ -149,7 +149,7 @@ class TestGRUOpReverse(TestGRUOp):
self.is_reverse = True self.is_reverse = True
self.attrs = { self.attrs = {
'activation': 'tanh', 'activation': 'tanh',
'gate_activation': 'tanh', 'gate_activation': 'sigmoid',
'is_reverse': self.is_reverse 'is_reverse': self.is_reverse
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册