diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc old mode 100644 new mode 100755 index 26b4ed71e00219fcb5f5942a69d11e983f245e89..8776644b91424213678bce1fa94886d0d64db91a --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -249,6 +249,15 @@ $$out = cos(x)$$ )DOC"; +UNUSED constexpr char TanDoc[] = R"DOC( +Tangent Operator. Computes tangent of x element-wise. + +Input range is `(k*pi-pi/2, k*pi+pi/2)` and output range is `(-inf, inf)`. + +$$out = tan(x)$$ + +)DOC"; + UNUSED constexpr char SinDoc[] = R"DOC( Sine Activation Operator. @@ -709,6 +718,7 @@ REGISTER_ACTIVATION_OP_MAKER(Abs, AbsDoc); REGISTER_ACTIVATION_OP_MAKER(Ceil, CeilDoc); REGISTER_ACTIVATION_OP_MAKER(Floor, FloorDoc); REGISTER_ACTIVATION_OP_MAKER(Cos, CosDoc); +REGISTER_ACTIVATION_OP_MAKER(Tan, TanDoc); REGISTER_ACTIVATION_OP_MAKER(Sin, SinDoc); REGISTER_ACTIVATION_OP_MAKER(Sinh, SinhDoc); REGISTER_ACTIVATION_OP_MAKER(Cosh, CoshDoc); diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index 43907744f956a62b1c21a95559104145a59060cd..3a8bf17f079fdaa22f0dffa5df23a19abfbebfb2 100755 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -584,6 +584,39 @@ struct SinFunctor : public BaseActivationFunctor { } }; +template +struct Tangent { + HOSTDEVICE T operator()(const T& val) const { return tan(val); } +}; + +template <> +struct Tangent { + HOSTDEVICE platform::float16 operator()(const platform::float16& val) const { + return platform::float16(tan(static_cast(val))); + } +}; + +// Tangent'(x) = -Tangent(x) +template +struct TanGradFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout / x.unaryExpr(Cosine()).square(); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + +// Tangent(x) = tan(x) +template +struct TanFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.unaryExpr(Tangent()); + } +}; + template struct Sinh { HOSTDEVICE T operator()(const T& val) const { return sinh(val); } @@ -1942,6 +1975,7 @@ struct LogGradGradFunctor : public BaseActivationFunctor { __macro(ceil, Ceil, CeilFunctor, ZeroGradFunctor); \ __macro(floor, Floor, FloorFunctor, ZeroGradFunctor); \ __macro(cos, Cos, CosFunctor, CosGradFunctor); \ + __macro(tan, Tan, TanFunctor, TanGradFunctor); \ __macro(acos, Acos, AcosFunctor, AcosGradFunctor); \ __macro(sin, Sin, SinFunctor, SinGradFunctor); \ __macro(asin, Asin, AsinFunctor, AsinGradFunctor); \ diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 908e06b96e493c825537e81c09caf992bb2a4608..ac279b796e486c91c6ed10c1e76c765e6c2b2e1d 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -134,6 +134,7 @@ from .tensor.math import asin #DEFINE_ALIAS from .tensor.math import atan #DEFINE_ALIAS from .tensor.math import ceil #DEFINE_ALIAS from .tensor.math import cos #DEFINE_ALIAS +from .tensor.math import tan #DEFINE_ALIAS from .tensor.math import cosh #DEFINE_ALIAS from .tensor.math import cumsum #DEFINE_ALIAS # from .tensor.math import elementwise_add #DEFINE_ALIAS diff --git a/python/paddle/fluid/layers/ops.py b/python/paddle/fluid/layers/ops.py old mode 100644 new mode 100755 index 4a429a94e1ec639dbdafb1427f425157763eb91a..841daf7a41d1fa6a84feba3632328c68723fba3f --- a/python/paddle/fluid/layers/ops.py +++ b/python/paddle/fluid/layers/ops.py @@ -43,6 +43,7 @@ __unary_func__ = [ 'ceil', 'floor', 'cos', + 'tan', 'acos', 'sin', 'sinh', @@ -244,6 +245,19 @@ Examples: """) +add_sample_code(globals()["tan"], r""" +Examples: + .. code-block:: python + + import paddle + + x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3]) + out = paddle.tan(x) + print(out) + # [-0.42279324, -0.20271005, 0.10033467, 0.30933627] + +""") + add_sample_code(globals()["acos"], r""" Examples: .. code-block:: python diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index f0bb15ae93bb2e6e008f8d47987e1027c9285511..a9982dc132970463454527c89127745a4a443a00 100755 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -13,16 +13,17 @@ # limitations under the License. from __future__ import print_function - import unittest + import numpy as np -import paddle.fluid.core as core -from op_test import OpTest from scipy.special import expit, erf + +from op_test import OpTest import paddle -import paddle.fluid as fluid import paddle.nn as nn import paddle.nn.functional as F +import paddle.fluid as fluid +import paddle.fluid.core as core from paddle.fluid import compiler, Program, program_guard paddle.enable_static() @@ -137,7 +138,7 @@ class TestLogSigmoidAPI(unittest.TestCase): def setUp(self): np.random.seed(1024) self.x_np = np.random.uniform(-1, 1, [11, 17]).astype('float32') - self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \ else paddle.CPUPlace() def test_static_api(self): @@ -218,7 +219,7 @@ class TestTanhAPI(unittest.TestCase): self.dtype = 'float32' np.random.seed(1024) self.x_np = np.random.uniform(-1, 1, [10, 12]).astype(self.dtype) - self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + self.place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \ else paddle.CPUPlace() def test_static_api(self): @@ -480,7 +481,7 @@ class TestTanhshrinkAPI(unittest.TestCase): def setUp(self): np.random.seed(1024) self.x_np = np.random.uniform(10, 20, [10, 17]).astype(np.float64) - self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \ else paddle.CPUPlace() def test_static_api(self): @@ -572,7 +573,7 @@ class TestHardShrinkAPI(unittest.TestCase): def setUp(self): np.random.seed(1024) self.x_np = np.random.uniform(-1, 1, [10, 12]).astype('float32') - self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \ else paddle.CPUPlace() def test_static_api(self): @@ -644,7 +645,7 @@ class TestHardtanhAPI(unittest.TestCase): def setUp(self): np.random.seed(1024) self.x_np = np.random.uniform(-3, 3, [10, 12]).astype('float32') - self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \ else paddle.CPUPlace() def test_static_api(self): @@ -726,7 +727,7 @@ class TestSoftshrinkAPI(unittest.TestCase): self.threshold = 0.8 np.random.seed(1024) self.x_np = np.random.uniform(0.25, 10, [10, 12]).astype(np.float64) - self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \ else paddle.CPUPlace() def test_static_api(self): @@ -895,6 +896,57 @@ class TestCos(TestActivation): self.check_grad(['X'], 'Out') +class TestTan(TestActivation): + def setUp(self): + np.random.seed(1024) + self.op_type = "tan" + self.init_dtype() + self.dtype = 'float32' + self.x_np = np.random.uniform(-1, 1, [10, 12]).astype(self.dtype) + self.place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + out = np.tan(self.x_np) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(self.x_np)} + self.outputs = {'Out': out} + + def test_check_grad(self): + if self.dtype == np.float16: + return + self.check_grad(['X'], 'Out') + + def test_dygraph_api(self): + paddle.disable_static(self.place) + x = paddle.to_tensor(self.x_np) + out_test = paddle.tan(x) + out_ref = np.tan(self.x_np) + self.assertTrue(np.allclose(out_ref, out_test.numpy())) + paddle.enable_static() + + def test_static_api(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data('X', [10, 12], self.dtype) + out = paddle.tan(x) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out]) + out_ref = np.tan(self.x_np) + self.assertTrue(np.allclose(out_ref, res[0])) + + def test_backward(self): + test_data_shape = [11, 17] + with fluid.dygraph.guard(): + input_x = np.random.uniform(0.1, 1, + test_data_shape).astype("float32") + var = paddle.to_tensor(input_x) + var.stop_gradient = False + loss = paddle.tan(var) + loss.backward() + grad_var = var.gradient() + self.assertEqual(grad_var.shape, input_x.shape) + + class TestAcos(TestActivation): def setUp(self): self.op_type = "acos" @@ -990,7 +1042,7 @@ class TestReluAPI(unittest.TestCase): def setUp(self): np.random.seed(1024) self.x_np = np.random.uniform(-1, 1, [10, 12]).astype('float32') - self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \ else paddle.CPUPlace() def test_static_api(self): @@ -1084,7 +1136,7 @@ class TestLeakyReluAPI(unittest.TestCase): def setUp(self): np.random.seed(1024) self.x_np = np.random.uniform(-1, 1, [10, 12]).astype('float32') - self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \ else paddle.CPUPlace() def test_static_api(self): @@ -1195,7 +1247,7 @@ class TestGELUAPI(unittest.TestCase): def setUp(self): np.random.seed(1024) self.x_np = np.random.uniform(-1, 1, [11, 17]).astype('float32') - self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \ else paddle.CPUPlace() def test_static_api(self): @@ -1281,7 +1333,7 @@ class TestBreluAPI(unittest.TestCase): self.out_ref[self.out_ref < self.t_min] = self.t_min self.out_ref[self.out_ref > self.t_max] = self.t_max self.out_ref = self.out_ref.astype('float32') - self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \ else paddle.CPUPlace() def test_fluid_api(self): @@ -1344,7 +1396,7 @@ class TestRelu6API(unittest.TestCase): np.random.seed(1024) self.x_np = np.random.uniform(-1, 10, [10, 12]).astype(np.float64) self.x_np[np.abs(self.x_np) < 0.005] = 0.02 - self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \ else paddle.CPUPlace() def test_static_api(self): @@ -1430,7 +1482,7 @@ class TestHardswishAPI(unittest.TestCase): # test paddle.nn.Hardswish, paddle.nn.functional.hardswish def setUp(self): self.x_np = np.random.uniform(-1, 1, [10, 12]).astype(np.float64) - self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \ else paddle.CPUPlace() def test_static_api(self): @@ -1555,7 +1607,7 @@ class TestELUAPI(unittest.TestCase): def setUp(self): np.random.seed(1024) self.x_np = np.random.uniform(-3, 3, [10, 12]).astype('float32') - self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \ else paddle.CPUPlace() def test_static_api(self): @@ -2055,7 +2107,7 @@ class TestSoftplusAPI(unittest.TestCase): self.threshold = 15 np.random.seed(1024) self.x_np = np.random.uniform(-1, 1, [10, 12]).astype(np.float64) - self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \ else paddle.CPUPlace() def test_static_api(self): @@ -2134,7 +2186,7 @@ class TestSoftsignAPI(unittest.TestCase): def setUp(self): np.random.seed(1024) self.x_np = np.random.uniform(-1, 1, [10, 12]).astype(np.float64) - self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \ else paddle.CPUPlace() def test_static_api(self): @@ -2219,7 +2271,7 @@ class TestThresholdedReluAPI(unittest.TestCase): np.random.seed(1024) self.x_np = np.random.uniform(-20, 20, [10, 12]).astype(np.float64) self.x_np[np.abs(self.x_np) < 0.005] = 0.02 - self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \ else paddle.CPUPlace() def test_static_api(self): @@ -2317,12 +2369,12 @@ class TestHardsigmoidAPI(unittest.TestCase): # test paddle.nn.Hardsigmoid, paddle.nn.functional.hardsigmoid def setUp(self): self.x_np = np.random.uniform(-1, 1, [10, 12]).astype(np.float64) - self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \ else paddle.CPUPlace() def test_static_api(self): with paddle.static.program_guard(paddle.static.Program()): - x = paddle.fluid.data('X', self.x_np.shape, self.x_np.dtype) + x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype) out1 = F.hardsigmoid(x) m = paddle.nn.Hardsigmoid() out2 = m(x) @@ -2400,13 +2452,13 @@ class TestSwishAPI(unittest.TestCase): def setUp(self): np.random.seed(1024) self.x_np = np.random.uniform(-1, 1, [10, 12]).astype(np.float64) - self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \ else paddle.CPUPlace() def test_static_api(self): paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): - x = paddle.fluid.data('X', self.x_np.shape, self.x_np.dtype) + x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype) out1 = F.swish(x) swish = paddle.nn.Swish() out2 = swish(x) @@ -2483,6 +2535,7 @@ create_test_error_class('rsqrt') create_test_error_class('sin') create_test_error_class('sqrt') create_test_error_class('tanh') +create_test_error_class('tan') #------------------ Test Cudnn Activation---------------------- @@ -2509,7 +2562,7 @@ def create_test_act_fp16_class(parent, atol=1e-3, grad_check=True, grad_atol=0.80): - @unittest.skipIf(not core.is_compiled_with_cuda(), + @unittest.skipIf(not paddle.is_compiled_with_cuda(), "core is not compiled with CUDA") class TestActFp16(parent): def init_dtype(self): @@ -2545,6 +2598,7 @@ create_test_act_fp16_class(TestAbs) create_test_act_fp16_class(TestCeil, grad_check=False) create_test_act_fp16_class(TestFloor, grad_check=False) create_test_act_fp16_class(TestCos, grad_atol=0.85) +create_test_act_fp16_class(TestTan, grad_atol=0.85) create_test_act_fp16_class(TestCosh, grad_atol=0.85) create_test_act_fp16_class(TestAcos, grad_atol=0.85) create_test_act_fp16_class(TestSin) diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 515b4024471209cb84fdf354d9167ee07aa259f6..daee64b420453aa46e8274783a456a7f7b702690 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -104,6 +104,7 @@ from .math import asin #DEFINE_ALIAS from .math import atan #DEFINE_ALIAS from .math import ceil #DEFINE_ALIAS from .math import cos #DEFINE_ALIAS +from .math import tan #DEFINE_ALIAS from .math import cosh #DEFINE_ALIAS from .math import cumsum #DEFINE_ALIAS # from .math import elementwise_add #DEFINE_ALIAS diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 80d2a4a513398ed1630014a0327efbc1d0010fe9..3d3d24c7c254b69eb1a8c5c99d22fe00087216e3 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -33,6 +33,7 @@ from ..fluid.layers import acos #DEFINE_ALIAS from ..fluid.layers import asin #DEFINE_ALIAS from ..fluid.layers import ceil #DEFINE_ALIAS from ..fluid.layers import cos #DEFINE_ALIAS +from ..fluid.layers import tan #DEFINE_ALIAS from ..fluid.layers import sinh #DEFINE_ALIAS from ..fluid.layers import cosh #DEFINE_ALIAS # from ..fluid.layers import elementwise_add #DEFINE_ALIAS