未验证 提交 948c57d8 编写于 作者: K kinghuin 提交者: GitHub

move sin, sqrt, tanh, atan to paddle.tensor.math and add a new parameter "out" (#23387)

* sin sqrt tanh atan add out, test=develop

* optimize doc, test=develop

* add dygraph test, test=develop
上级 a2e9af56
......@@ -169,14 +169,14 @@ $$out = \\log \\frac{1}{1 + e^{-x}}$$
UNUSED constexpr char ExpDoc[] = R"DOC(
Exp Operator. Computes exp of x element-wise with a natural number :math:`e` as the base.
$out = e^x$
$$out = e^x$$
)DOC";
UNUSED constexpr char ReluDoc[] = R"DOC(
Relu Activation Operator.
$out = \max(x, 0)$
$$out = \max(x, 0)$$
)DOC";
......@@ -209,42 +209,42 @@ Rsqrt Activation Operator.
Please make sure input is legal in case of numeric errors.
$out = \frac{1}{\sqrt{x}}$
$$out = \frac{1}{\sqrt{x}}$$
)DOC";
UNUSED constexpr char AbsDoc[] = R"DOC(
Abs Activation Operator.
$out = |x|$
$$out = |x|$$
)DOC";
UNUSED constexpr char CeilDoc[] = R"DOC(
Ceil Operator. Computes ceil of x element-wise.
$out = \left \lceil x \right \rceil$
$$out = \left \lceil x \right \rceil$$
)DOC";
UNUSED constexpr char FloorDoc[] = R"DOC(
Floor Activation Operator.
$out = \left \lfloor x \right \rfloor$
$$out = \left \lfloor x \right \rfloor$$
)DOC";
UNUSED constexpr char CosDoc[] = R"DOC(
Cosine Operator. Computes cosine of x element-wise.
$out = cos(x)$
$$out = cos(x)$$
)DOC";
UNUSED constexpr char SinDoc[] = R"DOC(
Sine Activation Operator.
$out = sin(x)$
$$out = sin(x)$$
)DOC";
......@@ -273,7 +273,7 @@ $$out = \\frac{1}{x}$$
UNUSED constexpr char LogDoc[] = R"DOC(
Log Activation Operator.
$out = \ln(x)$
$$out = \ln(x)$$
Natural logarithm of x.
......@@ -282,14 +282,14 @@ Natural logarithm of x.
UNUSED constexpr char SquareDoc[] = R"DOC(
The OP square each elements of the inputs.
$out = x^2$
$$out = x^2$$
)DOC";
UNUSED constexpr char SoftplusDoc[] = R"DOC(
Softplus Activation Operator.
$out = \ln(1 + e^{x})$
$$out = \ln(1 + e^{x})$$
)DOC";
......@@ -423,7 +423,7 @@ class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(
BRelu Activation Operator.
$out = \min(\max(x, t_{min}), t_{max})$
$$out = \min(\max(x, t_{min}), t_{max})$$
)DOC");
}
......@@ -439,7 +439,7 @@ class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(
SoftRelu Activation Operator.
$out = \ln(1 + \exp(\max(\min(x, threshold), -threshold)))$
$$out = \ln(1 + \exp(\max(\min(x, threshold), -threshold)))$$
)DOC");
}
......@@ -461,7 +461,7 @@ ELU Activation Operator.
Applies the following element-wise computation on the input according to
https://arxiv.org/abs/1511.07289.
$out = \max(0, x) + \min(0, \alpha * (e^x - 1))$
$$out = \max(0, x) + \min(0, \alpha * (e^x - 1))$$
)DOC");
}
......@@ -482,7 +482,7 @@ class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(
Relu6 Activation Operator.
$out = \min(\max(0, x), threshold)$
$$out = \min(\max(0, x), threshold)$$
)DOC");
}
......@@ -502,7 +502,7 @@ class PowOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(
Pow Activation Operator.
$out = x^{factor}$
$$out = x^{factor}$$
)DOC");
}
......@@ -568,7 +568,7 @@ HardSigmoid Activation Operator.
A 3-part piecewise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391),
which is much faster than sigmoid.
$out = \max(0, \min(1, slope * x + offset))$
$$out = \max(0, \min(1, slope * x + offset))$$
)DOC");
}
......@@ -608,7 +608,7 @@ HardSwish Activation Operator.
The hard version of swish(https://arxiv.org/pdf/1905.02244.pdf).
$out = \frac{x * (min(max(0, x+offset), threshold))}{scale}$
$$out = \frac{x * (min(max(0, x+offset), threshold))}{scale}$$
The threshold and scale should be positive. The offset can be either positive or negative.
The default parameters are set according to the above reference.
......
......@@ -92,7 +92,7 @@ from .tensor.logic import equal #DEFINE_ALIAS
# from .tensor.math import abs #DEFINE_ALIAS
# from .tensor.math import acos #DEFINE_ALIAS
# from .tensor.math import asin #DEFINE_ALIAS
# from .tensor.math import atan #DEFINE_ALIAS
from .tensor.math import atan #DEFINE_ALIAS
# from .tensor.math import ceil #DEFINE_ALIAS
# from .tensor.math import cos #DEFINE_ALIAS
# from .tensor.math import cumsum #DEFINE_ALIAS
......@@ -121,13 +121,13 @@ from .tensor.logic import equal #DEFINE_ALIAS
# from .tensor.math import rsqrt #DEFINE_ALIAS
# from .tensor.math import scale #DEFINE_ALIAS
# from .tensor.math import sign #DEFINE_ALIAS
# from .tensor.math import sin #DEFINE_ALIAS
# from .tensor.math import sqrt #DEFINE_ALIAS
from .tensor.math import sin #DEFINE_ALIAS
from .tensor.math import sqrt #DEFINE_ALIAS
# from .tensor.math import square #DEFINE_ALIAS
# from .tensor.math import stanh #DEFINE_ALIAS
# from .tensor.math import sum #DEFINE_ALIAS
# from .tensor.math import sums #DEFINE_ALIAS
# from .tensor.math import tanh #DEFINE_ALIAS
from .tensor.math import tanh #DEFINE_ALIAS
# from .tensor.math import elementwise_sum #DEFINE_ALIAS
# from .tensor.math import max #DEFINE_ALIAS
# from .tensor.math import min #DEFINE_ALIAS
......
......@@ -19,6 +19,7 @@ import numpy as np
import paddle.fluid.core as core
from op_test import OpTest
from scipy.special import expit, erf
import paddle
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
......@@ -66,6 +67,36 @@ class TestActivation(OpTest):
pass
class TestParameter(object):
def test_out(self):
with fluid.program_guard(fluid.Program()):
data = fluid.layers.data(name="X", shape=[1])
out = eval("paddle.%s(data, out=data)" % self.op_type)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
result = exe.run(feed={"X": np.array([0.1])},
fetch_list=[data, out])
self.assertEqual(result[0], result[1])
def test_out_name(self):
with fluid.program_guard(fluid.Program()):
data = fluid.layers.data(name="X", shape=[1])
out = eval("paddle.%s(data, name='Y', out=data)" % self.op_type)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
result = exe.run(feed={"X": np.array([0.1])},
fetch_list=[data, out])
self.assertEqual(result[0], result[1])
def test_dygraph(self):
with fluid.dygraph.guard():
np_x = np.array([0.1])
x = fluid.dygraph.to_variable(np_x)
z = eval("paddle.%s(x).numpy()" % self.op_type)
z_expected = eval("np.%s(np_x)" % self.op_type)
self.assertEqual(z, z_expected)
class TestSigmoid(TestActivation):
def setUp(self):
self.op_type = "sigmoid"
......@@ -103,7 +134,7 @@ class TestLogSigmoid(TestActivation):
self.check_grad(['X'], 'Out', max_relative_error=0.008)
class TestTanh(TestActivation):
class TestTanh(TestActivation, TestParameter):
def setUp(self):
self.op_type = "tanh"
self.init_dtype()
......@@ -125,7 +156,7 @@ class TestTanh(TestActivation):
self.dtype = np.float32
class TestAtan(TestActivation):
class TestAtan(TestActivation, TestParameter):
def setUp(self):
self.op_type = "atan"
self.init_dtype()
......@@ -141,6 +172,14 @@ class TestAtan(TestActivation):
return
self.check_grad(['X'], 'Out')
def test_dygraph(self):
with fluid.dygraph.guard():
np_x = np.array([0.1])
x = fluid.dygraph.to_variable(np_x)
z = paddle.atan(x).numpy()
z_expected = np.arctan(np_x)
self.assertEqual(z, z_expected)
class TestTanhShrink(TestActivation):
def setUp(self):
......@@ -200,7 +239,7 @@ class TestSoftShrink(TestActivation):
self.check_grad(['X'], 'Out')
class TestSqrt(TestActivation):
class TestSqrt(TestActivation, TestParameter):
def setUp(self):
self.op_type = "sqrt"
self.init_dtype()
......@@ -324,7 +363,7 @@ class TestAcos(TestActivation):
self.check_grad(['X'], 'Out')
class TestSin(TestActivation):
class TestSin(TestActivation, TestParameter):
def setUp(self):
self.op_type = "sin"
self.init_dtype()
......
......@@ -69,7 +69,7 @@ from .logic import equal #DEFINE_ALIAS
# from .math import abs #DEFINE_ALIAS
# from .math import acos #DEFINE_ALIAS
# from .math import asin #DEFINE_ALIAS
# from .math import atan #DEFINE_ALIAS
from .math import atan #DEFINE_ALIAS
# from .math import ceil #DEFINE_ALIAS
# from .math import cos #DEFINE_ALIAS
# from .math import cumsum #DEFINE_ALIAS
......@@ -98,13 +98,13 @@ from .logic import equal #DEFINE_ALIAS
# from .math import rsqrt #DEFINE_ALIAS
# from .math import scale #DEFINE_ALIAS
# from .math import sign #DEFINE_ALIAS
# from .math import sin #DEFINE_ALIAS
# from .math import sqrt #DEFINE_ALIAS
from .math import sin #DEFINE_ALIAS
from .math import sqrt #DEFINE_ALIAS
# from .math import square #DEFINE_ALIAS
# from .math import stanh #DEFINE_ALIAS
# from .math import sum #DEFINE_ALIAS
# from .math import sums #DEFINE_ALIAS
# from .math import tanh #DEFINE_ALIAS
from .math import tanh #DEFINE_ALIAS
# from .math import elementwise_sum #DEFINE_ALIAS
# from .math import max #DEFINE_ALIAS
# from .math import min #DEFINE_ALIAS
......
......@@ -12,11 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO: define math functions
# __all__ = ['abs',
from __future__ import print_function
import warnings
from ..fluid.framework import OpProtoHolder, core, in_dygraph_mode
from ..fluid.layer_helper import LayerHelper
from ..fluid.data_feeder import check_variable_and_dtype
from ..fluid.layers.layer_function_generator import _generate_doc_string_
# TODO: define math functions
# yapf: disable
__all__ = [
# 'abs',
# 'acos',
# 'asin',
# 'atan',
'atan',
# 'ceil',
# 'cos',
# 'cumsum',
......@@ -45,13 +56,13 @@
# 'rsqrt',
# 'scale',
# 'sign',
# 'sin',
# 'sqrt',
'sin',
'sqrt',
# 'square',
# 'stanh',
# 'sum',
# 'sums',
# 'tanh',
'tanh',
# 'elementwise_sum',
# 'max',
# 'min',
......@@ -65,3 +76,81 @@
# 'erf',
# 'addcmul',
# 'addmm']
]
# yapf: enable.
def generate_op_noattr(op_type):
"""Register the Python layer for an Operator without Attribute..
Args:
op_type: The name of the operator to be created.
This function takes in the operator type (sin, tanh etc) and
creates the operator functionality.
"""
op_proto = OpProtoHolder.instance().get_op_proto(op_type)
def func(x, name=None, out=None):
if in_dygraph_mode():
op = getattr(core.ops, op_type)
return op(x)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
op_type)
helper = LayerHelper(op_type, **locals())
if name and out:
warnings.warn(
"Both name and out parameters have been set in fluid.tensor.math.%s(), only out will take effect to specify the result storage. "
"You can discard either one to solve this warning." % op_type,
category=UserWarning,
stacklevel=2)
if not out:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type=op_type, inputs={"X": x}, outputs={"Out": out})
return out
func.__name__ = op_type
func.__doc__ = _generate_doc_string_(
op_proto,
additional_args_lines=[
"name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`.\n "
"out(Variable, optional): The default value is None. Optional output can be any created Variable that meets the requirements to store the result of operation. if out is None, a new Varibale will be create to store the result."
])
func.__doc__ = func.__doc__ + """
Return type
Variable
Examples:
.. code-block:: python
import numpy as np
import paddle
import paddle.fluid as fluid
inputs = fluid.data(name="x", shape = [None, 4], dtype='float32')
output = paddle.%s(inputs)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
#input.shape=1X4, batch_size=1
img = np.array([[1.0, 2.0, 3.0, 4.0]]).astype(np.float32)
res = exe.run(fluid.default_main_program(), feed={'x':img}, fetch_list=[output])
print(res)
""" % op_type
return func
__ops__noattr__ = [
'atan',
'sin',
'sqrt',
'tanh',
]
for _OP in set(__ops__noattr__):
globals()[_OP] = generate_op_noattr(_OP)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册