提交 7a81ab86 编写于 作者: Q Qiao Longfei

complete gru_unite_op and test

上级 816c18a2
...@@ -184,11 +184,19 @@ class GRUUnitGradKernel : public framework::OpKernel<T> { ...@@ -184,11 +184,19 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
auto c = g.slice(c_offsets, extents); // output candidate auto c = g.slice(c_offsets, extents); // output candidate
// backward for unactivated update gate // backward for unactivated update gate
ActGradCompute(context.Attr<int>("gate_activation"), place, u, u, if (context.Attr<bool>("origin_mode")) {
d_g.slice(u_offsets, extents), d_h * (c - h_p)); ActGradCompute(context.Attr<int>("gate_activation"), place, u, u,
// backward for unactivated output candidate d_g.slice(u_offsets, extents), d_h * (h_p - c));
ActGradCompute(context.Attr<int>("activation"), place, c, c, // backward for unactivated output candidate
d_g.slice(c_offsets, extents), d_h * u); ActGradCompute(context.Attr<int>("activation"), place, c, c,
d_g.slice(c_offsets, extents), d_h * (1 - u));
} else {
ActGradCompute(context.Attr<int>("gate_activation"), place, u, u,
d_g.slice(u_offsets, extents), d_h * (c - h_p));
// backward for unactivated output candidate
ActGradCompute(context.Attr<int>("activation"), place, c, c,
d_g.slice(c_offsets, extents), d_h * u);
}
// backward for reset_hidden_prev // backward for reset_hidden_prev
auto blas = math::GetBlas<DeviceContext, T>(context); auto blas = math::GetBlas<DeviceContext, T>(context);
blas.GEMM(false, true, batch_size, frame_size, frame_size, 1, blas.GEMM(false, true, batch_size, frame_size, frame_size, 1,
...@@ -218,9 +226,9 @@ class GRUUnitGradKernel : public framework::OpKernel<T> { ...@@ -218,9 +226,9 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
hidden_prev_grad->mutable_data<T>(context.GetPlace()); hidden_prev_grad->mutable_data<T>(context.GetPlace());
auto d_h_p = EigenMatrix<T>::From(*hidden_prev_grad); auto d_h_p = EigenMatrix<T>::From(*hidden_prev_grad);
if (context.Attr<bool>("origin_mode")) { if (context.Attr<bool>("origin_mode")) {
d_h_p.device(place) = d_r_h_p * (u.constant(T(1)) - u) + d_h * r; d_h_p.device(place) = d_r_h_p * r + d_h * u;
} else { } else {
d_h_p.device(place) = d_r_h_p * r + d_h * (u.constant(T(1)) - u); d_h_p.device(place) = d_r_h_p * r + d_h * (1 - u);
} }
blas.GEMM(false, true, batch_size, frame_size, frame_size * 2, 1, blas.GEMM(false, true, batch_size, frame_size, frame_size * 2, 1,
gate_grad_data, frame_size * 3, weight_data, frame_size * 2, 1, gate_grad_data, frame_size * 3, weight_data, frame_size * 2, 1,
......
...@@ -53,7 +53,7 @@ class TestGRUUnitOp(OpTest): ...@@ -53,7 +53,7 @@ class TestGRUUnitOp(OpTest):
GRUActivationType.relu: relu, GRUActivationType.relu: relu,
} }
def set_inputs(self): def set_inputs(self, origin_mode=False):
batch_size = self.batch_size batch_size = self.batch_size
frame_size = self.frame_size frame_size = self.frame_size
self.op_type = 'gru_unit' self.op_type = 'gru_unit'
...@@ -68,7 +68,8 @@ class TestGRUUnitOp(OpTest): ...@@ -68,7 +68,8 @@ class TestGRUUnitOp(OpTest):
} }
self.attrs = { self.attrs = {
'activation': GRUActivationType.tanh, 'activation': GRUActivationType.tanh,
'gate_activation': GRUActivationType.sigmoid 'gate_activation': GRUActivationType.sigmoid,
'origin_mode': origin_mode
} }
def set_outputs(self, origin_mode=False): def set_outputs(self, origin_mode=False):
...@@ -116,12 +117,12 @@ class TestGRUUnitOp(OpTest): ...@@ -116,12 +117,12 @@ class TestGRUUnitOp(OpTest):
class TestGRUUnitOpOriginMode(TestGRUUnitOp): class TestGRUUnitOpOriginMode(TestGRUUnitOp):
def setUp(self): def setUp(self):
self.set_inputs() self.set_inputs(origin_mode=True)
self.set_outputs(origin_mode=True) self.set_outputs(origin_mode=True)
class TestGRUUnitOpWithBias(TestGRUUnitOp): class TestGRUUnitOpWithBias(TestGRUUnitOp):
def set_inputs(self): def set_inputs(self, origin_mode=False):
batch_size = self.batch_size batch_size = self.batch_size
frame_size = self.frame_size frame_size = self.frame_size
super(TestGRUUnitOpWithBias, self).set_inputs() super(TestGRUUnitOpWithBias, self).set_inputs()
...@@ -129,7 +130,8 @@ class TestGRUUnitOpWithBias(TestGRUUnitOp): ...@@ -129,7 +130,8 @@ class TestGRUUnitOpWithBias(TestGRUUnitOp):
-0.1, 0.1, (1, frame_size * 3)).astype('float64') -0.1, 0.1, (1, frame_size * 3)).astype('float64')
self.attrs = { self.attrs = {
'activation': GRUActivationType.identity, 'activation': GRUActivationType.identity,
'gate_activation': GRUActivationType.sigmoid 'gate_activation': GRUActivationType.sigmoid,
'origin_mode': origin_mode
} }
def test_check_grad(self): def test_check_grad(self):
...@@ -143,7 +145,7 @@ class TestGRUUnitOpWithBias(TestGRUUnitOp): ...@@ -143,7 +145,7 @@ class TestGRUUnitOpWithBias(TestGRUUnitOp):
class TestGRUUnitOpWithBiasOriginMode(TestGRUUnitOpWithBias): class TestGRUUnitOpWithBiasOriginMode(TestGRUUnitOpWithBias):
def setUp(self): def setUp(self):
self.set_inputs() self.set_inputs(origin_mode=True)
self.set_outputs(origin_mode=True) self.set_outputs(origin_mode=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册