From 7a81ab860764c9cd667a09a9739d5fb4de3d2806 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Sun, 6 Jan 2019 22:23:43 +0800 Subject: [PATCH] complete gru_unite_op and test --- paddle/fluid/operators/gru_unit_op.h | 22 +++++++++++++------ .../fluid/tests/unittests/test_gru_unit_op.py | 14 +++++++----- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/operators/gru_unit_op.h b/paddle/fluid/operators/gru_unit_op.h index ed0c689cfe5..712ef05d863 100644 --- a/paddle/fluid/operators/gru_unit_op.h +++ b/paddle/fluid/operators/gru_unit_op.h @@ -184,11 +184,19 @@ class GRUUnitGradKernel : public framework::OpKernel { auto c = g.slice(c_offsets, extents); // output candidate // backward for unactivated update gate - ActGradCompute(context.Attr("gate_activation"), place, u, u, - d_g.slice(u_offsets, extents), d_h * (c - h_p)); - // backward for unactivated output candidate - ActGradCompute(context.Attr("activation"), place, c, c, - d_g.slice(c_offsets, extents), d_h * u); + if (context.Attr("origin_mode")) { + ActGradCompute(context.Attr("gate_activation"), place, u, u, + d_g.slice(u_offsets, extents), d_h * (h_p - c)); + // backward for unactivated output candidate + ActGradCompute(context.Attr("activation"), place, c, c, + d_g.slice(c_offsets, extents), d_h * (1 - u)); + } else { + ActGradCompute(context.Attr("gate_activation"), place, u, u, + d_g.slice(u_offsets, extents), d_h * (c - h_p)); + // backward for unactivated output candidate + ActGradCompute(context.Attr("activation"), place, c, c, + d_g.slice(c_offsets, extents), d_h * u); + } // backward for reset_hidden_prev auto blas = math::GetBlas(context); blas.GEMM(false, true, batch_size, frame_size, frame_size, 1, @@ -218,9 +226,9 @@ class GRUUnitGradKernel : public framework::OpKernel { hidden_prev_grad->mutable_data(context.GetPlace()); auto d_h_p = EigenMatrix::From(*hidden_prev_grad); if (context.Attr("origin_mode")) { - d_h_p.device(place) = d_r_h_p * (u.constant(T(1)) - u) + d_h * r; + d_h_p.device(place) = d_r_h_p * r + d_h * u; } else { - d_h_p.device(place) = d_r_h_p * r + d_h * (u.constant(T(1)) - u); + d_h_p.device(place) = d_r_h_p * r + d_h * (1 - u); } blas.GEMM(false, true, batch_size, frame_size, frame_size * 2, 1, gate_grad_data, frame_size * 3, weight_data, frame_size * 2, 1, diff --git a/python/paddle/fluid/tests/unittests/test_gru_unit_op.py b/python/paddle/fluid/tests/unittests/test_gru_unit_op.py index fdc6b75dbda..78f2f030f5b 100644 --- a/python/paddle/fluid/tests/unittests/test_gru_unit_op.py +++ b/python/paddle/fluid/tests/unittests/test_gru_unit_op.py @@ -53,7 +53,7 @@ class TestGRUUnitOp(OpTest): GRUActivationType.relu: relu, } - def set_inputs(self): + def set_inputs(self, origin_mode=False): batch_size = self.batch_size frame_size = self.frame_size self.op_type = 'gru_unit' @@ -68,7 +68,8 @@ class TestGRUUnitOp(OpTest): } self.attrs = { 'activation': GRUActivationType.tanh, - 'gate_activation': GRUActivationType.sigmoid + 'gate_activation': GRUActivationType.sigmoid, + 'origin_mode': origin_mode } def set_outputs(self, origin_mode=False): @@ -116,12 +117,12 @@ class TestGRUUnitOp(OpTest): class TestGRUUnitOpOriginMode(TestGRUUnitOp): def setUp(self): - self.set_inputs() + self.set_inputs(origin_mode=True) self.set_outputs(origin_mode=True) class TestGRUUnitOpWithBias(TestGRUUnitOp): - def set_inputs(self): + def set_inputs(self, origin_mode=False): batch_size = self.batch_size frame_size = self.frame_size super(TestGRUUnitOpWithBias, self).set_inputs() @@ -129,7 +130,8 @@ class TestGRUUnitOpWithBias(TestGRUUnitOp): -0.1, 0.1, (1, frame_size * 3)).astype('float64') self.attrs = { 'activation': GRUActivationType.identity, - 'gate_activation': GRUActivationType.sigmoid + 'gate_activation': GRUActivationType.sigmoid, + 'origin_mode': origin_mode } def test_check_grad(self): @@ -143,7 +145,7 @@ class TestGRUUnitOpWithBias(TestGRUUnitOp): class TestGRUUnitOpWithBiasOriginMode(TestGRUUnitOpWithBias): def setUp(self): - self.set_inputs() + self.set_inputs(origin_mode=True) self.set_outputs(origin_mode=True) -- GitLab