提交 fd24e195 编写于 作者: W wanghaoshuang

Uncomment check output in unitest

上级 89de5d5e
...@@ -178,21 +178,23 @@ class TestWarpCTCOp(OpTest): ...@@ -178,21 +178,23 @@ class TestWarpCTCOp(OpTest):
for i in range(batch_size): for i in range(batch_size):
max_sequence_length = max(max_sequence_length, max_sequence_length = max(max_sequence_length,
logits_lod[0][i + 1] - logits_lod[0][i]) logits_lod[0][i + 1] - logits_lod[0][i])
gradient = np.zeros( self.gradient = np.zeros(
[max_sequence_length, batch_size, num_classes], dtype="float32") [max_sequence_length, batch_size, num_classes], dtype="float32")
self.inputs = { self.inputs = {
"Logits": (logits, logits_lod), "Logits": (logits, logits_lod),
"Label": (labels, labels_lod) "Label": (labels, labels_lod)
} }
self.outputs = {"Loss": loss, "WarpCTCGrad": gradient} self.outputs = {"Loss": loss}
self.attrs = {"blank": blank, "norm_by_times": norm_by_times} self.attrs = {"blank": blank, "norm_by_times": norm_by_times}
# def test_check_output(self): def test_check_output(self):
# self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.outputs['WarpCTCGrad'] = self.gradient
self.check_grad(["Logits"], "Loss", max_relative_error=0.01) self.check_grad(["Logits"], "Loss", max_relative_error=0.01)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册