diff --git a/paddle/fluid/operators/gru_unit_op.cc b/paddle/fluid/operators/gru_unit_op.cc index 82a808b01e99ec33b0ca00a065fb301d3c633b19..6d91d0d5c0dad14c84dcddbfdfcf4edb82db7687 100644 --- a/paddle/fluid/operators/gru_unit_op.cc +++ b/paddle/fluid/operators/gru_unit_op.cc @@ -111,6 +111,10 @@ class GRUUnitOpMaker : public framework::OpProtoAndCheckerMaker { "The activation type used in update gate and reset gate.") .SetDefault(sigmoid) .InEnum({identity, sigmoid, tanh, relu}); + AddAttr("origin_mode", + "bool" + "use origin mode in article https://arxiv.org/abs/1412.3555") + .SetDefault(false); AddComment(R"DOC( GRUUnit Operator implements partial calculations of the GRU unit as following: diff --git a/paddle/fluid/operators/gru_unit_op.h b/paddle/fluid/operators/gru_unit_op.h index 451ec61ba1f7239d92c6dfbad0b2961e74e1bc17..9f56a0ef73156069ae433fe47242e5ac3bf09c15 100644 --- a/paddle/fluid/operators/gru_unit_op.h +++ b/paddle/fluid/operators/gru_unit_op.h @@ -113,7 +113,12 @@ class GRUUnitKernel : public framework::OpKernel { auto c = g.slice(c_offsets, extents); // output candidate // calculate final output - h.device(place) = u * (c - h_p) + h_p; + bool origin_mode = context.Attr("origin_mode"); + if (origin_mode) { + h.device(place) = c + u * (h_p - c); // (1 - u) * c + u * h_p + } else { + h.device(place) = u * (c - h_p) + h_p; // u * c + (1 - u) * h_p + } } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 9572fcb385823eab16d5c44fd56c680e577c8f04..d9c6b02c4cb6f86ee48c72c7a84491dc186393a6 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -991,7 +991,8 @@ def gru_unit(input, param_attr=None, bias_attr=None, activation='tanh', - gate_activation='sigmoid'): + gate_activation='sigmoid', + origin_mode=False): """ GRU unit layer. The equation of a gru step is: diff --git a/python/paddle/fluid/tests/unittests/test_gru_op.py b/python/paddle/fluid/tests/unittests/test_gru_op.py index f61a447fd77d001c1a9c42c46ba9a0d747e926ec..6606162733487b15ef55f1a4677fb382e6e7e0ac 100644 --- a/python/paddle/fluid/tests/unittests/test_gru_op.py +++ b/python/paddle/fluid/tests/unittests/test_gru_op.py @@ -31,7 +31,8 @@ def gru( is_reverse, act_state, act_gate, - dtype='float32'): + dtype='float32', + origin_mode=False): def _seq_to_batch(lod, is_reverse): idx_in_seq_list = [] seq_lens = lod[0] @@ -66,7 +67,10 @@ def gru( w_c = w.flatten()[D * D * 2:].reshape((D, D)) c = act_state(np.dot(r_h_p, w_c) + g[:, D * 2:]) g = np.hstack((u_r, c)) - h = u * c + (1 - u) * h_p + if origin_mode: + h = (1 - u) * c + u * h_p + else: + h = u * c + (1 - u) * h_p return g, r_h_p, h T = sum(lod[0]) @@ -110,6 +114,7 @@ class TestGRUOp(OpTest): self.act_state = 'tanh' self.act_gate = 'sigmoid' self.dtype = 'float64' + self.origin_mode = False self.set_confs() T = sum(self.lod[0]) @@ -126,7 +131,8 @@ class TestGRUOp(OpTest): batch_gate, batch_reset_hidden_prev, batch_hidden, hidden = gru( input, self.lod, h0, weight, bias, self.is_reverse, - ACTIVATION[self.act_state], ACTIVATION[self.act_gate], self.dtype) + ACTIVATION[self.act_state], ACTIVATION[self.act_gate], self.dtype, + self.origin_mode) self.inputs = {'Input': (input, self.lod), 'Weight': weight} if self.with_bias: @@ -145,7 +151,8 @@ class TestGRUOp(OpTest): self.attrs = { 'activation': self.act_state, 'gate_activation': self.act_gate, - 'is_reverse': self.is_reverse + 'is_reverse': self.is_reverse, + 'origin_mode': self.origin_mode } def test_check_output(self): @@ -155,12 +162,24 @@ class TestGRUOp(OpTest): self.check_grad(['Input', 'H0', 'Weight', 'Bias'], ['Hidden']) +class TestGRUOriginMode(TestGRUOp): + def set_confs(self): + self.origin_mode = True + + class TestGRUOp2(TestGRUOp): def set_confs(self): self.D = 19 self.dtype = 'float32' +class TestGRUOp2OriginMode(TestGRUOp): + def set_confs(self): + self.D = 19 + self.dtype = 'float32' + self.origin_mode = True + + class TestGRUOpNoInitial(TestGRUOp): def set_confs(self): self.with_h0 = False @@ -182,5 +201,11 @@ class TestGRUOpReverse(TestGRUOp): self.is_reverse = True +class TestGRUOpReverseOriginMode(TestGRUOp): + def set_confs(self): + self.is_reverse = True + self.origin_mode = True + + if __name__ == "__main__": unittest.main()