提交 9ecc54a1 编写于 作者: Y Yibing Liu

Remove redundant code in unit test

上级 76beff86
...@@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License"); ...@@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
......
...@@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License"); ...@@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
......
...@@ -131,7 +131,10 @@ def lstmp( ...@@ -131,7 +131,10 @@ def lstmp(
class TestLstmpOp(OpTest): class TestLstmpOp(OpTest):
def set_argument(self): def reset_argument(self):
pass
def setUp(self):
self.lod = [[0, 2, 5, 7]] self.lod = [[0, 2, 5, 7]]
# hidden size # hidden size
self.D = 16 self.D = 16
...@@ -147,8 +150,7 @@ class TestLstmpOp(OpTest): ...@@ -147,8 +150,7 @@ class TestLstmpOp(OpTest):
self.is_reverse = False self.is_reverse = False
self.use_peepholes = True self.use_peepholes = True
def setUp(self): self.reset_argument()
self.set_argument()
self.op_type = 'lstmp' self.op_type = 'lstmp'
T = self.lod[0][-1] T = self.lod[0][-1]
...@@ -212,19 +214,8 @@ class TestLstmpOp(OpTest): ...@@ -212,19 +214,8 @@ class TestLstmpOp(OpTest):
class TestLstmpOpHasInitial(TestLstmpOp): class TestLstmpOpHasInitial(TestLstmpOp):
def set_argument(self): def reset_argument(self):
self.lod = [[0, 2, 5, 7]]
self.D = 16
self.P = 5
self.act_gate = 'sigmoid'
self.act_cell = 'tanh'
self.act_cand = 'tanh'
self.act_proj = self.act_cell
self.has_initial_state = True self.has_initial_state = True
self.is_reverse = True
self.use_peepholes = True
def test_check_grad(self): def test_check_grad(self):
# TODO(qingqing) remove folowing lines after the check_grad is refined. # TODO(qingqing) remove folowing lines after the check_grad is refined.
...@@ -313,52 +304,19 @@ class TestLstmpOpHasInitial(TestLstmpOp): ...@@ -313,52 +304,19 @@ class TestLstmpOpHasInitial(TestLstmpOp):
class TestLstmpOpRerverse(TestLstmpOp): class TestLstmpOpRerverse(TestLstmpOp):
def set_argument(self): def reset_argument(self):
self.lod = [[0, 2, 5, 7]]
self.D = 16
self.P = 10
self.act_gate = 'sigmoid'
self.act_cell = 'tanh'
self.act_cand = 'tanh'
self.act_proj = self.act_cell
self.has_initial_state = False
self.is_reverse = True self.is_reverse = True
self.use_peepholes = True
class TestLstmpOpNotUsePeepholes(TestLstmpOp): class TestLstmpOpNotUsePeepholes(TestLstmpOp):
def set_argument(self): def reset_argument(self):
self.lod = [[0, 2, 5, 7]]
self.D = 16
self.P = 10
self.act_gate = 'sigmoid'
self.act_cell = 'tanh'
self.act_cand = 'tanh'
self.act_proj = self.act_cell
self.has_initial_state = False
self.is_reverse = False
self.use_peepholes = False self.use_peepholes = False
class TestLstmpOpLinearProjection(TestLstmpOp): class TestLstmpOpLinearProjection(TestLstmpOp):
def set_argument(self): def reset_argument(self):
self.lod = [[0, 2, 5, 7]]
self.D = 16
self.P = 10
self.act_gate = 'sigmoid'
self.act_cell = 'tanh'
self.act_cand = 'tanh'
self.act_proj = 'identity' self.act_proj = 'identity'
self.has_initial_state = False
self.is_reverse = False
self.use_peepholes = True
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册