From 75426e013a8af9a327a1c47008719053a4df8dff Mon Sep 17 00:00:00 2001 From: guosheng Date: Thu, 16 Nov 2017 11:24:08 +0800 Subject: [PATCH] Refine GRU Operator --- paddle/operators/gru_op.h | 1 + python/paddle/v2/framework/tests/test_gru_op.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/operators/gru_op.h b/paddle/operators/gru_op.h index b2cf358994..9fb60e20d1 100644 --- a/paddle/operators/gru_op.h +++ b/paddle/operators/gru_op.h @@ -154,6 +154,7 @@ class GRUGradKernel : public framework::OpKernel { } if (h0_grad) { ordered_h0_grad.mutable_data(h0_grad->dims(), context.GetPlace()); + zero(context.device_context(), &ordered_h0_grad, static_cast(0.0)); } bool is_reverse = context.Attr("is_reverse"); diff --git a/python/paddle/v2/framework/tests/test_gru_op.py b/python/paddle/v2/framework/tests/test_gru_op.py index 2bb78d10e0..fa2c5a53ec 100644 --- a/python/paddle/v2/framework/tests/test_gru_op.py +++ b/python/paddle/v2/framework/tests/test_gru_op.py @@ -149,7 +149,7 @@ class TestGRUOpReverse(TestGRUOp): self.is_reverse = True self.attrs = { 'activation': 'tanh', - 'gate_activation': 'tanh', + 'gate_activation': 'sigmoid', 'is_reverse': self.is_reverse } -- GitLab