未验证 提交 5208e2d6 编写于 作者: C Ccc 提交者: GitHub

[AMP OP&Test] add unittest for log_softmax (#52264)

上级 c0697296
......@@ -102,6 +102,29 @@ class TestLogSoftmaxAxis(TestLogSoftmaxOp):
self.axis = 1
class TestLogSoftmaxFP16OP(TestLogSoftmaxOp):
def set_attrs(self):
self.dtype = np.float16
def test_check_output(self):
self.check_output(atol=1e-3)
def test_check_grad(self):
self.check_grad(['X'], ['Out'], max_relative_error=1e-2)
class TestLogSoftmaxShapeFP16OP(TestLogSoftmaxFP16OP):
def set_attrs(self):
self.dtype = np.float16
self.shape = [12, 10]
class TestLogSoftmaxAxisFP16OP(TestLogSoftmaxFP16OP):
def set_attrs(self):
self.dtype = np.float16
self.axis = 1
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册