未验证 提交 d832a54d 编写于 作者: H haozi 提交者: GitHub

fix fp16 dtype checking for clip op (#50878)

* fix fp16 dtype checking for clip op

* modify the name

* fix type error

* fix check error

* Update test_clip_op.py

fix test error

* Update test_clip_op.py

fix code style

---------
Co-authored-by: NZhang Ting <Douyaer2020@qq.com>
上级 6b85eb59
...@@ -299,6 +299,33 @@ class TestClipAPI(unittest.TestCase): ...@@ -299,6 +299,33 @@ class TestClipAPI(unittest.TestCase):
paddle.disable_static() paddle.disable_static()
class TestClipOpFp16(unittest.TestCase):
def test_fp16(self):
paddle.enable_static()
data_shape = [1, 9, 9, 4]
data = np.random.random(data_shape).astype('float16')
with paddle.static.program_guard(paddle.static.Program()):
images = paddle.static.data(
name='image1', shape=data_shape, dtype='float16'
)
min = paddle.static.data(name='min1', shape=[1], dtype='float16')
max = paddle.static.data(name='max1', shape=[1], dtype='float16')
out = paddle.clip(images, min, max)
if fluid.core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
res1 = exe.run(
feed={
"image1": data,
"min1": np.array([0.2]).astype('float16'),
"max1": np.array([0.8]).astype('float16'),
},
fetch_list=[out],
)
paddle.disable_static()
class TestInplaceClipAPI(TestClipAPI): class TestInplaceClipAPI(TestClipAPI):
def _executed_api(self, x, min=None, max=None): def _executed_api(self, x, min=None, max=None):
return x.clip_(min, max) return x.clip_(min, max)
......
...@@ -2770,11 +2770,11 @@ def clip(x, min=None, max=None, name=None): ...@@ -2770,11 +2770,11 @@ def clip(x, min=None, max=None, name=None):
Out = MIN(MAX(x, min), max) Out = MIN(MAX(x, min), max)
Args: Args:
x (Tensor): An N-D Tensor with data type float32, float64, int32 or int64. x (Tensor): An N-D Tensor with data type float16, float32, float64, int32 or int64.
min (float|int|Tensor, optional): The lower bound with type ``float`` , ``int`` or a ``Tensor`` min (float|int|Tensor, optional): The lower bound with type ``float`` , ``int`` or a ``Tensor``
with shape [1] and type ``int32``, ``float32``, ``float64``. with shape [1] and type ``int32``, ``float16``, ``float32``, ``float64``.
max (float|int|Tensor, optional): The upper bound with type ``float``, ``int`` or a ``Tensor`` max (float|int|Tensor, optional): The upper bound with type ``float``, ``int`` or a ``Tensor``
with shape [1] and type ``int32``, ``float32``, ``float64``. with shape [1] and type ``int32``, ``float16``, ``float32``, ``float64``.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns: Returns:
...@@ -2803,6 +2803,9 @@ def clip(x, min=None, max=None, name=None): ...@@ -2803,6 +2803,9 @@ def clip(x, min=None, max=None, name=None):
elif x_dtype == 'paddle.int64': elif x_dtype == 'paddle.int64':
min_ = np.iinfo(np.int64).min min_ = np.iinfo(np.int64).min
max_ = np.iinfo(np.int64).max - 2**39 max_ = np.iinfo(np.int64).max - 2**39
elif x_dtype == 'paddle.float16':
min_ = float(np.finfo(np.float16).min)
max_ = float(np.finfo(np.float16).max)
else: else:
min_ = float(np.finfo(np.float32).min) min_ = float(np.finfo(np.float32).min)
max_ = float(np.finfo(np.float32).max) max_ = float(np.finfo(np.float32).max)
...@@ -2822,7 +2825,7 @@ def clip(x, min=None, max=None, name=None): ...@@ -2822,7 +2825,7 @@ def clip(x, min=None, max=None, name=None):
check_dtype( check_dtype(
min.dtype, min.dtype,
'min', 'min',
['float32', 'float64', 'int32'], ['float16', 'float32', 'float64', 'int32'],
'clip', 'clip',
'(When the type of min in clip is Variable.)', '(When the type of min in clip is Variable.)',
) )
...@@ -2832,13 +2835,13 @@ def clip(x, min=None, max=None, name=None): ...@@ -2832,13 +2835,13 @@ def clip(x, min=None, max=None, name=None):
check_dtype( check_dtype(
max.dtype, max.dtype,
'max', 'max',
['float32', 'float64', 'int32'], ['float16', 'float32', 'float64', 'int32'],
'clip', 'clip',
'(When the type of max in clip is Variable.)', '(When the type of max in clip is Variable.)',
) )
check_variable_and_dtype( check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'int32', 'int64'], 'clip' x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], 'clip'
) )
inputs = {'X': x} inputs = {'X': x}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册