diff --git a/python/paddle/fluid/tests/unittests/test_logical_op.py b/python/paddle/fluid/tests/unittests/test_logical_op.py index 91d339940d114cd97b6c4664a4faf273408ae491..e2c7cf3a5bb2bf8c2f1be7b8b0fc273d02fc644a 100755 --- a/python/paddle/fluid/tests/unittests/test_logical_op.py +++ b/python/paddle/fluid/tests/unittests/test_logical_op.py @@ -18,8 +18,8 @@ import op_test import unittest import numpy as np import paddle -import paddle.fluid as fluid -from paddle.static import Program, program_guard +from paddle.static import Program, program_guard, Executor +from paddle.framework import _non_static_mode from paddle.fluid.framework import _test_eager_guard SUPPORTED_DTYPES = [ @@ -109,13 +109,13 @@ TEST_META_WRONG_SHAPE_DATA = { def run_static(x_np, y_np, op_str, use_gpu=False, binary_op=True): paddle.enable_static() - startup_program = fluid.Program() - main_program = fluid.Program() + startup_program = Program() + main_program = Program() place = paddle.CPUPlace() - if use_gpu and fluid.core.is_compiled_with_cuda(): + if use_gpu and paddle.is_compiled_with_cuda(): place = paddle.CUDAPlace(0) - exe = fluid.Executor(place) - with fluid.program_guard(main_program, startup_program): + exe = Executor(place) + with program_guard(main_program, startup_program): x = paddle.static.data(name='x', shape=x_np.shape, dtype=x_np.dtype) op = getattr(paddle, op_str) feed_list = {'x': x_np} @@ -132,7 +132,7 @@ def run_static(x_np, y_np, op_str, use_gpu=False, binary_op=True): def run_dygraph(x_np, y_np, op_str, use_gpu=False, binary_op=True): place = paddle.CPUPlace() - if use_gpu and fluid.core.is_compiled_with_cuda(): + if use_gpu and paddle.is_compiled_with_cuda(): place = paddle.CUDAPlace(0) paddle.disable_static(place) op = getattr(paddle, op_str) @@ -147,7 +147,7 @@ def run_dygraph(x_np, y_np, op_str, use_gpu=False, binary_op=True): def run_eager(x_np, y_np, op_str, use_gpu=False, binary_op=True): place = paddle.CPUPlace() - if use_gpu and fluid.core.is_compiled_with_cuda(): + if use_gpu and paddle.is_compiled_with_cuda(): place = paddle.CUDAPlace(0) paddle.disable_static(place) with _test_eager_guard(): @@ -213,16 +213,16 @@ def test_type_error(unit_test, use_gpu, type_str_map): if binary_op: if type_str_map['x'] != type_str_map['y']: unit_test.assertRaises(error_type, op, x=x, y=y) - if not fluid._non_static_mode(): + if not _non_static_mode(): error_type = TypeError unit_test.assertRaises(error_type, op, x=x, y=y, out=1) else: - if not fluid._non_static_mode(): + if not _non_static_mode(): error_type = TypeError unit_test.assertRaises(error_type, op, x=x, out=1) place = paddle.CPUPlace() - if use_gpu and fluid.core.is_compiled_with_cuda(): + if use_gpu and paddle.is_compiled_with_cuda(): place = paddle.CUDAPlace(0) for op_data in TEST_META_OP_DATA: meta_data = dict(op_data) diff --git a/python/paddle/tensor/array.py b/python/paddle/tensor/array.py index 49678443f1f1cb50fd766fc8d1dd1cb81a580509..856b79c2a68940abe7e61f0b67635e81dd883a47 100644 --- a/python/paddle/tensor/array.py +++ b/python/paddle/tensor/array.py @@ -14,7 +14,11 @@ # Define functions about array. -from ..fluid import layers +import paddle +from ..static import Variable +from ..framework import LayerHelper, core, _non_static_mode +from ..fluid.data_feeder import check_type +from ..fluid.data_feeder import check_variable_and_dtype __all__ = [] @@ -43,7 +47,24 @@ def array_length(array): arr_len = paddle.tensor.array_length(arr) print(arr_len) # 1 """ - return layers.array_length(array) + if _non_static_mode(): + assert isinstance( + array, + list), "The 'array' in array_write must be a list in dygraph mode" + return len(array) + + if not isinstance( + array, + Variable) or array.type != core.VarDesc.VarType.LOD_TENSOR_ARRAY: + raise TypeError( + "array should be tensor array vairable in array_length Op") + + helper = LayerHelper('array_length', **locals()) + tmp = helper.create_variable_for_type_inference(dtype='int64') + tmp.stop_gradient = True + helper.append_op( + type='lod_array_length', inputs={'X': [array]}, outputs={'Out': [tmp]}) + return tmp def array_read(array, i): @@ -85,7 +106,32 @@ def array_read(array, i): item = paddle.tensor.array_read(arr, i) print(item) # [[5., 5., 5.]] """ - return layers.array_read(array, i) + if _non_static_mode(): + assert isinstance( + array, + list), "The 'array' in array_read must be list in dygraph mode" + assert isinstance( + i, Variable + ), "The index 'i' in array_read must be Variable in dygraph mode" + assert i.shape == [ + 1 + ], "The shape of index 'i' should be [1] in dygraph mode" + i = i.numpy().item(0) + return array[i] + + check_variable_and_dtype(i, 'i', ['int64'], 'array_read') + helper = LayerHelper('array_read', **locals()) + if not isinstance( + array, + Variable) or array.type != core.VarDesc.VarType.LOD_TENSOR_ARRAY: + raise TypeError("array should be tensor array vairable") + out = helper.create_variable_for_type_inference(dtype=array.dtype) + helper.append_op( + type='read_from_array', + inputs={'X': [array], + 'I': [i]}, + outputs={'Out': [out]}) + return out def array_write(x, i, array=None): @@ -119,7 +165,51 @@ def array_write(x, i, array=None): item = paddle.tensor.array_read(arr, i) print(item) # [[5., 5., 5.]] """ - return layers.array_write(x, i, array) + if _non_static_mode(): + assert isinstance( + x, Variable + ), "The input data 'x' in array_write must be Variable in dygraph mode" + assert isinstance( + i, Variable + ), "The index 'i' in array_write must be Variable in dygraph mode" + assert i.shape == [ + 1 + ], "The shape of index 'i' should be [1] in dygraph mode" + i = i.numpy().item(0) + if array is None: + array = create_array(x.dtype) + assert isinstance( + array, + list), "The 'array' in array_write must be a list in dygraph mode" + assert i <= len( + array + ), "The index 'i' should not be greater than the length of 'array' in dygraph mode" + if i < len(array): + array[i] = x + else: + array.append(x) + return array + + check_variable_and_dtype(i, 'i', ['int64'], 'array_write') + check_type(x, 'x', (Variable), 'array_write') + helper = LayerHelper('array_write', **locals()) + if array is not None: + if not isinstance( + array, + Variable) or array.type != core.VarDesc.VarType.LOD_TENSOR_ARRAY: + raise TypeError( + "array should be tensor array vairable in array_write Op") + if array is None: + array = helper.create_variable( + name="{0}.out".format(helper.name), + type=core.VarDesc.VarType.LOD_TENSOR_ARRAY, + dtype=x.dtype) + helper.append_op( + type='write_to_array', + inputs={'X': [x], + 'I': [i]}, + outputs={'Out': [array]}) + return array def create_array(dtype, initialized_list=None): @@ -151,4 +241,31 @@ def create_array(dtype, initialized_list=None): print(item) # [[5., 5., 5.]] """ - return layers.create_array(dtype, initialized_list) + array = [] + if initialized_list is not None: + if not isinstance(initialized_list, (list, tuple)): + raise TypeError( + "Require type(initialized_list) should be list/tuple, but received {}". + format(type(initialized_list))) + array = list(initialized_list) + + # NOTE: Only support plain list like [x, y,...], not support nested list in static mode. + for val in array: + if not isinstance(val, Variable): + raise TypeError( + "All values in `initialized_list` should be Variable, but recevied {}.". + format(type(val))) + + if _non_static_mode(): + return array + + helper = LayerHelper("array", **locals()) + tensor_array = helper.create_variable( + name="{0}.out".format(helper.name), + type=core.VarDesc.VarType.LOD_TENSOR_ARRAY, + dtype=dtype) + + for val in array: + array_write(x=val, i=array_length(tensor_array), array=tensor_array) + + return tensor_array diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 509ae903f59e48196131e16b0397483bdd7015a7..2fcf9ff4213d4ec2b176fbca9922d6403de80204 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -14,11 +14,12 @@ import numpy as np from ..fluid.layer_helper import LayerHelper -from ..framework import _varbase_creator, _dygraph_tracer +from ..framework import _varbase_creator, _dygraph_tracer, in_dygraph_mode, _non_static_mode from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype from ..static import Variable -from ..fluid.framework import _in_legacy_dygraph, in_dygraph_mode, _non_static_mode -from ..fluid.layers import transpose, cast # noqa: F401 +from ..fluid.framework import _in_legacy_dygraph +from .manipulation import cast + from ..fluid import layers import paddle from paddle.common_ops_import import core @@ -31,6 +32,95 @@ __all__ = [] K_DEFAULT_DIM = 9 +def transpose(x, perm, name=None): + """ + Permute the data dimensions of `input` according to `perm`. + + The `i`-th dimension of the returned tensor will correspond to the + perm[i]-th dimension of `input`. + + Args: + x (Tensor): The input Tensor. It is a N-D Tensor of data types bool, float32, float64, int32. + perm (list|tuple): Permute the input according to the data of perm. + name (str): The name of this layer. It is optional. + + Returns: + Tensor: A transposed n-D Tensor, with data type being bool, float32, float64, int32, int64. + + For Example: + + .. code-block:: text + + x = [[[ 1 2 3 4] [ 5 6 7 8] [ 9 10 11 12]] + [[13 14 15 16] [17 18 19 20] [21 22 23 24]]] + shape(x) = [2,3,4] + + # Example 1 + perm0 = [1,0,2] + y_perm0 = [[[ 1 2 3 4] [13 14 15 16]] + [[ 5 6 7 8] [17 18 19 20]] + [[ 9 10 11 12] [21 22 23 24]]] + shape(y_perm0) = [3,2,4] + + # Example 2 + perm1 = [2,1,0] + y_perm1 = [[[ 1 13] [ 5 17] [ 9 21]] + [[ 2 14] [ 6 18] [10 22]] + [[ 3 15] [ 7 19] [11 23]] + [[ 4 16] [ 8 20] [12 24]]] + shape(y_perm1) = [4,3,2] + + Examples: + + .. code-block:: python + + import paddle + + x = paddle.randn([2, 3, 4]) + x_transposed = paddle.transpose(x, perm=[1, 0, 2]) + print(x_transposed.shape) + # [3L, 2L, 4L] + + """ + if in_dygraph_mode(): + return _C_ops.final_state_transpose(x, perm) + else: + if _in_legacy_dygraph(): + out, _ = _C_ops.transpose2(x, 'axis', perm) + return out + + check_variable_and_dtype(x, 'x', [ + 'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'complex64', + 'complex128' + ], 'transpose') + check_type(perm, 'perm', (list, tuple), 'transpose') + if isinstance(perm, tuple): + perm = list(perm) + if len(perm) != len(x.shape): + raise ValueError( + "Input(perm) is the permutation of dimensions of Input(x), " + "its length should be equal to dimensions of Input(x), " + "but received dimension of Input(x) is %s, " + "the length of Input(perm) is %s." % (len(x.shape), len(perm))) + for idx, dim in enumerate(perm): + if dim >= len(x.shape): + raise ValueError( + "Each element in Input(perm) should be less than Input(x)'s dimension, " + "but %d-th element in Input(perm) is %d which exceeds Input(x)'s " + "dimension %d." % (idx, perm[idx], len(x.shape))) + + helper = LayerHelper('transpose', **locals()) + out = helper.create_variable_for_type_inference(x.dtype) + x_shape = helper.create_variable_for_type_inference(x.dtype) + helper.append_op( + type='transpose2', + inputs={'X': [x]}, + outputs={'Out': [out], + 'XShape': [x_shape]}, + attrs={'axis': perm}) + return out + + def matmul(x, y, transpose_x=False, transpose_y=False, name=None): """ Applies matrix multiplication to two tensors. `matmul` follows diff --git a/python/paddle/tensor/logic.py b/python/paddle/tensor/logic.py index 636b2ef17c6a0eeb166cdaf338ce2405ee5619f1..6a18e1201785a451c0dbd746928ab48cac62c1d1 100755 --- a/python/paddle/tensor/logic.py +++ b/python/paddle/tensor/logic.py @@ -12,29 +12,267 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ..fluid.layer_helper import LayerHelper +import paddle from ..fluid.data_feeder import check_type, check_variable_and_dtype from .layer_function_generator import templatedoc from ..static import Variable -from ..fluid.framework import _in_legacy_dygraph, in_dygraph_mode # TODO: define logic functions of a tensor -import paddle.fluid as fluid -if fluid.framework._in_eager_mode_: - Tensor = fluid.framework.core.eager.Tensor +from ..fluid.framework import _in_eager_mode_ +if _in_eager_mode_: + Tensor = paddle.fluid.framework.core.eager.Tensor else: from ..framework import VarBase as Tensor -from ..fluid.layers import is_empty # noqa: F401 -from ..fluid.layers import logical_and # noqa: F401 -from ..fluid.layers import logical_not # noqa: F401 -from ..fluid.layers import logical_or # noqa: F401 -from ..fluid.layers import logical_xor # noqa: F401 -import paddle + +from ..framework import in_dygraph_mode, _non_static_mode +from ..framework import LayerHelper +from ..fluid.framework import _in_legacy_dygraph +# TODO: define logic functions of a tensor from paddle import _C_ops from paddle.tensor.creation import full __all__ = [] +def _logical_op(op_name, x, y, out=None, name=None, binary_op=True): + if _non_static_mode(): + op = getattr(_C_ops, op_name) + if binary_op: + return op(x, y) + else: + return op(x) + check_variable_and_dtype(x, "x", [ + "bool", "int8", "int16", "int32", "int64", "float32", "float64" + ], op_name) + if y is not None: + check_variable_and_dtype(y, "y", [ + "bool", "int8", "int16", "int32", "int64", "float32", "float64" + ], op_name) + if out is not None: + check_type(out, "out", Variable, op_name) + + helper = LayerHelper(op_name, **locals()) + + if binary_op and x.dtype != y.dtype: + raise ValueError( + "(InvalidArgument) The DataType of %s Op's Variable must be consistent, but received %s and %s." + % (op_name, x.dtype, y.dtype)) + + if out is None: + out = helper.create_variable_for_type_inference(dtype=x.dtype) + + if binary_op: + helper.append_op( + type=op_name, inputs={"X": x, + "Y": y}, outputs={"Out": out}) + else: + helper.append_op(type=op_name, inputs={"X": x}, outputs={"Out": out}) + + return out + + +def logical_and(x, y, out=None, name=None): + r""" + + ``logical_and`` operator computes element-wise logical AND on ``x`` and ``y``, and returns ``out``. ``out`` is N-dim boolean ``Tensor``. + Each element of ``out`` is calculated by + + .. math:: + + out = x \&\& y + + .. note:: + ``paddle.logical_and`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting`. + + Args: + x (Tensor): the input tensor, it's data type should be one of bool, int8, int16, in32, in64, float32, float64. + y (Tensor): the input tensor, it's data type should be one of bool, int8, int16, in32, in64, float32, float64. + out(Tensor): The ``Tensor`` that specifies the output of the operator, which can be any ``Tensor`` that has been created in the program. The default value is None, and a new ``Tensor`` will be created to save the output. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + N-D Tensor. A location into which the result is stored. It's dimension equals with ``x``. + + Examples: + .. code-block:: python + + import paddle + + x = paddle.to_tensor([True]) + y = paddle.to_tensor([True, False, True, False]) + res = paddle.logical_and(x, y) + print(res) # [True False True False] + """ + if in_dygraph_mode(): + return _C_ops.final_state_logical_and(x, y) + + return _logical_op( + op_name="logical_and", x=x, y=y, name=name, out=out, binary_op=True) + + +def logical_or(x, y, out=None, name=None): + """ + + ``logical_or`` operator computes element-wise logical OR on ``x`` and ``y``, and returns ``out``. ``out`` is N-dim boolean ``Tensor``. + Each element of ``out`` is calculated by + + .. math:: + + out = x || y + + .. note:: + ``paddle.logical_or`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting`. + + Args: + x (Tensor): the input tensor, it's data type should be one of bool, int8, int16, in32, in64, float32, float64. + y (Tensor): the input tensor, it's data type should be one of bool, int8, int16, in32, in64, float32, float64. + out(Tensor): The ``Variable`` that specifies the output of the operator, which can be any ``Tensor`` that has been created in the program. The default value is None, and a new ``Tensor`` will be created to save the output. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + N-D Tensor. A location into which the result is stored. It's dimension equals with ``x``. + + Examples: + .. code-block:: python + + import paddle + import numpy as np + + x_data = np.array([True, False], dtype=np.bool).reshape(2, 1) + y_data = np.array([True, False, True, False], dtype=np.bool).reshape(2, 2) + x = paddle.to_tensor(x_data) + y = paddle.to_tensor(y_data) + res = paddle.logical_or(x, y) + print(res) # [[ True True] [ True False]] + """ + if in_dygraph_mode(): + return _C_ops.final_state_logical_or(x, y) + return _logical_op( + op_name="logical_or", x=x, y=y, name=name, out=out, binary_op=True) + + +def logical_xor(x, y, out=None, name=None): + r""" + + ``logical_xor`` operator computes element-wise logical XOR on ``x`` and ``y``, and returns ``out``. ``out`` is N-dim boolean ``Tensor``. + Each element of ``out`` is calculated by + + .. math:: + + out = (x || y) \&\& !(x \&\& y) + + .. note:: + ``paddle.logical_xor`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting`. + + Args: + x (Tensor): the input tensor, it's data type should be one of bool, int8, int16, in32, in64, float32, float64. + y (Tensor): the input tensor, it's data type should be one of bool, int8, int16, in32, in64, float32, float64. + out(Tensor): The ``Tensor`` that specifies the output of the operator, which can be any ``Tensor`` that has been created in the program. The default value is None, and a new ``Tensor`` will be created to save the output. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + N-D Tensor. A location into which the result is stored. It's dimension equals with ``x``. + + Examples: + .. code-block:: python + + import paddle + import numpy as np + + x_data = np.array([True, False], dtype=np.bool).reshape([2, 1]) + y_data = np.array([True, False, True, False], dtype=np.bool).reshape([2, 2]) + x = paddle.to_tensor(x_data) + y = paddle.to_tensor(y_data) + res = paddle.logical_xor(x, y) + print(res) # [[False, True], [ True, False]] + """ + if in_dygraph_mode(): + return _C_ops.final_state_logical_xor(x, y) + + return _logical_op( + op_name="logical_xor", x=x, y=y, name=name, out=out, binary_op=True) + + +@templatedoc() +def logical_not(x, out=None, name=None): + """ + + ``logical_not`` operator computes element-wise logical NOT on ``x``, and returns ``out``. ``out`` is N-dim boolean ``Variable``. + Each element of ``out`` is calculated by + + .. math:: + + out = !x + + Args: + x(Tensor): Operand of logical_not operator. Must be a Tensor of type bool, int8, int16, in32, in64, float32, or float64. + out(Tensor): The ``Tensor`` that specifies the output of the operator, which can be any ``Tensor`` that has been created in the program. The default value is None, and a new ``Tensor` will be created to save the output. + name(str|None): The default value is None. Normally there is no need for users to set this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor: ${out_comment} + + Examples: + .. code-block:: python + + import paddle + + x = paddle.to_tensor([True, False, True, False]) + res = paddle.logical_not(x) + print(res) # [False True False True] + """ + if in_dygraph_mode(): + return _C_ops.final_state_logical_not(x) + return _logical_op( + op_name="logical_not", x=x, y=None, name=name, out=out, binary_op=False) + + +def is_empty(x, name=None): + """ + + Test whether a Tensor is empty. + + Args: + x (Tensor): The Tensor to be tested. + name (str, optional): The default value is ``None`` . Normally users + don't have to set this parameter. For more information, + please refer to :ref:`api_guide_Name` . + + Returns: + Tensor: A bool scalar Tensor. True if 'x' is an empty Tensor. + + Examples: + .. code-block:: python + + import paddle + + input = paddle.rand(shape=[4, 32, 32], dtype='float32') + res = paddle.is_empty(x=input) + print("res:", res) + # ('res:', Tensor: eager_tmp_1 + # - place: CPUPlace + # - shape: [1] + # - layout: NCHW + # - dtype: bool + # - data: [0]) + + """ + if in_dygraph_mode(): + return _C_ops.final_state_is_empty(x) + if _in_legacy_dygraph(): + return _C_ops.is_empty(x) + + check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'], + 'is_empty') + check_type(name, "name", (str, type(None)), "is_empty") + + helper = LayerHelper("is_empty", **locals()) + cond = helper.create_variable_for_type_inference(dtype='bool') + cond.stop_gradient = True + helper.append_op( + type='is_empty', inputs={'X': [x]}, outputs={'Out': [cond]}) + return cond + + def equal_all(x, y, name=None): """ This OP returns the truth value of :math:`x == y`. True if two inputs have the same elements, False otherwise. diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 5c290aa0eb760ae5d7b46c4e9927205fa1a0a307..b2fb9d6c37ff22af702d9b26192792b8024864b6 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -17,8 +17,8 @@ import paddle from ..fluid.layer_helper import LayerHelper from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype from ..fluid import layers -from ..framework import core -from ..fluid.framework import _in_legacy_dygraph, in_dygraph_mode, _non_static_mode +from ..framework import core, in_dygraph_mode, _non_static_mode +from ..fluid.framework import _in_legacy_dygraph from paddle.common_ops_import import convert_np_dtype_to_dtype_ from paddle.common_ops_import import Variable from paddle.common_ops_import import VarDesc @@ -401,7 +401,15 @@ def nonzero(x, as_tuple=False): if paddle.in_dynamic_mode(): outs = _C_ops.where_index(x) else: - outs = layers.where(x) + helper = LayerHelper("where_index", **locals()) + + outs = helper.create_variable_for_type_inference( + dtype=core.VarDesc.VarType.INT64) + + helper.append_op( + type='where_index', + inputs={'Condition': x}, + outputs={'Out': [outs]}) if not as_tuple: return outs @@ -592,10 +600,10 @@ def where(condition, x=None, y=None, name=None): # [3]]),) """ if np.isscalar(x): - x = layers.fill_constant([1], np.array([x]).dtype.name, x) + x = paddle.full([1], x, np.array([x]).dtype.name) if np.isscalar(y): - y = layers.fill_constant([1], np.array([y]).dtype.name, y) + y = paddle.full([1], y, np.array([y]).dtype.name) if x is None and y is None: return nonzero(condition, as_tuple=True)