提交 78ec7c0f 编写于 作者: Q Qiao Longfei

gru add origin mode

test=develop
上级 67093da3
...@@ -111,6 +111,10 @@ class GRUUnitOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -111,6 +111,10 @@ class GRUUnitOpMaker : public framework::OpProtoAndCheckerMaker {
"The activation type used in update gate and reset gate.") "The activation type used in update gate and reset gate.")
.SetDefault(sigmoid) .SetDefault(sigmoid)
.InEnum({identity, sigmoid, tanh, relu}); .InEnum({identity, sigmoid, tanh, relu});
AddAttr<bool>("origin_mode",
"bool"
"use origin mode in article https://arxiv.org/abs/1412.3555")
.SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
GRUUnit Operator implements partial calculations of the GRU unit as following: GRUUnit Operator implements partial calculations of the GRU unit as following:
......
...@@ -113,7 +113,12 @@ class GRUUnitKernel : public framework::OpKernel<T> { ...@@ -113,7 +113,12 @@ class GRUUnitKernel : public framework::OpKernel<T> {
auto c = g.slice(c_offsets, extents); // output candidate auto c = g.slice(c_offsets, extents); // output candidate
// calculate final output // calculate final output
h.device(place) = u * (c - h_p) + h_p; bool origin_mode = context.Attr<bool>("origin_mode");
if (origin_mode) {
h.device(place) = c + u * (h_p - c); // (1 - u) * c + u * h_p
} else {
h.device(place) = u * (c - h_p) + h_p; // u * c + (1 - u) * h_p
}
} }
}; };
......
...@@ -991,7 +991,8 @@ def gru_unit(input, ...@@ -991,7 +991,8 @@ def gru_unit(input,
param_attr=None, param_attr=None,
bias_attr=None, bias_attr=None,
activation='tanh', activation='tanh',
gate_activation='sigmoid'): gate_activation='sigmoid',
origin_mode=False):
""" """
GRU unit layer. The equation of a gru step is: GRU unit layer. The equation of a gru step is:
......
...@@ -31,7 +31,8 @@ def gru( ...@@ -31,7 +31,8 @@ def gru(
is_reverse, is_reverse,
act_state, act_state,
act_gate, act_gate,
dtype='float32'): dtype='float32',
origin_mode=False):
def _seq_to_batch(lod, is_reverse): def _seq_to_batch(lod, is_reverse):
idx_in_seq_list = [] idx_in_seq_list = []
seq_lens = lod[0] seq_lens = lod[0]
...@@ -66,7 +67,10 @@ def gru( ...@@ -66,7 +67,10 @@ def gru(
w_c = w.flatten()[D * D * 2:].reshape((D, D)) w_c = w.flatten()[D * D * 2:].reshape((D, D))
c = act_state(np.dot(r_h_p, w_c) + g[:, D * 2:]) c = act_state(np.dot(r_h_p, w_c) + g[:, D * 2:])
g = np.hstack((u_r, c)) g = np.hstack((u_r, c))
h = u * c + (1 - u) * h_p if origin_mode:
h = (1 - u) * c + u * h_p
else:
h = u * c + (1 - u) * h_p
return g, r_h_p, h return g, r_h_p, h
T = sum(lod[0]) T = sum(lod[0])
...@@ -110,6 +114,7 @@ class TestGRUOp(OpTest): ...@@ -110,6 +114,7 @@ class TestGRUOp(OpTest):
self.act_state = 'tanh' self.act_state = 'tanh'
self.act_gate = 'sigmoid' self.act_gate = 'sigmoid'
self.dtype = 'float64' self.dtype = 'float64'
self.origin_mode = False
self.set_confs() self.set_confs()
T = sum(self.lod[0]) T = sum(self.lod[0])
...@@ -126,7 +131,8 @@ class TestGRUOp(OpTest): ...@@ -126,7 +131,8 @@ class TestGRUOp(OpTest):
batch_gate, batch_reset_hidden_prev, batch_hidden, hidden = gru( batch_gate, batch_reset_hidden_prev, batch_hidden, hidden = gru(
input, self.lod, h0, weight, bias, self.is_reverse, input, self.lod, h0, weight, bias, self.is_reverse,
ACTIVATION[self.act_state], ACTIVATION[self.act_gate], self.dtype) ACTIVATION[self.act_state], ACTIVATION[self.act_gate], self.dtype,
self.origin_mode)
self.inputs = {'Input': (input, self.lod), 'Weight': weight} self.inputs = {'Input': (input, self.lod), 'Weight': weight}
if self.with_bias: if self.with_bias:
...@@ -145,7 +151,8 @@ class TestGRUOp(OpTest): ...@@ -145,7 +151,8 @@ class TestGRUOp(OpTest):
self.attrs = { self.attrs = {
'activation': self.act_state, 'activation': self.act_state,
'gate_activation': self.act_gate, 'gate_activation': self.act_gate,
'is_reverse': self.is_reverse 'is_reverse': self.is_reverse,
'origin_mode': self.origin_mode
} }
def test_check_output(self): def test_check_output(self):
...@@ -155,12 +162,24 @@ class TestGRUOp(OpTest): ...@@ -155,12 +162,24 @@ class TestGRUOp(OpTest):
self.check_grad(['Input', 'H0', 'Weight', 'Bias'], ['Hidden']) self.check_grad(['Input', 'H0', 'Weight', 'Bias'], ['Hidden'])
class TestGRUOriginMode(TestGRUOp):
def set_confs(self):
self.origin_mode = True
class TestGRUOp2(TestGRUOp): class TestGRUOp2(TestGRUOp):
def set_confs(self): def set_confs(self):
self.D = 19 self.D = 19
self.dtype = 'float32' self.dtype = 'float32'
class TestGRUOp2OriginMode(TestGRUOp):
def set_confs(self):
self.D = 19
self.dtype = 'float32'
self.origin_mode = True
class TestGRUOpNoInitial(TestGRUOp): class TestGRUOpNoInitial(TestGRUOp):
def set_confs(self): def set_confs(self):
self.with_h0 = False self.with_h0 = False
...@@ -182,5 +201,11 @@ class TestGRUOpReverse(TestGRUOp): ...@@ -182,5 +201,11 @@ class TestGRUOpReverse(TestGRUOp):
self.is_reverse = True self.is_reverse = True
class TestGRUOpReverseOriginMode(TestGRUOp):
def set_confs(self):
self.is_reverse = True
self.origin_mode = True
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册