提交 816c18a2 编写于 作者: Q Qiao Longfei

update test_gru_unit_op

上级 72618c8d
...@@ -71,7 +71,7 @@ class TestGRUUnitOp(OpTest): ...@@ -71,7 +71,7 @@ class TestGRUUnitOp(OpTest):
'gate_activation': GRUActivationType.sigmoid 'gate_activation': GRUActivationType.sigmoid
} }
def set_outputs(self): def set_outputs(self, origin_mode=False):
# GRU calculations # GRU calculations
batch_size = self.batch_size batch_size = self.batch_size
frame_size = self.frame_size frame_size = self.frame_size
...@@ -93,7 +93,10 @@ class TestGRUUnitOp(OpTest): ...@@ -93,7 +93,10 @@ class TestGRUUnitOp(OpTest):
c = self.activate[self.attrs['activation']](np.dot(r_h_p, w_c) + c = self.activate[self.attrs['activation']](np.dot(r_h_p, w_c) +
g[:, frame_size * 2:]) g[:, frame_size * 2:])
g = np.hstack((u_r, c)) g = np.hstack((u_r, c))
h = u * c + (1 - u) * h_p if origin_mode:
h = (1 - u) * c + u * h_p
else:
h = u * c + (1 - u) * h_p
self.outputs = { self.outputs = {
'Gate': g.astype('float64'), 'Gate': g.astype('float64'),
'ResetHiddenPrev': r_h_p.astype('float64'), 'ResetHiddenPrev': r_h_p.astype('float64'),
...@@ -111,6 +114,12 @@ class TestGRUUnitOp(OpTest): ...@@ -111,6 +114,12 @@ class TestGRUUnitOp(OpTest):
self.check_grad(['Input', 'HiddenPrev', 'Weight'], ['Hidden']) self.check_grad(['Input', 'HiddenPrev', 'Weight'], ['Hidden'])
class TestGRUUnitOpOriginMode(TestGRUUnitOp):
def setUp(self):
self.set_inputs()
self.set_outputs(origin_mode=True)
class TestGRUUnitOpWithBias(TestGRUUnitOp): class TestGRUUnitOpWithBias(TestGRUUnitOp):
def set_inputs(self): def set_inputs(self):
batch_size = self.batch_size batch_size = self.batch_size
...@@ -132,5 +141,11 @@ class TestGRUUnitOpWithBias(TestGRUUnitOp): ...@@ -132,5 +141,11 @@ class TestGRUUnitOpWithBias(TestGRUUnitOp):
no_grad_set=set('Input')) no_grad_set=set('Input'))
class TestGRUUnitOpWithBiasOriginMode(TestGRUUnitOpWithBias):
def setUp(self):
self.set_inputs()
self.set_outputs(origin_mode=True)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册