提交 78ec7c0f 编写于 作者: Q Qiao Longfei

gru add origin mode

test=develop
上级 67093da3
......@@ -111,6 +111,10 @@ class GRUUnitOpMaker : public framework::OpProtoAndCheckerMaker {
"The activation type used in update gate and reset gate.")
.SetDefault(sigmoid)
.InEnum({identity, sigmoid, tanh, relu});
AddAttr<bool>("origin_mode",
"bool"
"use origin mode in article https://arxiv.org/abs/1412.3555")
.SetDefault(false);
AddComment(R"DOC(
GRUUnit Operator implements partial calculations of the GRU unit as following:
......
......@@ -113,7 +113,12 @@ class GRUUnitKernel : public framework::OpKernel<T> {
auto c = g.slice(c_offsets, extents); // output candidate
// calculate final output
h.device(place) = u * (c - h_p) + h_p;
bool origin_mode = context.Attr<bool>("origin_mode");
if (origin_mode) {
h.device(place) = c + u * (h_p - c); // (1 - u) * c + u * h_p
} else {
h.device(place) = u * (c - h_p) + h_p; // u * c + (1 - u) * h_p
}
}
};
......
......@@ -991,7 +991,8 @@ def gru_unit(input,
param_attr=None,
bias_attr=None,
activation='tanh',
gate_activation='sigmoid'):
gate_activation='sigmoid',
origin_mode=False):
"""
GRU unit layer. The equation of a gru step is:
......
......@@ -31,7 +31,8 @@ def gru(
is_reverse,
act_state,
act_gate,
dtype='float32'):
dtype='float32',
origin_mode=False):
def _seq_to_batch(lod, is_reverse):
idx_in_seq_list = []
seq_lens = lod[0]
......@@ -66,6 +67,9 @@ def gru(
w_c = w.flatten()[D * D * 2:].reshape((D, D))
c = act_state(np.dot(r_h_p, w_c) + g[:, D * 2:])
g = np.hstack((u_r, c))
if origin_mode:
h = (1 - u) * c + u * h_p
else:
h = u * c + (1 - u) * h_p
return g, r_h_p, h
......@@ -110,6 +114,7 @@ class TestGRUOp(OpTest):
self.act_state = 'tanh'
self.act_gate = 'sigmoid'
self.dtype = 'float64'
self.origin_mode = False
self.set_confs()
T = sum(self.lod[0])
......@@ -126,7 +131,8 @@ class TestGRUOp(OpTest):
batch_gate, batch_reset_hidden_prev, batch_hidden, hidden = gru(
input, self.lod, h0, weight, bias, self.is_reverse,
ACTIVATION[self.act_state], ACTIVATION[self.act_gate], self.dtype)
ACTIVATION[self.act_state], ACTIVATION[self.act_gate], self.dtype,
self.origin_mode)
self.inputs = {'Input': (input, self.lod), 'Weight': weight}
if self.with_bias:
......@@ -145,7 +151,8 @@ class TestGRUOp(OpTest):
self.attrs = {
'activation': self.act_state,
'gate_activation': self.act_gate,
'is_reverse': self.is_reverse
'is_reverse': self.is_reverse,
'origin_mode': self.origin_mode
}
def test_check_output(self):
......@@ -155,12 +162,24 @@ class TestGRUOp(OpTest):
self.check_grad(['Input', 'H0', 'Weight', 'Bias'], ['Hidden'])
class TestGRUOriginMode(TestGRUOp):
def set_confs(self):
self.origin_mode = True
class TestGRUOp2(TestGRUOp):
def set_confs(self):
self.D = 19
self.dtype = 'float32'
class TestGRUOp2OriginMode(TestGRUOp):
def set_confs(self):
self.D = 19
self.dtype = 'float32'
self.origin_mode = True
class TestGRUOpNoInitial(TestGRUOp):
def set_confs(self):
self.with_h0 = False
......@@ -182,5 +201,11 @@ class TestGRUOpReverse(TestGRUOp):
self.is_reverse = True
class TestGRUOpReverseOriginMode(TestGRUOp):
def set_confs(self):
self.is_reverse = True
self.origin_mode = True
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册