提交 fd24e195 编写于 作者: W wanghaoshuang

Uncomment check output in unitest

上级 89de5d5e
......@@ -178,21 +178,23 @@ class TestWarpCTCOp(OpTest):
for i in range(batch_size):
max_sequence_length = max(max_sequence_length,
logits_lod[0][i + 1] - logits_lod[0][i])
gradient = np.zeros(
self.gradient = np.zeros(
[max_sequence_length, batch_size, num_classes], dtype="float32")
self.inputs = {
"Logits": (logits, logits_lod),
"Label": (labels, labels_lod)
}
self.outputs = {"Loss": loss, "WarpCTCGrad": gradient}
self.outputs = {"Loss": loss}
self.attrs = {"blank": blank, "norm_by_times": norm_by_times}
# def test_check_output(self):
# self.check_output()
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.outputs['WarpCTCGrad'] = self.gradient
self.check_grad(["Logits"], "Loss", max_relative_error=0.01)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册