提交 7a81ab86 编写于 作者: Q Qiao Longfei

complete gru_unite_op and test

上级 816c18a2
......@@ -184,11 +184,19 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
auto c = g.slice(c_offsets, extents); // output candidate
// backward for unactivated update gate
if (context.Attr<bool>("origin_mode")) {
ActGradCompute(context.Attr<int>("gate_activation"), place, u, u,
d_g.slice(u_offsets, extents), d_h * (h_p - c));
// backward for unactivated output candidate
ActGradCompute(context.Attr<int>("activation"), place, c, c,
d_g.slice(c_offsets, extents), d_h * (1 - u));
} else {
ActGradCompute(context.Attr<int>("gate_activation"), place, u, u,
d_g.slice(u_offsets, extents), d_h * (c - h_p));
// backward for unactivated output candidate
ActGradCompute(context.Attr<int>("activation"), place, c, c,
d_g.slice(c_offsets, extents), d_h * u);
}
// backward for reset_hidden_prev
auto blas = math::GetBlas<DeviceContext, T>(context);
blas.GEMM(false, true, batch_size, frame_size, frame_size, 1,
......@@ -218,9 +226,9 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
hidden_prev_grad->mutable_data<T>(context.GetPlace());
auto d_h_p = EigenMatrix<T>::From(*hidden_prev_grad);
if (context.Attr<bool>("origin_mode")) {
d_h_p.device(place) = d_r_h_p * (u.constant(T(1)) - u) + d_h * r;
d_h_p.device(place) = d_r_h_p * r + d_h * u;
} else {
d_h_p.device(place) = d_r_h_p * r + d_h * (u.constant(T(1)) - u);
d_h_p.device(place) = d_r_h_p * r + d_h * (1 - u);
}
blas.GEMM(false, true, batch_size, frame_size, frame_size * 2, 1,
gate_grad_data, frame_size * 3, weight_data, frame_size * 2, 1,
......
......@@ -53,7 +53,7 @@ class TestGRUUnitOp(OpTest):
GRUActivationType.relu: relu,
}
def set_inputs(self):
def set_inputs(self, origin_mode=False):
batch_size = self.batch_size
frame_size = self.frame_size
self.op_type = 'gru_unit'
......@@ -68,7 +68,8 @@ class TestGRUUnitOp(OpTest):
}
self.attrs = {
'activation': GRUActivationType.tanh,
'gate_activation': GRUActivationType.sigmoid
'gate_activation': GRUActivationType.sigmoid,
'origin_mode': origin_mode
}
def set_outputs(self, origin_mode=False):
......@@ -116,12 +117,12 @@ class TestGRUUnitOp(OpTest):
class TestGRUUnitOpOriginMode(TestGRUUnitOp):
def setUp(self):
self.set_inputs()
self.set_inputs(origin_mode=True)
self.set_outputs(origin_mode=True)
class TestGRUUnitOpWithBias(TestGRUUnitOp):
def set_inputs(self):
def set_inputs(self, origin_mode=False):
batch_size = self.batch_size
frame_size = self.frame_size
super(TestGRUUnitOpWithBias, self).set_inputs()
......@@ -129,7 +130,8 @@ class TestGRUUnitOpWithBias(TestGRUUnitOp):
-0.1, 0.1, (1, frame_size * 3)).astype('float64')
self.attrs = {
'activation': GRUActivationType.identity,
'gate_activation': GRUActivationType.sigmoid
'gate_activation': GRUActivationType.sigmoid,
'origin_mode': origin_mode
}
def test_check_grad(self):
......@@ -143,7 +145,7 @@ class TestGRUUnitOpWithBias(TestGRUUnitOp):
class TestGRUUnitOpWithBiasOriginMode(TestGRUUnitOpWithBias):
def setUp(self):
self.set_inputs()
self.set_inputs(origin_mode=True)
self.set_outputs(origin_mode=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册