未验证 提交 134d809e 编写于 作者: T Tao Luo 提交者: GitHub

fix softmax input error check on float16 (#20273)

test=develop
上级 54df1ad4
......@@ -2380,9 +2380,13 @@ def softmax(input, use_cudnn=False, name=None, axis=-1):
raise TypeError(
"The type of 'input' in softmax must be Variable, but received %s" %
(type(input)))
if convert_dtype(input.dtype) not in ['float32', 'float64']:
if convert_dtype(input.dtype) in ['float16']:
warnings.warn(
"The data type of 'input' in softmax only support float16 in GPU now."
)
if convert_dtype(input.dtype) not in ['float16', 'float32', 'float64']:
raise TypeError(
"The data type of 'input' in softmax must be float32 or float64, but received %s."
"The data type of 'input' in softmax must be float16, float32 or float64, but received %s."
% (convert_dtype(input.dtype)))
dtype = helper.input_dtype()
......
......@@ -83,9 +83,11 @@ class TestSoftmaxOpError(OpTest):
x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
self.assertRaises(TypeError, fluid.layers.softmax, x1)
# The input dtype of softmax_op must be float32 or float64.
# The input dtype of softmax_op must be float16, float32 or float64.
x2 = fluid.layers.data(name='x2', shape=[4], dtype="int32")
self.assertRaises(TypeError, fluid.layers.softmax, x2)
x3 = fluid.layers.data(name='x3', shape=[4], dtype="float16")
fluid.layers.softmax(x3)
class TestSoftmaxOp2(TestSoftmaxOp):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册