From 598414446116a9490739a1ba09c7fbd72a441656 Mon Sep 17 00:00:00 2001 From: Zhang Ting Date: Wed, 22 Mar 2023 14:06:22 +0800 Subject: [PATCH] fix dtype checking for softmax (#51929) --- python/paddle/nn/functional/activation.py | 6 +++--- python/paddle/nn/layer/activation.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 0ec80b3c5bb..4dc3e3a62ec 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -1110,15 +1110,15 @@ def softmax(x, axis=-1, dtype=None, name=None): use_cudnn = True if dtype is None: check_variable_and_dtype( - x, 'x', ['float16', 'float32', 'float64'], 'softmax' + x, 'x', ['float16', 'bfloat16', 'float32', 'float64'], 'softmax' ) else: check_dtype( dtype, 'dtype', - ['float32', 'float64'], + ['float16', 'bfloat16', 'float32', 'float64'], 'softmax', - 'If dtype is not None, it only support float32 or float64.', + 'If dtype is not None, it only support float16, bfloat16, float32 or float64.', ) helper = LayerHelper("softmax", **locals()) diff --git a/python/paddle/nn/layer/activation.py b/python/paddle/nn/layer/activation.py index ebe305d7869..6c85ae646b7 100644 --- a/python/paddle/nn/layer/activation.py +++ b/python/paddle/nn/layer/activation.py @@ -1324,7 +1324,7 @@ class Softmax(Layer): self._name = name def forward(self, x): - return F.softmax(x, self._axis, self._dtype, self._name) + return F.softmax(x, self._axis, name=self._name) def extra_repr(self): name_str = ', name={}'.format(self._name) if self._name else '' -- GitLab