未验证 提交 1d0535a2 编写于 作者: Y yikaikkk 提交者: GitHub

add float16 to log log1p logsumexp (#51216)

* add float16 to log log1p logsumexp

* update log

* update test_activation_op.py

* update

* update codestyle

* Update  test_activation_op codestyle

* update

* update

* up

* update codestyle

* emmm

* -

---------
Co-authored-by: Nwqgo <1552367872@qq.com>
上级 74cd3889
......@@ -24,6 +24,7 @@ import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.nn.functional as F
import paddle.static as static
from paddle.fluid import Program, program_guard
from paddle.fluid.layer_helper import LayerHelper
......@@ -2688,6 +2689,21 @@ class TestLog(TestActivation):
self.assertRaises(TypeError, paddle.log, in2)
class Test_Log_Op_Fp16(unittest.TestCase):
def test_api_fp16(self):
paddle.enable_static()
with static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = [[2, 3, 4], [7, 8, 9]]
x = paddle.to_tensor(x, dtype='float16')
out = paddle.log(x)
if core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
(res,) = exe.run(fetch_list=[out])
class TestLog_ZeroDim(TestLog):
def init_shape(self):
self.shape = []
......@@ -2838,6 +2854,21 @@ class TestLog1p(TestActivation):
self.check_grad(['X'], 'Out', check_eager=True)
class Test_Log1p_Op_Fp16(unittest.TestCase):
def test_api_fp16(self):
paddle.enable_static()
with static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = [[2, 3, 4], [7, 8, 9]]
x = paddle.to_tensor(x, dtype='float16')
out = paddle.log1p(x)
if core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
(res,) = exe.run(fetch_list=[out])
class TestLog1p_ZeroDim(TestLog1p):
def init_shape(self):
self.shape = []
......
......@@ -135,7 +135,7 @@ def log(x, name=None):
Out = \ln(x)
Args:
x (Tensor): Input Tensor. Must be one of the following types: float32, float64.
x (Tensor): Input Tensor. Must be one of the following types: float16, float32, float64.
name (str|None): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`
......@@ -156,7 +156,9 @@ def log(x, name=None):
if in_dygraph_mode():
return _C_ops.log(x)
else:
check_variable_and_dtype(x, 'x', ['float32', 'float64'], "log")
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], "log"
)
inputs = {'X': [x]}
helper = LayerHelper('log', **locals())
dtype = helper.input_dtype(input_param_name='x')
......@@ -2127,7 +2129,7 @@ def logsumexp(x, axis=None, keepdim=False, name=None):
logsumexp(x) = \log\sum exp(x)
Args:
x (Tensor): The input Tensor with data type float32 or float64, which
x (Tensor): The input Tensor with data type float16, float32 or float64, which
have no more than 4 dimensions.
axis (int|list|tuple, optional): The axis along which to perform
logsumexp calculations. ``axis`` should be int, list(int) or
......@@ -2166,7 +2168,9 @@ def logsumexp(x, axis=None, keepdim=False, name=None):
if in_dygraph_mode():
return _C_ops.logsumexp(x, axis, keepdim, reduce_all)
else:
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'logsumexp')
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'logsumexp'
)
helper = LayerHelper('logsumexp', **locals())
attrs = {'axis': axis, 'keepdim': keepdim, 'reduce_all': reduce_all}
......@@ -2648,7 +2652,7 @@ def log1p(x, name=None):
Out = \ln(x+1)
Args:
x (Tensor): Input Tensor. Must be one of the following types: float32, float64.
x (Tensor): Input Tensor. Must be one of the following types: float16, float32, float64.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
......@@ -2667,7 +2671,9 @@ def log1p(x, name=None):
if in_dygraph_mode():
return _C_ops.log1p(x)
else:
check_variable_and_dtype(x, 'x', ['float32', 'float64'], "log1p")
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], "log1p"
)
inputs = {'X': [x]}
helper = LayerHelper('log1p', **locals())
dtype = helper.input_dtype(input_param_name='x')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册