未验证 提交 87e75a77 编写于 作者: J joejiong 提交者: GitHub

Add tangent operator (#29207)

As the title
上级 95e33481
...@@ -249,6 +249,15 @@ $$out = cos(x)$$ ...@@ -249,6 +249,15 @@ $$out = cos(x)$$
)DOC"; )DOC";
UNUSED constexpr char TanDoc[] = R"DOC(
Tangent Operator. Computes tangent of x element-wise.
Input range is `(k*pi-pi/2, k*pi+pi/2)` and output range is `(-inf, inf)`.
$$out = tan(x)$$
)DOC";
UNUSED constexpr char SinDoc[] = R"DOC( UNUSED constexpr char SinDoc[] = R"DOC(
Sine Activation Operator. Sine Activation Operator.
...@@ -709,6 +718,7 @@ REGISTER_ACTIVATION_OP_MAKER(Abs, AbsDoc); ...@@ -709,6 +718,7 @@ REGISTER_ACTIVATION_OP_MAKER(Abs, AbsDoc);
REGISTER_ACTIVATION_OP_MAKER(Ceil, CeilDoc); REGISTER_ACTIVATION_OP_MAKER(Ceil, CeilDoc);
REGISTER_ACTIVATION_OP_MAKER(Floor, FloorDoc); REGISTER_ACTIVATION_OP_MAKER(Floor, FloorDoc);
REGISTER_ACTIVATION_OP_MAKER(Cos, CosDoc); REGISTER_ACTIVATION_OP_MAKER(Cos, CosDoc);
REGISTER_ACTIVATION_OP_MAKER(Tan, TanDoc);
REGISTER_ACTIVATION_OP_MAKER(Sin, SinDoc); REGISTER_ACTIVATION_OP_MAKER(Sin, SinDoc);
REGISTER_ACTIVATION_OP_MAKER(Sinh, SinhDoc); REGISTER_ACTIVATION_OP_MAKER(Sinh, SinhDoc);
REGISTER_ACTIVATION_OP_MAKER(Cosh, CoshDoc); REGISTER_ACTIVATION_OP_MAKER(Cosh, CoshDoc);
......
...@@ -584,6 +584,39 @@ struct SinFunctor : public BaseActivationFunctor<T> { ...@@ -584,6 +584,39 @@ struct SinFunctor : public BaseActivationFunctor<T> {
} }
}; };
template <typename T>
struct Tangent {
HOSTDEVICE T operator()(const T& val) const { return tan(val); }
};
template <>
struct Tangent<platform::float16> {
HOSTDEVICE platform::float16 operator()(const platform::float16& val) const {
return platform::float16(tan(static_cast<float>(val)));
}
};
// Tangent'(x) = -Tangent(x)
template <typename T>
struct TanGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout / x.unaryExpr(Cosine<T>()).square();
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// Tangent(x) = tan(x)
template <typename T>
struct TanFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Tangent<T>());
}
};
template <typename T> template <typename T>
struct Sinh { struct Sinh {
HOSTDEVICE T operator()(const T& val) const { return sinh(val); } HOSTDEVICE T operator()(const T& val) const { return sinh(val); }
...@@ -1942,6 +1975,7 @@ struct LogGradGradFunctor : public BaseActivationFunctor<T> { ...@@ -1942,6 +1975,7 @@ struct LogGradGradFunctor : public BaseActivationFunctor<T> {
__macro(ceil, Ceil, CeilFunctor, ZeroGradFunctor); \ __macro(ceil, Ceil, CeilFunctor, ZeroGradFunctor); \
__macro(floor, Floor, FloorFunctor, ZeroGradFunctor); \ __macro(floor, Floor, FloorFunctor, ZeroGradFunctor); \
__macro(cos, Cos, CosFunctor, CosGradFunctor); \ __macro(cos, Cos, CosFunctor, CosGradFunctor); \
__macro(tan, Tan, TanFunctor, TanGradFunctor); \
__macro(acos, Acos, AcosFunctor, AcosGradFunctor); \ __macro(acos, Acos, AcosFunctor, AcosGradFunctor); \
__macro(sin, Sin, SinFunctor, SinGradFunctor); \ __macro(sin, Sin, SinFunctor, SinGradFunctor); \
__macro(asin, Asin, AsinFunctor, AsinGradFunctor); \ __macro(asin, Asin, AsinFunctor, AsinGradFunctor); \
......
...@@ -134,6 +134,7 @@ from .tensor.math import asin #DEFINE_ALIAS ...@@ -134,6 +134,7 @@ from .tensor.math import asin #DEFINE_ALIAS
from .tensor.math import atan #DEFINE_ALIAS from .tensor.math import atan #DEFINE_ALIAS
from .tensor.math import ceil #DEFINE_ALIAS from .tensor.math import ceil #DEFINE_ALIAS
from .tensor.math import cos #DEFINE_ALIAS from .tensor.math import cos #DEFINE_ALIAS
from .tensor.math import tan #DEFINE_ALIAS
from .tensor.math import cosh #DEFINE_ALIAS from .tensor.math import cosh #DEFINE_ALIAS
from .tensor.math import cumsum #DEFINE_ALIAS from .tensor.math import cumsum #DEFINE_ALIAS
# from .tensor.math import elementwise_add #DEFINE_ALIAS # from .tensor.math import elementwise_add #DEFINE_ALIAS
......
...@@ -43,6 +43,7 @@ __unary_func__ = [ ...@@ -43,6 +43,7 @@ __unary_func__ = [
'ceil', 'ceil',
'floor', 'floor',
'cos', 'cos',
'tan',
'acos', 'acos',
'sin', 'sin',
'sinh', 'sinh',
...@@ -244,6 +245,19 @@ Examples: ...@@ -244,6 +245,19 @@ Examples:
""") """)
add_sample_code(globals()["tan"], r"""
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.tan(x)
print(out)
# [-0.42279324, -0.20271005, 0.10033467, 0.30933627]
""")
add_sample_code(globals()["acos"], r""" add_sample_code(globals()["acos"], r"""
Examples: Examples:
.. code-block:: python .. code-block:: python
......
...@@ -13,16 +13,17 @@ ...@@ -13,16 +13,17 @@
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import paddle.fluid.core as core
from op_test import OpTest
from scipy.special import expit, erf from scipy.special import expit, erf
from op_test import OpTest
import paddle import paddle
import paddle.fluid as fluid
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import compiler, Program, program_guard from paddle.fluid import compiler, Program, program_guard
paddle.enable_static() paddle.enable_static()
...@@ -137,7 +138,7 @@ class TestLogSigmoidAPI(unittest.TestCase): ...@@ -137,7 +138,7 @@ class TestLogSigmoidAPI(unittest.TestCase):
def setUp(self): def setUp(self):
np.random.seed(1024) np.random.seed(1024)
self.x_np = np.random.uniform(-1, 1, [11, 17]).astype('float32') self.x_np = np.random.uniform(-1, 1, [11, 17]).astype('float32')
self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \
else paddle.CPUPlace() else paddle.CPUPlace()
def test_static_api(self): def test_static_api(self):
...@@ -218,7 +219,7 @@ class TestTanhAPI(unittest.TestCase): ...@@ -218,7 +219,7 @@ class TestTanhAPI(unittest.TestCase):
self.dtype = 'float32' self.dtype = 'float32'
np.random.seed(1024) np.random.seed(1024)
self.x_np = np.random.uniform(-1, 1, [10, 12]).astype(self.dtype) self.x_np = np.random.uniform(-1, 1, [10, 12]).astype(self.dtype)
self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ self.place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \
else paddle.CPUPlace() else paddle.CPUPlace()
def test_static_api(self): def test_static_api(self):
...@@ -480,7 +481,7 @@ class TestTanhshrinkAPI(unittest.TestCase): ...@@ -480,7 +481,7 @@ class TestTanhshrinkAPI(unittest.TestCase):
def setUp(self): def setUp(self):
np.random.seed(1024) np.random.seed(1024)
self.x_np = np.random.uniform(10, 20, [10, 17]).astype(np.float64) self.x_np = np.random.uniform(10, 20, [10, 17]).astype(np.float64)
self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \
else paddle.CPUPlace() else paddle.CPUPlace()
def test_static_api(self): def test_static_api(self):
...@@ -572,7 +573,7 @@ class TestHardShrinkAPI(unittest.TestCase): ...@@ -572,7 +573,7 @@ class TestHardShrinkAPI(unittest.TestCase):
def setUp(self): def setUp(self):
np.random.seed(1024) np.random.seed(1024)
self.x_np = np.random.uniform(-1, 1, [10, 12]).astype('float32') self.x_np = np.random.uniform(-1, 1, [10, 12]).astype('float32')
self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \
else paddle.CPUPlace() else paddle.CPUPlace()
def test_static_api(self): def test_static_api(self):
...@@ -644,7 +645,7 @@ class TestHardtanhAPI(unittest.TestCase): ...@@ -644,7 +645,7 @@ class TestHardtanhAPI(unittest.TestCase):
def setUp(self): def setUp(self):
np.random.seed(1024) np.random.seed(1024)
self.x_np = np.random.uniform(-3, 3, [10, 12]).astype('float32') self.x_np = np.random.uniform(-3, 3, [10, 12]).astype('float32')
self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \
else paddle.CPUPlace() else paddle.CPUPlace()
def test_static_api(self): def test_static_api(self):
...@@ -726,7 +727,7 @@ class TestSoftshrinkAPI(unittest.TestCase): ...@@ -726,7 +727,7 @@ class TestSoftshrinkAPI(unittest.TestCase):
self.threshold = 0.8 self.threshold = 0.8
np.random.seed(1024) np.random.seed(1024)
self.x_np = np.random.uniform(0.25, 10, [10, 12]).astype(np.float64) self.x_np = np.random.uniform(0.25, 10, [10, 12]).astype(np.float64)
self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \
else paddle.CPUPlace() else paddle.CPUPlace()
def test_static_api(self): def test_static_api(self):
...@@ -895,6 +896,57 @@ class TestCos(TestActivation): ...@@ -895,6 +896,57 @@ class TestCos(TestActivation):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
class TestTan(TestActivation):
def setUp(self):
np.random.seed(1024)
self.op_type = "tan"
self.init_dtype()
self.dtype = 'float32'
self.x_np = np.random.uniform(-1, 1, [10, 12]).astype(self.dtype)
self.place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \
else paddle.CPUPlace()
out = np.tan(self.x_np)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(self.x_np)}
self.outputs = {'Out': out}
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out')
def test_dygraph_api(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.x_np)
out_test = paddle.tan(x)
out_ref = np.tan(self.x_np)
self.assertTrue(np.allclose(out_ref, out_test.numpy()))
paddle.enable_static()
def test_static_api(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('X', [10, 12], self.dtype)
out = paddle.tan(x)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out])
out_ref = np.tan(self.x_np)
self.assertTrue(np.allclose(out_ref, res[0]))
def test_backward(self):
test_data_shape = [11, 17]
with fluid.dygraph.guard():
input_x = np.random.uniform(0.1, 1,
test_data_shape).astype("float32")
var = paddle.to_tensor(input_x)
var.stop_gradient = False
loss = paddle.tan(var)
loss.backward()
grad_var = var.gradient()
self.assertEqual(grad_var.shape, input_x.shape)
class TestAcos(TestActivation): class TestAcos(TestActivation):
def setUp(self): def setUp(self):
self.op_type = "acos" self.op_type = "acos"
...@@ -990,7 +1042,7 @@ class TestReluAPI(unittest.TestCase): ...@@ -990,7 +1042,7 @@ class TestReluAPI(unittest.TestCase):
def setUp(self): def setUp(self):
np.random.seed(1024) np.random.seed(1024)
self.x_np = np.random.uniform(-1, 1, [10, 12]).astype('float32') self.x_np = np.random.uniform(-1, 1, [10, 12]).astype('float32')
self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \
else paddle.CPUPlace() else paddle.CPUPlace()
def test_static_api(self): def test_static_api(self):
...@@ -1084,7 +1136,7 @@ class TestLeakyReluAPI(unittest.TestCase): ...@@ -1084,7 +1136,7 @@ class TestLeakyReluAPI(unittest.TestCase):
def setUp(self): def setUp(self):
np.random.seed(1024) np.random.seed(1024)
self.x_np = np.random.uniform(-1, 1, [10, 12]).astype('float32') self.x_np = np.random.uniform(-1, 1, [10, 12]).astype('float32')
self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \
else paddle.CPUPlace() else paddle.CPUPlace()
def test_static_api(self): def test_static_api(self):
...@@ -1195,7 +1247,7 @@ class TestGELUAPI(unittest.TestCase): ...@@ -1195,7 +1247,7 @@ class TestGELUAPI(unittest.TestCase):
def setUp(self): def setUp(self):
np.random.seed(1024) np.random.seed(1024)
self.x_np = np.random.uniform(-1, 1, [11, 17]).astype('float32') self.x_np = np.random.uniform(-1, 1, [11, 17]).astype('float32')
self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \
else paddle.CPUPlace() else paddle.CPUPlace()
def test_static_api(self): def test_static_api(self):
...@@ -1281,7 +1333,7 @@ class TestBreluAPI(unittest.TestCase): ...@@ -1281,7 +1333,7 @@ class TestBreluAPI(unittest.TestCase):
self.out_ref[self.out_ref < self.t_min] = self.t_min self.out_ref[self.out_ref < self.t_min] = self.t_min
self.out_ref[self.out_ref > self.t_max] = self.t_max self.out_ref[self.out_ref > self.t_max] = self.t_max
self.out_ref = self.out_ref.astype('float32') self.out_ref = self.out_ref.astype('float32')
self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \
else paddle.CPUPlace() else paddle.CPUPlace()
def test_fluid_api(self): def test_fluid_api(self):
...@@ -1344,7 +1396,7 @@ class TestRelu6API(unittest.TestCase): ...@@ -1344,7 +1396,7 @@ class TestRelu6API(unittest.TestCase):
np.random.seed(1024) np.random.seed(1024)
self.x_np = np.random.uniform(-1, 10, [10, 12]).astype(np.float64) self.x_np = np.random.uniform(-1, 10, [10, 12]).astype(np.float64)
self.x_np[np.abs(self.x_np) < 0.005] = 0.02 self.x_np[np.abs(self.x_np) < 0.005] = 0.02
self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \
else paddle.CPUPlace() else paddle.CPUPlace()
def test_static_api(self): def test_static_api(self):
...@@ -1430,7 +1482,7 @@ class TestHardswishAPI(unittest.TestCase): ...@@ -1430,7 +1482,7 @@ class TestHardswishAPI(unittest.TestCase):
# test paddle.nn.Hardswish, paddle.nn.functional.hardswish # test paddle.nn.Hardswish, paddle.nn.functional.hardswish
def setUp(self): def setUp(self):
self.x_np = np.random.uniform(-1, 1, [10, 12]).astype(np.float64) self.x_np = np.random.uniform(-1, 1, [10, 12]).astype(np.float64)
self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \
else paddle.CPUPlace() else paddle.CPUPlace()
def test_static_api(self): def test_static_api(self):
...@@ -1555,7 +1607,7 @@ class TestELUAPI(unittest.TestCase): ...@@ -1555,7 +1607,7 @@ class TestELUAPI(unittest.TestCase):
def setUp(self): def setUp(self):
np.random.seed(1024) np.random.seed(1024)
self.x_np = np.random.uniform(-3, 3, [10, 12]).astype('float32') self.x_np = np.random.uniform(-3, 3, [10, 12]).astype('float32')
self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \
else paddle.CPUPlace() else paddle.CPUPlace()
def test_static_api(self): def test_static_api(self):
...@@ -2055,7 +2107,7 @@ class TestSoftplusAPI(unittest.TestCase): ...@@ -2055,7 +2107,7 @@ class TestSoftplusAPI(unittest.TestCase):
self.threshold = 15 self.threshold = 15
np.random.seed(1024) np.random.seed(1024)
self.x_np = np.random.uniform(-1, 1, [10, 12]).astype(np.float64) self.x_np = np.random.uniform(-1, 1, [10, 12]).astype(np.float64)
self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \
else paddle.CPUPlace() else paddle.CPUPlace()
def test_static_api(self): def test_static_api(self):
...@@ -2134,7 +2186,7 @@ class TestSoftsignAPI(unittest.TestCase): ...@@ -2134,7 +2186,7 @@ class TestSoftsignAPI(unittest.TestCase):
def setUp(self): def setUp(self):
np.random.seed(1024) np.random.seed(1024)
self.x_np = np.random.uniform(-1, 1, [10, 12]).astype(np.float64) self.x_np = np.random.uniform(-1, 1, [10, 12]).astype(np.float64)
self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \
else paddle.CPUPlace() else paddle.CPUPlace()
def test_static_api(self): def test_static_api(self):
...@@ -2219,7 +2271,7 @@ class TestThresholdedReluAPI(unittest.TestCase): ...@@ -2219,7 +2271,7 @@ class TestThresholdedReluAPI(unittest.TestCase):
np.random.seed(1024) np.random.seed(1024)
self.x_np = np.random.uniform(-20, 20, [10, 12]).astype(np.float64) self.x_np = np.random.uniform(-20, 20, [10, 12]).astype(np.float64)
self.x_np[np.abs(self.x_np) < 0.005] = 0.02 self.x_np[np.abs(self.x_np) < 0.005] = 0.02
self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \
else paddle.CPUPlace() else paddle.CPUPlace()
def test_static_api(self): def test_static_api(self):
...@@ -2317,12 +2369,12 @@ class TestHardsigmoidAPI(unittest.TestCase): ...@@ -2317,12 +2369,12 @@ class TestHardsigmoidAPI(unittest.TestCase):
# test paddle.nn.Hardsigmoid, paddle.nn.functional.hardsigmoid # test paddle.nn.Hardsigmoid, paddle.nn.functional.hardsigmoid
def setUp(self): def setUp(self):
self.x_np = np.random.uniform(-1, 1, [10, 12]).astype(np.float64) self.x_np = np.random.uniform(-1, 1, [10, 12]).astype(np.float64)
self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \
else paddle.CPUPlace() else paddle.CPUPlace()
def test_static_api(self): def test_static_api(self):
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
x = paddle.fluid.data('X', self.x_np.shape, self.x_np.dtype) x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype)
out1 = F.hardsigmoid(x) out1 = F.hardsigmoid(x)
m = paddle.nn.Hardsigmoid() m = paddle.nn.Hardsigmoid()
out2 = m(x) out2 = m(x)
...@@ -2400,13 +2452,13 @@ class TestSwishAPI(unittest.TestCase): ...@@ -2400,13 +2452,13 @@ class TestSwishAPI(unittest.TestCase):
def setUp(self): def setUp(self):
np.random.seed(1024) np.random.seed(1024)
self.x_np = np.random.uniform(-1, 1, [10, 12]).astype(np.float64) self.x_np = np.random.uniform(-1, 1, [10, 12]).astype(np.float64)
self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \
else paddle.CPUPlace() else paddle.CPUPlace()
def test_static_api(self): def test_static_api(self):
paddle.enable_static() paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
x = paddle.fluid.data('X', self.x_np.shape, self.x_np.dtype) x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype)
out1 = F.swish(x) out1 = F.swish(x)
swish = paddle.nn.Swish() swish = paddle.nn.Swish()
out2 = swish(x) out2 = swish(x)
...@@ -2483,6 +2535,7 @@ create_test_error_class('rsqrt') ...@@ -2483,6 +2535,7 @@ create_test_error_class('rsqrt')
create_test_error_class('sin') create_test_error_class('sin')
create_test_error_class('sqrt') create_test_error_class('sqrt')
create_test_error_class('tanh') create_test_error_class('tanh')
create_test_error_class('tan')
#------------------ Test Cudnn Activation---------------------- #------------------ Test Cudnn Activation----------------------
...@@ -2509,7 +2562,7 @@ def create_test_act_fp16_class(parent, ...@@ -2509,7 +2562,7 @@ def create_test_act_fp16_class(parent,
atol=1e-3, atol=1e-3,
grad_check=True, grad_check=True,
grad_atol=0.80): grad_atol=0.80):
@unittest.skipIf(not core.is_compiled_with_cuda(), @unittest.skipIf(not paddle.is_compiled_with_cuda(),
"core is not compiled with CUDA") "core is not compiled with CUDA")
class TestActFp16(parent): class TestActFp16(parent):
def init_dtype(self): def init_dtype(self):
...@@ -2545,6 +2598,7 @@ create_test_act_fp16_class(TestAbs) ...@@ -2545,6 +2598,7 @@ create_test_act_fp16_class(TestAbs)
create_test_act_fp16_class(TestCeil, grad_check=False) create_test_act_fp16_class(TestCeil, grad_check=False)
create_test_act_fp16_class(TestFloor, grad_check=False) create_test_act_fp16_class(TestFloor, grad_check=False)
create_test_act_fp16_class(TestCos, grad_atol=0.85) create_test_act_fp16_class(TestCos, grad_atol=0.85)
create_test_act_fp16_class(TestTan, grad_atol=0.85)
create_test_act_fp16_class(TestCosh, grad_atol=0.85) create_test_act_fp16_class(TestCosh, grad_atol=0.85)
create_test_act_fp16_class(TestAcos, grad_atol=0.85) create_test_act_fp16_class(TestAcos, grad_atol=0.85)
create_test_act_fp16_class(TestSin) create_test_act_fp16_class(TestSin)
......
...@@ -104,6 +104,7 @@ from .math import asin #DEFINE_ALIAS ...@@ -104,6 +104,7 @@ from .math import asin #DEFINE_ALIAS
from .math import atan #DEFINE_ALIAS from .math import atan #DEFINE_ALIAS
from .math import ceil #DEFINE_ALIAS from .math import ceil #DEFINE_ALIAS
from .math import cos #DEFINE_ALIAS from .math import cos #DEFINE_ALIAS
from .math import tan #DEFINE_ALIAS
from .math import cosh #DEFINE_ALIAS from .math import cosh #DEFINE_ALIAS
from .math import cumsum #DEFINE_ALIAS from .math import cumsum #DEFINE_ALIAS
# from .math import elementwise_add #DEFINE_ALIAS # from .math import elementwise_add #DEFINE_ALIAS
......
...@@ -33,6 +33,7 @@ from ..fluid.layers import acos #DEFINE_ALIAS ...@@ -33,6 +33,7 @@ from ..fluid.layers import acos #DEFINE_ALIAS
from ..fluid.layers import asin #DEFINE_ALIAS from ..fluid.layers import asin #DEFINE_ALIAS
from ..fluid.layers import ceil #DEFINE_ALIAS from ..fluid.layers import ceil #DEFINE_ALIAS
from ..fluid.layers import cos #DEFINE_ALIAS from ..fluid.layers import cos #DEFINE_ALIAS
from ..fluid.layers import tan #DEFINE_ALIAS
from ..fluid.layers import sinh #DEFINE_ALIAS from ..fluid.layers import sinh #DEFINE_ALIAS
from ..fluid.layers import cosh #DEFINE_ALIAS from ..fluid.layers import cosh #DEFINE_ALIAS
# from ..fluid.layers import elementwise_add #DEFINE_ALIAS # from ..fluid.layers import elementwise_add #DEFINE_ALIAS
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册