未验证 提交 59841444 编写于 作者: Z Zhang Ting 提交者: GitHub

fix dtype checking for softmax (#51929)

上级 2b98993b
......@@ -1110,15 +1110,15 @@ def softmax(x, axis=-1, dtype=None, name=None):
use_cudnn = True
if dtype is None:
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'softmax'
x, 'x', ['float16', 'bfloat16', 'float32', 'float64'], 'softmax'
)
else:
check_dtype(
dtype,
'dtype',
['float32', 'float64'],
['float16', 'bfloat16', 'float32', 'float64'],
'softmax',
'If dtype is not None, it only support float32 or float64.',
'If dtype is not None, it only support float16, bfloat16, float32 or float64.',
)
helper = LayerHelper("softmax", **locals())
......
......@@ -1324,7 +1324,7 @@ class Softmax(Layer):
self._name = name
def forward(self, x):
return F.softmax(x, self._axis, self._dtype, self._name)
return F.softmax(x, self._axis, name=self._name)
def extra_repr(self):
name_str = ', name={}'.format(self._name) if self._name else ''
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册