提交 6f02fe7d 编写于 作者: C chengduoZH

fix unit test

上级 b15c69f5
...@@ -130,8 +130,30 @@ class TestSeqProject(OpTest): ...@@ -130,8 +130,30 @@ class TestSeqProject(OpTest):
max_relative_error=0.05, max_relative_error=0.05,
no_grad_set=set(['X', 'PaddingData'])) no_grad_set=set(['X', 'PaddingData']))
def test_check_grad_input_filter(self):
self.check_grad(
['X', 'Filter'],
'Out',
max_relative_error=0.05,
no_grad_set=set(['PaddingData']))
def test_check_grad_padding_input(self):
if self.padding_trainable:
self.check_grad(
['X', 'PaddingData'],
'Out',
max_relative_error=0.05,
no_grad_set=set(['Filter']))
def test_check_grad_padding_filter(self):
if self.padding_trainable:
self.check_grad(
['PaddingData', 'Filter'],
'Out',
max_relative_error=0.05,
no_grad_set=set(['X']))
def init_test_case(self): def init_test_case(self):
self.op_type = "sequence_project"
self.input_row = 11 self.input_row = 11
self.context_start = 0 self.context_start = 0
self.context_length = 1 self.context_length = 1
...@@ -144,7 +166,6 @@ class TestSeqProject(OpTest): ...@@ -144,7 +166,6 @@ class TestSeqProject(OpTest):
class TestSeqProjectCase1(TestSeqProject): class TestSeqProjectCase1(TestSeqProject):
def init_test_case(self): def init_test_case(self):
self.op_type = "sequence_project"
self.input_row = 11 self.input_row = 11
self.context_start = -1 self.context_start = -1
self.context_length = 3 self.context_length = 3
...@@ -157,7 +178,6 @@ class TestSeqProjectCase1(TestSeqProject): ...@@ -157,7 +178,6 @@ class TestSeqProjectCase1(TestSeqProject):
class TestSeqProjectCase2(TestSeqProject): class TestSeqProjectCase2(TestSeqProject):
def init_test_case(self): def init_test_case(self):
self.op_type = "sequence_project"
self.input_row = 25 self.input_row = 25
self.context_start = 2 self.context_start = 2
self.context_length = 3 self.context_length = 3
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册