未验证 提交 acc25d0b 编写于 作者: Z zhiboniu 提交者: GitHub

tensor fluid code transfer part1 (#41094)

上级 f43af275
......@@ -22,7 +22,7 @@ from paddle.fluid.core import VarDesc
from paddle.fluid import core, dygraph_utils
from paddle.fluid.data_feeder import check_type, check_dtype, check_variable_and_dtype, convert_dtype
from paddle.fluid.layers import fill_constant, utils, scale
from paddle.fluid.layers.layer_function_generator import templatedoc
from paddle.tensor.layer_function_generator import templatedoc
import paddle.fluid as fluid
import numpy
import warnings
......@@ -68,26 +68,26 @@ class TestMultiplexOpError(unittest.TestCase):
def test_list():
# the inputs type must be list
fluid.layers.multiplex(inputs=x1, index=index)
paddle.multiplex(inputs=x1, index=index)
self.assertRaises(TypeError, test_list)
def test_len():
fluid.layers.multiplex(inputs=[x1], index=index)
paddle.multiplex(inputs=[x1], index=index)
self.assertRaises(ValueError, test_len)
def test_type():
y1 = fluid.data(name='y1', shape=[None, 2], dtype='int16')
y2 = fluid.data(name='y2', shape=[None, 2], dtype='int16')
fluid.layers.multiplex(inputs=[y1, y2], index=index)
paddle.multiplex(inputs=[y1, y2], index=index)
self.assertRaises(TypeError, test_type)
def test_type2():
index2 = fluid.data(
name='index2', shape=[None, 1], dtype='int16')
fluid.layers.multiplex(inputs=[x1, x2], index=index2)
paddle.multiplex(inputs=[x1, x2], index=index2)
self.assertRaises(TypeError, test_type2)
......
......@@ -53,4 +53,7 @@ from ..fluid.framework import dygraph_only # noqa: F401
from ..fluid.framework import convert_np_dtype_to_dtype_, _varbase_creator, OpProtoHolder # noqa: F401
from ..fluid.framework import _dygraph_tracer # noqa: F401
from ..fluid.layer_helper import LayerHelper # noqa: F401
from ..fluid.framework import in_dygraph_mode # noqa: F401
__all__ = []
......@@ -20,7 +20,7 @@ from ...fluid.layer_helper import LayerHelper
from ...static import Variable
from ...tensor.creation import assign
from ...fluid import dygraph_utils
from ...fluid.layers.layer_function_generator import templatedoc
from ...tensor.layer_function_generator import templatedoc
from ...fluid.layers.sequence_lod import sequence_mask #noqa: F401
from paddle import in_dynamic_mode
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import re
import functools
import warnings
import string
from six.moves import cStringIO
from ..static import Variable
from ..fluid.proto import framework_pb2
from ..framework import OpProtoHolder, core, convert_np_dtype_to_dtype_
from ..framework import LayerHelper
from ..fluid.data_feeder import check_variable_and_dtype
import paddle
from paddle import _C_ops
__all__ = []
def _convert_(name):
"""
Formatting.
Args:
name: The name/alias
This function takes in a name and converts it to a standard format of
group1_group2. Where as per the regular expression, group1 can have
alphabets and numbers and group2 has capital alphabets.
"""
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
def _type_to_str_(tp):
return framework_pb2.AttrType.Name(tp)
_two_dollar_pattern_ = re.compile(r"\$\$([^\$]+)\$\$")
_single_dollar_pattern_ = re.compile(r"\$([^\$]+)\$")
_two_bang_pattern_ = re.compile(r"!!([^!]+)!!")
def escape_math(text):
#return _two_bang_pattern_.sub(
# r'$$\1$$',
# _single_dollar_pattern_.sub(r':math:\n`\1`',
# _two_dollar_pattern_.sub(r"!!\1!!", text)))
return _two_dollar_pattern_.sub(r':math:`\1`', text)
def _generate_doc_string_(op_proto,
additional_args_lines=None,
skip_attrs_set=None):
"""
Generate docstring by OpProto
Args:
op_proto (framework_pb2.OpProto): a protobuf message typed OpProto
Returns:
str: the document string
"""
if not isinstance(op_proto, framework_pb2.OpProto):
raise TypeError("OpProto should be `framework_pb2.OpProto`")
buf = cStringIO()
buf.write(escape_math(op_proto.comment))
buf.write('\nArgs:\n')
for each_input in op_proto.inputs:
line_begin = ' {0}'.format(_convert_(each_input.name))
buf.write(line_begin)
buf.write(" (Tensor): ")
buf.write(escape_math(each_input.comment))
if each_input.duplicable:
buf.write(" Duplicatable.")
if each_input.dispensable:
buf.write(" Optional.")
buf.write('\n')
skip_attrs = OpProtoHolder.generated_op_attr_names()
# attr use_mkldnn and is_test also should not be visible to users.
skip_attrs.add("use_mkldnn")
skip_attrs.add("is_test")
skip_attrs.add("use_cudnn")
if skip_attrs_set:
for t in skip_attrs_set:
skip_attrs.add(t)
for each_attr in op_proto.attrs:
if each_attr.name in skip_attrs:
continue
buf.write(' ')
buf.write(each_attr.name)
buf.write(' (')
buf.write(_type_to_str_(each_attr.type))
buf.write('): ')
buf.write(escape_math(each_attr.comment))
buf.write('\n')
if additional_args_lines is not None:
for line in additional_args_lines:
line = line.strip()
buf.write(' ')
buf.write(line)
buf.write('\n')
if len(op_proto.outputs) != 0:
buf.write('\nReturns:\n')
buf.write(' ')
for each_opt in op_proto.outputs:
if not each_opt.intermediate:
break
buf.write(_convert_(each_opt.name))
buf.write(' (Tensor): ')
buf.write(escape_math(each_opt.comment))
return buf.getvalue()
def generate_layer_fn(op_type):
"""Register the Python layer for an Operator.
Args:
op_type: The name of the operator to be created.
This function takes in the operator type (sigmoid, mean , average etc) and
creates the operator functionality.
"""
op_proto = OpProtoHolder.instance().get_op_proto(op_type)
not_intermediate_outputs = \
[output for output in op_proto.outputs if not output.intermediate]
intermediate_outputs = \
[output for output in op_proto.outputs if output.intermediate]
if len(not_intermediate_outputs) != 1:
raise ValueError("Only one non intermediate output operator can be",
"automatically generated. {0}".format(op_type))
if not_intermediate_outputs[0].duplicable:
raise ValueError(
"Only non duplicable op can be automatically generated.")
for output in intermediate_outputs:
if output.duplicable:
raise ValueError("The op can be automatically generated only when ",
"all intermediate ops are not duplicable.")
o_name = not_intermediate_outputs[0].name
intermediate_output_names = [output.name for output in intermediate_outputs]
def infer_and_check_dtype(op_proto, *args, **kwargs):
"""
This function performs the sanity check for dtype and
instance type.
"""
dtype = None
for ipt in op_proto.inputs:
name = _convert_(ipt.name)
val = kwargs.pop(name, [])
if not isinstance(val, list) and not isinstance(val, tuple):
val = [val]
if len(val) == 0:
if len(args) == 0:
continue
val = [args[0]]
args = args[1:]
for each in val:
if not isinstance(each, Variable):
raise ValueError("input of {0} must be variable".format(
op_type))
if dtype is None:
dtype = each.dtype
elif dtype != each.dtype:
raise ValueError(
"operator {0} must input same dtype. {1} vs {2}".format(
op_type, dtype, each.dtype))
if dtype is None:
arg_dtype = kwargs.get("dtype")
if arg_dtype:
if not isinstance(arg_dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(arg_dtype)
else:
dtype = arg_dtype
else:
dtype = core.VarDesc.VarType.FP32
return dtype
def func(*args, **kwargs):
helper = LayerHelper(op_type, **kwargs)
dtype = infer_and_check_dtype(op_proto, *args, **kwargs)
inputs = dict()
for ipt in op_proto.inputs:
name = _convert_(ipt.name)
val = kwargs.pop(name, [])
if not isinstance(val, list) and not isinstance(val, tuple):
val = [val]
if len(val) == 0 and len(args) != 0:
val = args[0]
args = args[1:]
inputs[ipt.name] = val
outputs = dict()
out = kwargs.pop(_convert_(o_name), [])
if out:
out_var = out[0] if (isinstance(out, list) or
isinstance(out, tuple)) else out
else:
out_var = helper.create_variable_for_type_inference(dtype=dtype)
outputs[o_name] = [out_var]
for name in intermediate_output_names:
outputs[name] = [
helper.create_variable_for_type_inference(dtype=dtype)
]
helper.append_op(
type=op_type, inputs=inputs, outputs=outputs, attrs=kwargs)
return helper.append_activation(out_var)
func.__name__ = op_type
func.__doc__ = _generate_doc_string_(op_proto)
return func
def generate_activation_fn(op_type):
"""Register the Python layer for an Operator without Attribute.
Args:
op_type: The name of the operator to be created.
This function takes in the operator type (sigmoid, exp , tanh etc) and
creates the operator functionality.
"""
op_proto = OpProtoHolder.instance().get_op_proto(op_type)
def func(x, name=None):
if paddle.in_dynamic_mode():
op = getattr(_C_ops, op_type)
return op(x)
if op_type not in ["abs", "exp", "square"]:
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
op_type)
else:
# abs exp square ops support dtype(int32, int64, float16, float32, float64)
check_variable_and_dtype(
x, 'x', ['int32', 'int64', 'float16', 'float32', 'float64'],
op_type)
helper = LayerHelper(op_type, **locals())
output = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type=op_type, inputs={"X": x}, outputs={"Out": output})
return output
func.__name__ = op_type
func.__doc__ = _generate_doc_string_(
op_proto,
additional_args_lines=[
"name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`."
])
return func
def generate_inplace_fn(inplace_op_type):
"""Register the Python layer for an Inplace Operator without Attribute.
Args:
inplace_op_type: The name of the inplace operator to be created.
This function takes in the inplace operator type (exp_ , ceil_ etc) and
creates the operator functionality.
"""
origin_op_type = inplace_op_type[:-1]
def func(x, name=None):
if paddle.in_dynamic_mode():
op = getattr(_C_ops, inplace_op_type)
return op(x)
warnings.warn(
"In static mode, {}() is the same as {}() and does not perform inplace operation.".
format(inplace_op_type, origin_op_type))
return generate_activation_fn(origin_op_type)(x, name)
func.__name__ = inplace_op_type
func.__doc__ = """
Inplace version of ``{0}`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_fluid_layers_{1}`.
""".format(origin_op_type, origin_op_type)
return func
def templatedoc(op_type=None):
"""
Decorator of layer function. It will use the docstring from the layer
function as the template. The template arguments are:
* ${comment}: The operator comment written in CPP.
* ${{name}_comment}: The comment of ${name} written with AddAttr, AddOutput,
and AddInput. The ${name} is Python snake style. i.e., xxx_xxx.
* ${{name}_type}: The type of ${name}.
Returns:
Decorated function.
"""
def trim_ending_dot(msg):
return msg.rstrip('.')
def __impl__(func):
if op_type is None:
op_type_name = func.__name__
else:
op_type_name = op_type
op_proto = OpProtoHolder.instance().get_op_proto(op_type_name)
tmpl = string.Template(func.__doc__)
comment_lines = op_proto.comment.split("\n")
comment = ""
for line in comment_lines:
line = line.strip()
if len(line) != 0:
comment += escape_math(line)
comment += " "
elif len(comment) != 0:
comment += "\n \n "
args = {"comment": trim_ending_dot(comment)}
for each_input in op_proto.inputs:
input_name = _convert_(each_input.name)
args["{0}_comment".format(input_name)] = trim_ending_dot(
each_input.comment)
args["{0}_type".format(input_name)] = "Variable"
for each_attr in op_proto.attrs:
input_name = _convert_(each_attr.name)
args["{0}_comment".format(input_name)] = trim_ending_dot(
each_attr.comment)
args["{0}_type".format(input_name)] = _type_to_str_(each_attr.type)
for each_opt in op_proto.outputs:
output_name = _convert_(each_opt.name)
args["{0}_comment".format(output_name)] = trim_ending_dot(
each_opt.comment)
args["{0}_type".format(output_name)] = "Variable"
func.__doc__ = tmpl.substitute(args)
return func
return __impl__
def add_sample_code(func, sample_code):
"""
Append sample code for dynamically generated functions.
Args:
func: The function of the function to be append sample code to.
sample_code: sample code session in rst format.
"""
func.__doc__ = func.__doc__ + sample_code
......@@ -14,7 +14,7 @@
from ..fluid.layer_helper import LayerHelper
from ..fluid.data_feeder import check_type, check_variable_and_dtype
from ..fluid.layers.layer_function_generator import templatedoc
from .layer_function_generator import templatedoc
from ..static import Variable
from ..fluid.framework import _in_legacy_dygraph, in_dygraph_mode
# TODO: define logic functions of a tensor
......
......@@ -23,56 +23,52 @@ from paddle.common_ops_import import OpProtoHolder
from paddle.common_ops_import import templatedoc
from paddle.common_ops_import import dygraph_utils
from paddle.tensor import cast
from paddle.tensor.attribute import _complex_to_real_dtype
from .manipulation import cast
from .creation import _complex_to_real_dtype
from .layer_function_generator import _generate_doc_string_, generate_activation_fn, generate_layer_fn
import paddle
from paddle.static import Variable
from ..framework import core
from ..fluid.framework import _in_legacy_dygraph, in_dygraph_mode, _non_static_mode
from ..static import Variable
from ..framework import core, in_dygraph_mode, _non_static_mode, LayerHelper
from ..fluid.framework import _in_legacy_dygraph
from ..framework import _varbase_creator, convert_np_dtype_to_dtype_
from ..fluid.layer_helper import LayerHelper
from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype, convert_dtype
from ..fluid.layers.layer_function_generator import _generate_doc_string_, generate_activation_fn, generate_layer_fn
from ..fluid.dygraph.inplace_utils import inplace_apis_in_dygraph_only
# TODO: define math functions
# yapf: disable
from ..fluid.layers import abs # noqa: F401
from ..fluid.layers import acos # noqa: F401
from ..fluid.layers import asin # noqa: F401
from ..fluid.layers import ceil # noqa: F401
from ..fluid.layers import ceil_ # noqa: F401
from ..fluid.layers import cos # noqa: F401
from ..fluid.layers import tan # noqa: F401
from ..fluid.layers import sinh # noqa: F401
from ..fluid.layers import cosh # noqa: F401
from ..fluid.layers import exp # noqa: F401
from ..fluid.layers import exp_ # noqa: F401
from ..fluid.layers import expm1 # noqa: F401
from ..fluid.layers import floor # noqa: F401
from ..fluid.layers import floor_ # noqa: F401
from ..fluid.layers import log # noqa: F401
from ..fluid.layers import reciprocal # noqa: F401
from ..fluid.layers import reciprocal_ # noqa: F401
from ..fluid.layers import round # noqa: F401
from ..fluid.layers import round_ # noqa: F401
from ..fluid.layers import rsqrt # noqa: F401
from ..fluid.layers import rsqrt_ # noqa: F401
from ..fluid.layers import scale # noqa: F401
from ..fluid.layers import square # noqa: F401
from ..fluid.layers import stanh # noqa: F401
from ..fluid.layers import atan # noqa: F401
from ..fluid.layers import erf # noqa: F401
from ..fluid.layers import sqrt # noqa: F401
from ..fluid.layers import sqrt_ # noqa: F401
from ..fluid.layers import sin # noqa: F401
from ..fluid.layers import lgamma # noqa: F401
from ..fluid.layers import asinh # noqa: F401
from ..fluid.layers import acosh # noqa: F401
from ..fluid.layers import atanh # noqa: F401
from ..fluid.layers import multiplex # noqa: F401
from ..fluid.layers import reduce_prod
from .ops import abs # noqa: F401
from .ops import acos # noqa: F401
from .ops import asin # noqa: F401
from .ops import ceil # noqa: F401
from .ops import ceil_ # noqa: F401
from .ops import cos # noqa: F401
from .ops import tan # noqa: F401
from .ops import sinh # noqa: F401
from .ops import cosh # noqa: F401
from .ops import exp # noqa: F401
from .ops import exp_ # noqa: F401
from .ops import expm1 # noqa: F401
from .ops import floor # noqa: F401
from .ops import floor_ # noqa: F401
from .ops import reciprocal # noqa: F401
from .ops import reciprocal_ # noqa: F401
from .ops import round # noqa: F401
from .ops import round_ # noqa: F401
from .ops import rsqrt # noqa: F401
from .ops import rsqrt_ # noqa: F401
from .ops import square # noqa: F401
from .ops import atan # noqa: F401
from .ops import erf # noqa: F401
from .ops import sqrt # noqa: F401
from .ops import sqrt_ # noqa: F401
from .ops import sin # noqa: F401
from .ops import lgamma # noqa: F401
from .ops import asinh # noqa: F401
from .ops import acosh # noqa: F401
from .ops import atanh # noqa: F401
from ..fluid.layers import elementwise_sub
from paddle import _C_ops
......@@ -92,6 +88,241 @@ _supported_float_dtype_ = [
]
def log(x, name=None):
r"""
Calculates the natural log of the given input tensor, element-wise.
.. math::
Out = \\ln(x)
Args:
x (Tensor): Input Tensor. Must be one of the following types: float32, float64.
name (str|None): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`
Returns:
Tensor: The natural log of the input Tensor computed element-wise.
Examples:
.. code-block:: python
import paddle
x = [[2,3,4], [7,8,9]]
x = paddle.to_tensor(x, dtype='float32')
res = paddle.log(x)
# [[0.693147, 1.09861, 1.38629], [1.94591, 2.07944, 2.19722]]
"""
if in_dygraph_mode():
return _C_ops.final_state_log(x)
if _in_legacy_dygraph():
return _C_ops.log(x)
check_variable_and_dtype(x, 'x', ['float32', 'float64'], "log")
inputs = {'X': [x]}
helper = LayerHelper('log', **locals())
dtype = helper.input_dtype(input_param_name='x')
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(type="log", inputs={"X": x}, outputs={"Out": out})
return out
def scale(x, scale=1.0, bias=0.0, bias_after_scale=True, act=None, name=None):
"""
Scale operator.
Putting scale and bias to the input Tensor as following:
``bias_after_scale`` is True:
.. math::
Out=scale*X+bias
``bias_after_scale`` is False:
.. math::
Out=scale*(X+bias)
Args:
x(Tensor): Input N-D Tensor of scale operator. Data type can be float32, float64, int8, int16, int32, int64, uint8.
scale(float|Tensor): The scale factor of the input, it should be a float number or a Tensor with shape [1] and data type as float32.
bias(float): The bias to be put on the input.
bias_after_scale(bool): Apply bias addition after or before scaling. It is useful for numeric stability in some circumstances.
act(str, optional): Activation applied to the output such as tanh, softmax, sigmoid, relu.
name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`
Returns:
Tensor: Output tensor of scale operator, with shape and data type same as input.
Examples:
.. code-block:: python
# scale as a float32 number
import paddle
data = paddle.randn(shape=[2,3], dtype='float32')
res = paddle.scale(data, scale=2.0, bias=1.0)
.. code-block:: python
# scale with parameter scale as a Tensor
import paddle
data = paddle.randn(shape=[2, 3], dtype='float32')
factor = paddle.to_tensor([2], dtype='float32')
res = paddle.scale(data, scale=factor, bias=1.0)
"""
if in_dygraph_mode():
out = _C_ops.final_state_scale(x, scale, float(bias), bias_after_scale)
return dygraph_utils._append_activation_in_dygraph(out)
if _non_static_mode():
_scale = scale.numpy().item(0) if isinstance(scale, Variable) else scale
out = _C_ops.scale(x, 'scale',
float(_scale), 'bias',
float(bias), 'bias_after_scale', bias_after_scale)
return dygraph_utils._append_activation_in_dygraph(out)
check_variable_and_dtype(x, "x", [
'float16', 'uint16', 'float32', 'float64', 'int8', 'int16', 'int32',
'int64', 'uint8'
], "scale")
inputs = {'X': [x]}
attrs = {
'bias': float(bias),
'bias_after_scale': bias_after_scale,
}
if isinstance(scale, Variable):
inputs['ScaleTensor'] = [scale]
else:
attrs['scale'] = float(scale)
helper = LayerHelper('scale', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='scale', inputs=inputs, outputs={'Out': out}, attrs=attrs)
return helper.append_activation(out)
def stanh(x, scale_a=0.67, scale_b=1.7159, name=None):
"""
stanh activation.
.. math::
out = b * \\frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}
Parameters:
x (Tensor): The input Tensor with data type float32, float64.
scale_a (float, optional): The scale factor a of the input. Default is 0.67.
scale_b (float, optional): The scale factor b of the output. Default is 1.7159.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([1.0, 2.0, 3.0, 4.0])
out = paddle.stanh(x, scale_a=0.67, scale_b=1.72) # [1.00616539, 1.49927628, 1.65933108, 1.70390463]
"""
if _non_static_mode():
return _C_ops.stanh(x, 'scale_a', scale_a, 'scale_b', scale_b)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'stanh')
helper = LayerHelper('stanh', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='stanh',
inputs={'X': x},
outputs={'Out': out},
attrs={'scale_a': scale_a,
'scale_b': scale_b})
return out
def multiplex(inputs, index, name=None):
"""
Based on the given index parameter, the OP selects a specific row from each input Tensor to construct the output Tensor.
If the input of this OP contains :math:`m` Tensors, where :math:`I_{i}` means the i-th input Tensor, :math:`i` between :math:`[0,m)` .
And :math:`O` means the output, where :math:`O[i]` means the i-th row of the output, then the output satisfies that :math:`O[i] = I_{index[i]}[i]` .
For Example:
.. code-block:: text
Given:
inputs = [[[0,0,3,4], [0,1,3,4], [0,2,4,4], [0,3,3,4]],
[[1,0,3,4], [1,1,7,8], [1,2,4,2], [1,3,3,4]],
[[2,0,3,4], [2,1,7,8], [2,2,4,2], [2,3,3,4]],
[[3,0,3,4], [3,1,7,8], [3,2,4,2], [3,3,3,4]]]
index = [[3],[0],[1],[2]]
out = [[3,0,3,4], # out[0] = inputs[index[0]][0] = inputs[3][0] = [3,0,3,4]
[0,1,3,4], # out[1] = inputs[index[1]][1] = inputs[0][1] = [0,1,3,4]
[1,2,4,2], # out[2] = inputs[index[2]][2] = inputs[1][2] = [1,2,4,2]
[2,3,3,4]] # out[3] = inputs[index[3]][3] = inputs[2][3] = [2,3,3,4]
Args:
inputs (list): The input Tensor list. The list elements are N-D Tensors of data types float32, float64, int32, int64. All input Tensor shapes should be the same and rank must be at least 2.
index (Tensor): Used to select some rows in the input Tensor to construct an index of the output Tensor. It is a 2-D Tensor with data type int32 or int64 and shape [M, 1], where M is the number of input Tensors.
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
Returns:
Tensor: Output of multiplex OP, with data type being float32, float64, int32, int64.
Examples:
.. code-block:: python
import paddle
import numpy as np
img1 = np.array([[1, 2], [3, 4]]).astype(np.float32)
img2 = np.array([[5, 6], [7, 8]]).astype(np.float32)
inputs = [paddle.to_tensor(img1), paddle.to_tensor(img2)]
index = paddle.to_tensor(np.array([[1], [0]]).astype(np.int32))
res = paddle.multiplex(inputs, index)
print(res) # [array([[5., 6.], [3., 4.]], dtype=float32)]
"""
if _non_static_mode():
return _C_ops.multiplex(index, inputs)
helper = LayerHelper('multiplex', **locals())
check_type(inputs, 'inputs', (list), 'multiplex')
if len(inputs) < 2:
raise ValueError(
"inputs should be a list object with at least 2 elements.")
for id, x in enumerate(inputs):
check_variable_and_dtype(x, 'input[' + str(id) + ']',
['float32', 'float64', 'int32', 'int64'],
'multiplex')
check_variable_and_dtype(index, "index", ['int32', 'int64'], 'multiplex')
out = helper.create_variable_for_type_inference(inputs[0].dtype)
helper.append_op(
type='multiplex',
inputs={'X': inputs,
'Ids': index},
outputs={'Out': [out]})
return out
@inplace_apis_in_dygraph_only
def scale_(x, scale=1.0, bias=0.0, bias_after_scale=True, act=None, name=None):
"""
......@@ -2973,7 +3204,38 @@ def prod(x, axis=None, keepdim=False, dtype=None, name=None):
if x.dtype != convert_np_dtype_to_dtype_(dtype):
x = cast(x, dtype)
return reduce_prod(input=x, dim=axis, keep_dim=keepdim, name=name)
input = x
dim = axis
keep_dim = keepdim
if dim is not None and not isinstance(dim, list):
if isinstance(dim, tuple):
dim = list(dim)
elif isinstance(dim, int):
dim = [dim]
else:
raise TypeError(
"The type of axis must be int, list or tuple, but received {}".
format(type(dim)))
if in_dygraph_mode():
return _C_ops.final_state_reduce_prod(
input, dim if dim != None and dim != [] else [0], keep_dim, True if
dim == None or dim == [] or len(dim) == len(input.shape) else False)
helper = LayerHelper('reduce_prod', **locals())
check_variable_and_dtype(
input, 'input', ['float32', 'float64', 'int32', 'int64'], 'reduce_prod')
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
helper.append_op(
type='reduce_prod',
inputs={'X': input},
outputs={'Out': out},
attrs={
'dim': dim if dim != None and dim != [] else [0],
'keep_dim': keep_dim,
'reduce_all': True if dim == None or dim == [] or
len(dim) == len(input.shape) else False
})
return out
def sign(x, name=None):
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
from .layer_function_generator import generate_layer_fn, generate_activation_fn, generate_inplace_fn, add_sample_code
from ..framework import core
from ..framework import convert_np_dtype_to_dtype_
from ..static import Variable
from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
__deprecated_func_name__ = {
'tanh_shrink': 'tanhshrink',
'logsigmoid': 'log_sigmoid'
}
__activations_noattr__ = [
'sigmoid',
'silu',
'logsigmoid',
'tanh_shrink',
'softplus',
'softsign',
'tanh',
]
__unary_func__ = [
'exp',
'expm1',
'atan',
'sqrt',
'rsqrt',
'abs',
'ceil',
'floor',
'cos',
'tan',
'acos',
'sin',
'sinh',
'asin',
'cosh',
'round',
'reciprocal',
'square',
'lgamma',
'acosh',
'asinh',
'atanh',
]
__inplace_unary_func__ = [
'exp_',
'sqrt_',
'rsqrt_',
'ceil_',
'floor_',
'round_',
'reciprocal_',
]
__all__ = []
for _OP in set(__all__):
globals()[_OP] = generate_layer_fn(_OP)
# It is a hot fix in some unittest using:
# fluid.layers.scale(x=x, scale=10.0, out=out_var)
# e.g.: test_program_code.py, test_dist_train.py
globals()['_scale'] = generate_layer_fn('scale')
globals()['_elementwise_div'] = generate_layer_fn('elementwise_div')
__all__ += __activations_noattr__
__all__ += __unary_func__
__all__ += __inplace_unary_func__
for _OP in set(__activations_noattr__):
_new_OP = _OP
if _OP in __deprecated_func_name__:
_new_OP = __deprecated_func_name__[_OP]
_func = generate_activation_fn(_OP)
globals()[_OP] = _func
for _OP in set(__unary_func__):
_new_OP = _OP
if _OP in __deprecated_func_name__:
_new_OP = __deprecated_func_name__[_OP]
_func = generate_activation_fn(_OP)
globals()[_OP] = _func
for _OP in set(__inplace_unary_func__):
_new_OP = _OP
if _OP in __deprecated_func_name__:
_new_OP = __deprecated_func_name__[_OP]
_func = generate_inplace_fn(_OP)
globals()[_OP] = _func
add_sample_code(globals()["sigmoid"], r"""
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = F.sigmoid(x)
print(out)
# [0.40131234 0.450166 0.52497919 0.57444252]
""")
add_sample_code(globals()["silu"], r"""
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
x = paddle.to_tensor([1.0, 2.0, 3.0, 4.0])
out = F.silu(x)
print(out)
# [ 0.7310586 1.7615942 2.8577224, 3.9280552 ]
""")
add_sample_code(globals()["logsigmoid"], r"""
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = F.log_sigmoid(x)
print(out)
# [-0.91301525 -0.79813887 -0.64439666 -0.55435524]
""")
add_sample_code(globals()["exp"], r"""
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.exp(x)
print(out)
# [0.67032005 0.81873075 1.10517092 1.34985881]
""")
add_sample_code(globals()["expm1"], r"""
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.expm1(x)
print(out)
# [-0.32967997, -0.18126924, 0.10517092, 0.34985882]
""")
add_sample_code(globals()["tanh"], r"""
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.tanh(x)
print(out)
# [-0.37994896 -0.19737532 0.09966799 0.29131261]
""")
add_sample_code(globals()["atan"], r"""
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.atan(x)
print(out)
# [-0.38050638 -0.19739556 0.09966865 0.29145679]
""")
add_sample_code(globals()["tanh_shrink"], r"""
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = F.tanhshrink(x)
print(out)
# [-0.020051, -0.00262468, 0.000332005, 0.00868739]
""")
add_sample_code(globals()["sqrt"], r"""
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([0.1, 0.2, 0.3, 0.4])
out = paddle.sqrt(x)
print(out)
# [0.31622777 0.4472136 0.54772256 0.63245553]
""")
add_sample_code(globals()["rsqrt"], r"""
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([0.1, 0.2, 0.3, 0.4])
out = paddle.rsqrt(x)
print(out)
# [3.16227766 2.23606798 1.82574186 1.58113883]
""")
add_sample_code(globals()["abs"], r"""
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.abs(x)
print(out)
# [0.4 0.2 0.1 0.3]
""")
add_sample_code(globals()["ceil"], r"""
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.ceil(x)
print(out)
# [-0. -0. 1. 1.]
""")
add_sample_code(globals()["floor"], r"""
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.floor(x)
print(out)
# [-1. -1. 0. 0.]
""")
add_sample_code(globals()["cos"], r"""
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.cos(x)
print(out)
# [0.92106099 0.98006658 0.99500417 0.95533649]
""")
add_sample_code(globals()["tan"], r"""
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.tan(x)
print(out)
# [-0.42279324, -0.20271005, 0.10033467, 0.30933627]
""")
add_sample_code(globals()["acos"], r"""
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.acos(x)
print(out)
# [1.98231317 1.77215425 1.47062891 1.26610367]
""")
add_sample_code(globals()["sin"], r"""
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.sin(x)
print(out)
# [-0.38941834 -0.19866933 0.09983342 0.29552021]
""")
add_sample_code(globals()["asin"], r"""
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.asin(x)
print(out)
# [-0.41151685 -0.20135792 0.10016742 0.30469265]
""")
add_sample_code(globals()["cosh"], r"""
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.cosh(x)
print(out)
# [1.08107237 1.02006676 1.00500417 1.04533851]
""")
add_sample_code(globals()["sinh"], r"""
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.sinh(x)
print(out)
# [-0.41075233 -0.201336 0.10016675 0.30452029]
""")
add_sample_code(globals()["asinh"], r"""
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.asinh(x)
print(out)
# [-0.39003533, -0.19869010, 0.09983408, 0.29567307]
""")
add_sample_code(globals()["acosh"], r"""
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([1., 3., 4., 5.])
out = paddle.acosh(x)
print(out)
# [0. , 1.76274729, 2.06343699, 2.29243159]
""")
add_sample_code(globals()["atanh"], r"""
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.atanh(x)
print(out)
# [-0.42364895, -0.20273256, 0.10033535, 0.30951962]
""")
add_sample_code(globals()["round"], r"""
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-0.5, -0.2, 0.6, 1.5])
out = paddle.round(x)
print(out)
# [-1. -0. 1. 2.]
""")
add_sample_code(globals()["reciprocal"], r"""
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.reciprocal(x)
print(out)
# [-2.5 -5. 10. 3.33333333]
""")
add_sample_code(globals()["square"], r"""
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.square(x)
print(out)
# [0.16 0.04 0.01 0.09]
""")
add_sample_code(globals()["lgamma"], r"""
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.lgamma(x)
print(out)
# [1.31452441, 1.76149750, 2.25271273, 1.09579802]
""")
add_sample_code(globals()["softplus"], r"""
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = F.softplus(x)
print(out)
# [0.513015, 0.598139, 0.744397, 0.854355]
""")
add_sample_code(globals()["softsign"], r"""
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = F.softsign(x)
print(out)
# [-0.285714, -0.166667, 0.0909091, 0.230769]
""")
__all__ += ['erf']
_erf_ = generate_layer_fn('erf')
def erf(x, name=None):
locals_var = locals().copy()
kwargs = dict()
for name, val in locals_var.items():
if val is not None:
kwargs[name] = val
return _erf_(**kwargs)
erf.__doc__ = r"""
:strong:`Erf Operator`
For more details, see [Error function](https://en.wikipedia.org/wiki/Error_function).
Equation:
.. math::
out = \\frac{2}{\\sqrt{\\pi}} \\int_{0}^{x}e^{- \\eta^{2}}d\\eta
Args:
x (Tensor): The input tensor, it's data type should be float32, float64.
Returns:
Tensor: The output of Erf op, dtype: float32 or float64, the same as the input, shape: the same as the input.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.erf(x)
print(out)
# [-0.42839236 -0.22270259 0.11246292 0.32862676]
"""
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册