未验证 提交 95b35f77 编写于 作者: T Tao Luo 提交者: GitHub

fix softmax input error check on float16 (#20273) (#20346)

test=release/1.6
上级 6e4a454e
...@@ -2380,9 +2380,13 @@ def softmax(input, use_cudnn=False, name=None, axis=-1): ...@@ -2380,9 +2380,13 @@ def softmax(input, use_cudnn=False, name=None, axis=-1):
raise TypeError( raise TypeError(
"The type of 'input' in softmax must be Variable, but received %s" % "The type of 'input' in softmax must be Variable, but received %s" %
(type(input))) (type(input)))
if convert_dtype(input.dtype) not in ['float32', 'float64']: if convert_dtype(input.dtype) in ['float16']:
warnings.warn(
"The data type of 'input' in softmax only support float16 in GPU now."
)
if convert_dtype(input.dtype) not in ['float16', 'float32', 'float64']:
raise TypeError( raise TypeError(
"The data type of 'input' in softmax must be float32 or float64, but received %s." "The data type of 'input' in softmax must be float16, float32 or float64, but received %s."
% (convert_dtype(input.dtype))) % (convert_dtype(input.dtype)))
dtype = helper.input_dtype() dtype = helper.input_dtype()
......
...@@ -83,9 +83,11 @@ class TestSoftmaxOpError(OpTest): ...@@ -83,9 +83,11 @@ class TestSoftmaxOpError(OpTest):
x1 = fluid.create_lod_tensor( x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace()) np.array([[-1]]), [[1]], fluid.CPUPlace())
self.assertRaises(TypeError, fluid.layers.softmax, x1) self.assertRaises(TypeError, fluid.layers.softmax, x1)
# The input dtype of softmax_op must be float32 or float64. # The input dtype of softmax_op must be float16, float32 or float64.
x2 = fluid.layers.data(name='x2', shape=[4], dtype="int32") x2 = fluid.layers.data(name='x2', shape=[4], dtype="int32")
self.assertRaises(TypeError, fluid.layers.softmax, x2) self.assertRaises(TypeError, fluid.layers.softmax, x2)
x3 = fluid.layers.data(name='x3', shape=[4], dtype="float16")
fluid.layers.softmax(x3)
class TestSoftmaxOp2(TestSoftmaxOp): class TestSoftmaxOp2(TestSoftmaxOp):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册