未验证 提交 0155f916 编写于 作者: Z zhupengyang 提交者: GitHub

enable softmax unittest (#28362)

上级 c42e6561
......@@ -267,22 +267,11 @@ class TestSoftmaxFP16Op(TestSoftmaxOp):
pass
@unittest.skip('disable TestSoftmaxFP16Op2')
class TestSoftmaxFP16Op2(TestSoftmaxOp):
def init_kernel_type(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestSoftmaxFP16Op2(TestSoftmaxFP16Op):
def get_x_shape(self):
return [2, 3, 4, 5]
def test_check_grad(self):
pass
return [2, 3, 4, 10]
@unittest.skipIf(not core.is_compiled_with_cuda(),
......@@ -354,10 +343,12 @@ class TestSoftmaxAPI(unittest.TestCase):
# The input type must be Variable.
self.assertRaises(TypeError, F.softmax, 1)
# The input dtype must be float16, float32, float64.
x_int32 = paddle.fluid.data(name='x_int32', shape=[2, 3], dtype='int32')
x_int32 = paddle.fluid.data(
name='x_int32', shape=[2, 3], dtype='int32')
self.assertRaises(TypeError, F.softmax, x_int32)
# support the input dtype is float16
x_fp16 = paddle.fluid.data(name='x_fp16', shape=[2, 3], dtype='float16')
x_fp16 = paddle.fluid.data(
name='x_fp16', shape=[2, 3], dtype='float16')
F.softmax(x_fp16)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册