提交 816c18a2 编写于 作者: Q Qiao Longfei

update test_gru_unit_op

上级 72618c8d
......@@ -71,7 +71,7 @@ class TestGRUUnitOp(OpTest):
'gate_activation': GRUActivationType.sigmoid
}
def set_outputs(self):
def set_outputs(self, origin_mode=False):
# GRU calculations
batch_size = self.batch_size
frame_size = self.frame_size
......@@ -93,6 +93,9 @@ class TestGRUUnitOp(OpTest):
c = self.activate[self.attrs['activation']](np.dot(r_h_p, w_c) +
g[:, frame_size * 2:])
g = np.hstack((u_r, c))
if origin_mode:
h = (1 - u) * c + u * h_p
else:
h = u * c + (1 - u) * h_p
self.outputs = {
'Gate': g.astype('float64'),
......@@ -111,6 +114,12 @@ class TestGRUUnitOp(OpTest):
self.check_grad(['Input', 'HiddenPrev', 'Weight'], ['Hidden'])
class TestGRUUnitOpOriginMode(TestGRUUnitOp):
def setUp(self):
self.set_inputs()
self.set_outputs(origin_mode=True)
class TestGRUUnitOpWithBias(TestGRUUnitOp):
def set_inputs(self):
batch_size = self.batch_size
......@@ -132,5 +141,11 @@ class TestGRUUnitOpWithBias(TestGRUUnitOp):
no_grad_set=set('Input'))
class TestGRUUnitOpWithBiasOriginMode(TestGRUUnitOpWithBias):
def setUp(self):
self.set_inputs()
self.set_outputs(origin_mode=True)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册