tensor.py 30.4 KB
Newer Older
1
#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
Y
yuyang18 已提交
9
# Unlessf required by applicable law or agreed to in writing, software
D
dzhwinter 已提交
10 11 12 13 14
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16 17
import numpy
import warnings

Y
Yu Yang 已提交
18
from ..layer_helper import LayerHelper
19 20 21 22 23 24
from ..framework import (
    _current_expected_place,
    convert_np_dtype_to_dtype_,
    _varbase_creator,
    in_dygraph_mode,
)
X
xuwei06 已提交
25
from ..framework import Variable
26
from ..core import VarDesc
27
from .. import core
28
from .layer_function_generator import templatedoc
L
Leo Chen 已提交
29
from . import utils
30 31 32 33 34 35
from ..data_feeder import (
    check_variable_and_dtype,
    check_type,
    check_dtype,
    convert_dtype,
)
36
from paddle.utils import deprecated
37

38
from .utils import check_shape
39
from paddle import _C_ops, _legacy_C_ops
Y
Yu Yang 已提交
40 41

__all__ = [
42 43 44 45 46 47 48 49 50
    'cast',
    'concat',
    'sums',
    'assign',
    'fill_constant_batch_size_like',
    'fill_constant',
    'argmin',
    'argmax',
    'zeros',
Y
Yu Yang 已提交
51 52 53
]


54
def cast(x, dtype):
Y
Yu Yang 已提交
55
    """
S
swtkiwi 已提交
56

57
    This OP takes in the Tensor :attr:`x` with :attr:`x.dtype` and casts it
58 59
    to the output with :attr:`dtype`. It's meaningless if the output dtype
    equals the input dtype, but it's fine if you do so.
Y
Yibing Liu 已提交
60 61

    Args:
62
        x(Tensor): An input N-D Tensor with data type bool, float16,
63
            float32, float64, int32, int64, uint8.
64
        dtype(np.dtype|str): Data type of the output:
65
            bool, float16, float32, float64, int8, int32, int64, uint8.
Y
Yibing Liu 已提交
66 67

    Returns:
68
        Tensor: A Tensor with the same shape as input's.
Y
Yibing Liu 已提交
69 70 71

    Examples:
        .. code-block:: python
F
fengjiayi 已提交
72

73
            import paddle
74

75 76
            x = paddle.to_tensor([2, 3, 4], 'float64')
            y = paddle.cast(x, 'uint8')
Y
Yu Yang 已提交
77
    """
H
hong 已提交
78 79 80
    if in_dygraph_mode():
        if not isinstance(dtype, core.VarDesc.VarType):
            dtype = convert_np_dtype_to_dtype_(dtype)
81
        return _C_ops.cast(x, dtype)
姜永久 已提交
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
    else:
        check_variable_and_dtype(
            x,
            'x',
            [
                'bool',
                'float16',
                'float32',
                'float64',
                'int16',
                'int32',
                'int64',
                'uint8',
                'uint16',
            ],
            'cast',
        )
        check_dtype(
            dtype,
            'dtype',
            [
                'bool',
                'float16',
                'float32',
                'float64',
                'int8',
                'int16',
                'int32',
                'int64',
                'uint8',
                'uint16',
            ],
            'cast',
        )
H
hong 已提交
116

姜永久 已提交
117 118 119 120 121 122 123 124 125 126
        helper = LayerHelper('cast', **locals())
        out = helper.create_variable_for_type_inference(
            dtype=dtype, stop_gradient=x.stop_gradient
        )
        helper.append_op(
            type='cast',
            inputs={'X': [x]},
            outputs={'Out': [out]},
            attrs={'in_dtype': x.dtype, 'out_dtype': out.dtype},
        )
Z
Zhang Ting 已提交
127
        return out
128

Y
Yu Yang 已提交
129

130
def concat(input, axis=0, name=None):
Y
Yu Yang 已提交
131
    """
132
    This OP concatenates the input along the axis.
133 134

    Args:
135
        input(list|tuple|Tensor): ``input`` can be Tensor, Tensor list or Tensor tuple which is with data type
136
            bool, float16, float32, float64, int32, int64. All the Tensors in ``input`` must have the same data type.
137 138
        axis(int|Tensor, optional): Specify the axis to operate on the input Tensors.
            It's a scalar with data type int or a Tensor with shape [1] and data type int32 or int64.
139
            The effective range is [-R, R), where R is Rank(x). When ``axis < 0``, it works the same way
140
            as ``axis+R``. Default is 0.
141 142 143
        name (str, optional): The default value is None. Normally there is no
            need for user to set this property. For more information, please
            refer to :ref:`api_guide_Name`.
144 145

    Returns:
146
        Tensor: A Tensor with the same data type as ``input``.
147 148 149

    Examples:
        .. code-block:: python
F
fengjiayi 已提交
150

151
            import paddle.fluid as fluid
152 153
            import numpy as np

154 155 156 157 158 159
            in1 = np.array([[1, 2, 3],
                            [4, 5, 6]])
            in2 = np.array([[11, 12, 13],
                            [14, 15, 16]])
            in3 = np.array([[21, 22],
                            [23, 24]])
160 161 162 163
            with fluid.dygraph.guard():
                x1 = fluid.dygraph.to_variable(in1)
                x2 = fluid.dygraph.to_variable(in2)
                x3 = fluid.dygraph.to_variable(in3)
164 165
                # When the axis is negative, the real axis is (axis + Rank(x)).
                # As follows, axis is -1, Rank(x) is 2, the real axis is 1
166 167
                out1 = fluid.layers.concat(input=[x1, x2, x3], axis=-1)
                out2 = fluid.layers.concat(input=[x1, x2], axis=0)
168 169 170 171 172 173 174 175
                print(out1.numpy())
                # [[ 1  2  3 11 12 13 21 22]
                #  [ 4  5  6 14 15 16 23 24]]
                print(out2.numpy())
                # [[ 1  2  3]
                #  [ 4  5  6]
                #  [11 12 13]
                #  [14 15 16]]
Y
Yu Yang 已提交
176
    """
177

178 179 180 181 182 183
    if in_dygraph_mode():
        if isinstance(axis, Variable):
            axis = axis.numpy()
            axis = axis.item(0)
        if not isinstance(input, Variable):
            input = [t for t in input if t.shape.count(0) == 0]
184
        out = _C_ops.concat(input, axis)
185
        return out
姜永久 已提交
186 187
    else:
        check_type(input, 'input', (list, tuple, Variable), 'concat')
188
        if not isinstance(input, Variable):
姜永久 已提交
189 190 191 192 193 194 195 196 197 198 199 200 201 202
            for id, x in enumerate(input):
                check_variable_and_dtype(
                    x,
                    'input[' + str(id) + ']',
                    ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
                    'concat',
                )
                if x.dtype != input[0].dtype:
                    raise TypeError(
                        "All the Tensors in the input must have the same data type."
                    )
        else:
            input = [input]
        check_type(axis, 'axis', (int, Variable), 'concat')
203

姜永久 已提交
204 205 206 207 208
        if isinstance(axis, Variable):
            check_dtype(
                axis.dtype,
                'axis',
                ['int32', 'int64'],
209
                'concat',
姜永久 已提交
210
                "The data type of axis must be int32 or int64 when axis is a Tensor",
211
            )
212

姜永久 已提交
213 214 215
        helper = LayerHelper('concat', **locals())
        out = helper.create_variable_for_type_inference(
            dtype=helper.input_dtype()
216
        )
217

姜永久 已提交
218 219 220 221
        if input[0].desc.type() == core.VarDesc.VarType.LOD_TENSOR_ARRAY:
            # NOTE(liym27): Don't remove this if branch!
            # This feature is supported for Dynamic-to-Static, because after transformed, the type of inputs[0]
            # is LOD_TENSOR_ARRAY in some scenarios. And this feature can be used in static mode.
222

姜永久 已提交
223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
            assert len(input) == 1, (
                "If the elements of 'input' in concat are Variable(LoDTensorArray), "
                "number of the elements must be 1, but received %s."
                % len(input)
            )
            out_index = helper.create_variable_for_type_inference(dtype="int32")
            helper.append_op(
                type='tensor_array_to_tensor',
                inputs={'X': input[0]},
                outputs={'Out': [out], 'OutIndex': [out_index]},
                attrs={'axis': axis, 'use_stack': False},
            )
        else:
            inputs = {'X': input}
            attrs = {}
            if isinstance(axis, Variable):
                axis.stop_gradient = True
            attrs['axis'] = axis
241

姜永久 已提交
242 243 244 245 246 247 248
            helper.append_op(
                type='concat',
                inputs=inputs,
                outputs={'Out': [out]},
                attrs=attrs,
            )
        return out
Y
Yu Yang 已提交
249 250


251
def sums(input, out=None):
252
    r"""
253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273
    This function computes the sum of multiple input Tensors elementwisely.

    - Case 1, sum of 3 Tensors

    .. code-block:: text

        # Input Tensors
        x0.shape = [2, 3]
        x0.data = [[1., 2., 3.],
                   [4., 5., 6.]]
        x1.shape = [2, 3]
        x1.data = [[10., 20., 30.],
                   [40., 50., 60.]]
        x2.shape = [2, 3]
        x2.data = [[100., 200., 300.],
                   [400., 500., 600.]]

        # Output Tensor
        out.shape = [2, 3]
        out.data = [[111., 222., 333.],
                    [444., 555., 666.]]
K
kavyasrinet 已提交
274 275

    Args:
276 277 278 279
        input (list): A list of Variables which hold input Tensors with the same
            data type and shape. Optional data types are: float32, float64, int32, int64.
        out (Variable, optional): Output Tensor. It can be any existing Variable.
            The default value is None, then a new Variable will be created and returned.
K
kavyasrinet 已提交
280 281

    Returns:
282 283
        Variable: The sum of inputs. The shape and data type is the same with input. \
            If :code:`out` is not None, the returned value is :code:`out` .
K
kavyasrinet 已提交
284 285

    Examples:
F
fengjiayi 已提交
286
        .. code-block:: python
K
kavyasrinet 已提交
287

288 289 290 291 292 293 294 295 296
            import paddle.fluid as fluid

            x0 = fluid.layers.fill_constant(shape=[16, 32], dtype='int64', value=1)
            x1 = fluid.layers.fill_constant(shape=[16, 32], dtype='int64', value=2)
            x2 = fluid.layers.fill_constant(shape=[16, 32], dtype='int64', value=3)
            x3 = fluid.layers.fill_constant(shape=[16, 32], dtype='int64', value=0)

            # Sum of multiple Tensors, the result is stored to a new Variable sum0 (sum0=x0+x1+x2, the value is [[6, ..., 6], ..., [6, ..., 6]])
            sum0 = fluid.layers.sums(input=[x0, x1, x2])
297

298 299
            # Sum of multiple Tensors, sum1 and x3 represents the same Variable (x3=x0+x1+x2, the value is [[6, ..., 6], ..., [6, ..., 6]])
            sum1 = fluid.layers.sums(input=[x0, x1, x2], out=x3)
Y
Yu Yang 已提交
300
    """
301 302 303
    check_type(input, 'input', (Variable, tuple, list), 'sums')
    if isinstance(input, list) or isinstance(input, tuple):
        for input_section in input:
304 305 306 307 308 309
            check_variable_and_dtype(
                input_section,
                "input",
                ['float16', 'float32', 'float64', 'int32', 'int64'],
                'sums',
            )
310
    else:
311 312 313 314 315 316
        check_variable_and_dtype(
            input,
            "input",
            ['float16', 'float32', 'float64', 'int32', 'int64'],
            'sums',
        )
317

Y
Yu Yang 已提交
318 319
    helper = LayerHelper('sum', **locals())
    if out is None:
X
Xin Pan 已提交
320
        out = helper.create_variable_for_type_inference(
321 322
            dtype=helper.input_dtype()
        )
323
    else:
324 325 326 327 328 329 330 331 332 333
        check_variable_and_dtype(
            out, "out", ['float32', 'float64', 'int32', 'int64'], 'sums'
        )

    helper.append_op(
        type='sum',
        inputs={'X': input},
        outputs={'Out': out},
        attrs={'use_mkldnn': False},
    )
Y
Yu Yang 已提交
334 335 336
    return out


F
fengjiayi 已提交
337
def assign(input, output=None):
338
    """
S
swtkiwi 已提交
339

340
    The OP copies the :attr:`input` to the :attr:`output`.
341

342
    Parameters:
343 344 345 346
        input (Tensor|numpy.ndarray|list|tuple|scalar): A tensor, numpy ndarray, tuple/list of scalar,
            or scalar. Its data type supports float16, float32, float64, int32, int64, and bool.
            Note: the float64 data will be converted to float32 because of current platform protobuf
            data limitation.
347
        output (Tensor, optional): A tensor. If :attr:`output` is None, a new tensor will
348
            be created as :attr:`output`. Default: None.
349 350

    Returns:
351
        Tensor: A tensor with the same shape, data type and value as :attr:`input`.
352 353 354

    Examples:
        .. code-block:: python
355

356
          import paddle
357
          import numpy as np
358
          data = paddle.full(shape=[3, 2], fill_value=2.5, dtype='float64') # [[2.5, 2.5], [2.5, 2.5], [2.5, 2.5]]
359 360 361 362
          array = np.array([[1, 1],
                            [3, 4],
                            [1, 3]]).astype(np.int64)
          result1 = paddle.zeros(shape=[3, 3], dtype='float32')
363 364 365
          paddle.assign(array, result1) # result1 = [[1, 1], [3 4], [1, 3]]
          result2 = paddle.assign(data)  # result2 = [[2.5, 2.5], [2.5, 2.5], [2.5, 2.5]]
          result3 = paddle.assign(np.array([[2.5, 2.5], [2.5, 2.5], [2.5, 2.5]], dtype='float32')) # result3 = [[2.5, 2.5], [2.5, 2.5], [2.5, 2.5]]
366
    """
Y
Yu Yang 已提交
367
    helper = LayerHelper('assign', **locals())
368 369 370 371 372 373
    check_type(
        input,
        'input',
        (Variable, numpy.ndarray, list, tuple, float, int, bool),
        'assign',
    )
374 375
    is_inplace = True if output is not None else False

376 377 378 379
    if numpy.isscalar(input) and not isinstance(input, str):
        input = numpy.array([input])
    elif isinstance(input, (list, tuple)):
        input = numpy.array(input)
380 381
    # NOTE(Aurelius84): Why we judge core.VarBase?
    # In case of @to_static, a VarBase can be as input of `assign`,
姜永久 已提交
382
    # but in_dygraph_mode()==False under @to_static, which means
383 384 385
    # isinstance(VarBase, Variable) == False. It will cause return None
    # after this api.
    if isinstance(input, (Variable, core.VarBase)):
姜永久 已提交
386 387
        if in_dygraph_mode():
            if output is None:
388
                output = _C_ops.assign(input)
C
chentianyu03 已提交
389
            else:
姜永久 已提交
390
                _C_ops.assign_out_(input, output)
391
        else:
392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407
            check_dtype(
                input.dtype,
                'input',
                [
                    'float16',
                    'uint16',
                    'float32',
                    'float64',
                    'int32',
                    'int64',
                    'uint8',
                    'bool',
                ],
                'assign',
                '(When the type of input in assign is Variable.)',
            )
408 409
            if output is None:
                output = helper.create_variable_for_type_inference(
410 411 412 413 414
                    dtype=input.dtype
                )
            helper.append_op(
                type='assign', inputs={'X': [input]}, outputs={'Out': [output]}
            )
X
xuwei06 已提交
415
    elif isinstance(input, numpy.ndarray):
416 417 418 419 420
        # Not support [var, var, ...] currently.
        if len(input.shape) > 0 and any(isinstance(x, Variable) for x in input):
            raise TypeError(
                "Required type(input) numpy.ndarray, but found `list(Variable)` in input."
            )
X
xuwei06 已提交
421
        dtype = convert_np_dtype_to_dtype_(input.dtype)
422 423 424 425 426 427
        if dtype == VarDesc.VarType.FP64:
            # Setting FP64 numpy data is not supported in Paddle, so we
            # use FP32 here
            warnings.warn(
                "paddle.assign doesn't support float64 input now due "
                "to current platform protobuf data limitation, we convert "
428 429
                "it to float32"
            )
430
            dtype = VarDesc.VarType.FP32
431 432
        if dtype == VarDesc.VarType.BOOL:
            value_name = "bool_values"
W
wanghuancoder 已提交
433
            values = [int(v) for v in input.flat]
434
        elif dtype == VarDesc.VarType.FP32:
X
xuwei06 已提交
435
            value_name = "fp32_values"
436
            values = [float(v) for v in input.flat]
437
        elif dtype == VarDesc.VarType.INT32:
X
xuwei06 已提交
438
            value_name = "int32_values"
439
            values = [int(v) for v in input.flat]
440 441 442
        elif dtype == VarDesc.VarType.INT64:
            value_name = "int64_values"
            values = [int(v) for v in input.flat]
X
xuwei06 已提交
443
        else:
444 445
            raise TypeError(
                "When the type of 'input' in assign is numpy.ndarray, "
446
                "the data type of 'input' must be bool, float32, int32 or int64, but "
447 448
                "received %s." % convert_dtype(dtype)
            )
449
        if input.size > 1024 * 1024:
450 451 452 453
            raise ValueError(
                "The size of input is too big. Please consider "
                "saving it to file and 'load_op' to load it"
            )
454 455 456
        if in_dygraph_mode():
            if output is None:
                output = zeros(list(input.shape), dtype)
457 458 459 460 461 462 463
            _C_ops.assign_value_(
                output,
                list(input.shape),
                dtype,
                values,
                _current_expected_place(),
            )
464
        else:
465 466
            if output is None:
                output = helper.create_variable_for_type_inference(
467 468 469 470 471 472 473 474 475 476 477
                    dtype=input.dtype
                )
            helper.append_op(
                type='assign_value',
                outputs={'Out': [output]},
                attrs={
                    'dtype': dtype,
                    'shape': list(input.shape),
                    value_name: values,
                },
            )
X
xuwei06 已提交
478

姜永久 已提交
479
    if is_inplace and in_dygraph_mode():
480
        output._bump_inplace_version()
481

Y
Yu Yang 已提交
482 483 484
    return output


485
def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
Y
Yu Yang 已提交
486
    """
S
swtkiwi 已提交
487

W
wangchaochaohu 已提交
488
    This OP creates a Tensor with specified `shape` and `dtype`, and
T
tianshuo78520a 已提交
489
    initializes it with a constant specified by `value`.
K
kavyasrinet 已提交
490

T
tianshuo78520a 已提交
491
    The attribute `stop_gradient` of the created Tensor is set to True.
492 493

    Args:
494 495 496
        shape(list|tuple|Tensor): Shape of the output Tensor, the data type of ``shape`` is int32 or int64.
            If ``shape`` is a list or tuple, the elements of it should be integers or Tensors with shape [1].
            If ``shape`` is an Tensor, it should be an 1-D Tensor with date type int32 or int64.
W
wangchaochaohu 已提交
497
        dtype(np.dtype|str): Data type of the output Tensor which can
498
            be float16, float32, float64, uint8, int16, int32, int64.
499
        value(bool|float|int|Tensor): The constant value used to initialize
500 501
            the Tensor to be created. If ``value`` is an Tensor, it should be an 1-D Tensor.
        force_cpu(bool, optional): data should be on CPU if it's true, default value is False.
502
        out(Tensor, optional): Optional output which can be any created
503 504
            Tensor that meets the requirements to store the result of operation.
            if ``out`` is None, a new Tensor will be create to store the result.
505 506
        name(str, optional): The default value is None.  Normally there is no need for user to set this
            property.  For more information, please refer to :ref:`api_guide_Name`.
507 508

    Returns:
509
        Tensor: Tensor which is created according to shape and dtype.
W
wangchaochaohu 已提交
510

511 512 513
    Examples:
        .. code-block:: python

514
          import paddle.fluid as fluid
515
          # attr shape is a list which doesn't contain  Tensor.
516 517
          data1 = fluid.layers.fill_constant(shape=[2,1], value=0, dtype='int64') # data1=[[0],[0]]
          data2 = fluid.layers.fill_constant(shape=[2,1], value=5, dtype='int64', out=data1)
518
          # data1=[[5], [5]] data2=[[5], [5]]
519

520
          # attr shape is a list which contains Tensor.
521
          positive_2 = fluid.layers.fill_constant([1], "int32", 2)
522
          data3 = fluid.layers.fill_constant(shape=[1, positive_2], dtype='float32', value=1.5) # data3=[[1.5, 1.5]]
523

524
          # attr shape is a Tensor.
525
          shape = fluid.layers.fill_constant([2], "int32", 2) # shape=[2,2]
526
          data4 = fluid.layers.fill_constant(shape=shape, dtype='bool', value=True) # data4=[[True,True],[True,True]]
527

528
          # attr value is a Tensor.
W
wangchaochaohu 已提交
529 530
          val = fluid.layers.fill_constant([1], "float32", 2.0) # val=[2.0]
          data5 = fluid.layers.fill_constant(shape=[2,1], value=val, dtype='float32') #data5=[[2.0],[2.0]]
Y
Yu Yang 已提交
531
    """
532

533 534 535 536 537
    if in_dygraph_mode():
        place = _current_expected_place()
        if force_cpu:
            place = core.CPUPlace()
        if isinstance(shape, (list, tuple)):
538
            shape = utils.convert_shape_to_list(shape)
539 540 541 542 543

        if not isinstance(dtype, core.VarDesc.VarType):
            dtype = convert_np_dtype_to_dtype_(dtype)

        if out is None:
544
            out = _C_ops.full(shape, float(value), dtype, place)
545 546 547
            out.stop_gradient = True
            return out

548 549
        if out is not None:
            # final state mode is support out is not None.
550
            _C_ops.full_(out, shape, float(value), dtype, place)
551 552
            out.stop_gradient = True
            return out
姜永久 已提交
553
    else:
554 555 556 557 558 559 560 561 562 563
        attrs = {'force_cpu': force_cpu}
        dtype = convert_dtype(dtype)
        if not isinstance(value, Variable):
            if dtype in ['uint8', 'int16', 'int32', 'int64']:
                attrs['str_value'] = str(int(value))
                attrs['value'] = int(value)
            else:
                attrs['str_value'] = str(float(value))
                attrs['value'] = float(value)

姜永久 已提交
564 565
        helper = LayerHelper("fill_constant", **locals())
        inputs = {}
566
        if isinstance(value, Variable):
姜永久 已提交
567 568 569 570 571 572 573
            if convert_dtype(value.dtype) != dtype:
                value = cast(value, dtype)
            inputs['ValueTensor'] = value

        check_shape(shape)
        check_dtype(
            dtype,
574
            'dtype',
姜永久 已提交
575 576 577 578 579 580 581 582 583 584 585 586 587
            [
                'bool',
                'float16',
                'float32',
                'float64',
                'uint8',
                'int16',
                'int32',
                'int64',
                'complex64',
                'complex128',
            ],
            'fill_constant',
588
        )
姜永久 已提交
589
        check_type(shape, 'shape', (Variable, list, tuple), 'fill_constant')
590

姜永久 已提交
591 592 593 594
        if out is not None:
            check_variable_and_dtype(
                out, 'out', [convert_dtype(dtype)], 'fill_constant'
            )
595

姜永久 已提交
596 597 598
        helper = LayerHelper("fill_constant", **locals())
        utils.get_shape_tensor_inputs(
            inputs=inputs, attrs=attrs, shape=shape, op_type='fill_constant'
599
        )
600

姜永久 已提交
601 602 603 604 605 606 607 608 609 610 611 612
        if out is None:
            out = helper.create_variable_for_type_inference(dtype=dtype)
        attrs['dtype'] = out.dtype
        helper.append_op(
            type='fill_constant',
            inputs=inputs,
            outputs={'Out': [out]},
            attrs=attrs,
            stop_gradient=True,
        )
        out.stop_gradient = True
        return out
Y
Yu Yang 已提交
613 614


615
@deprecated(since='1.8.0', update_to="paddle.fluid.layers.fill_constant")
Y
yuyang18 已提交
616
@templatedoc()
617 618 619 620 621 622 623 624 625
def fill_constant_batch_size_like(
    input,
    shape,
    dtype,
    value,
    input_dim_idx=0,
    output_dim_idx=0,
    force_cpu=False,
):
626
    """
T
tianshuo78520a 已提交
627
    This OP creates a Tesnor according the shape and dtype, and initializes the
W
wangchaochaohu 已提交
628 629 630 631
    Tensor with the constants provided in ``value``. When the input is LoDTensor
    and the input_dim_idx is 0, the output_dim_idx dimension is set to the value
    of the batch_size input by the input, the Stop_gradient attribute of the created
    Tensor is False by default.
632 633

    Args:
W
wangchaochaohu 已提交
634 635 636 637 638
        input(Variable): Tensor which data type is float32, float64, int32 and int64.
        shape(list): The shape of Tensor to be created, Tensor's shape may be changed
            according the input.
        dtype(np.dtype|core.VarDesc.VarType|str): The data type of created Tensor which
            can be float32, float64, int32, int64.
639
        value(float|int): The constant value used to initialize the Tensor to be created.
W
wangchaochaohu 已提交
640 641 642 643 644
        input_dim_idx(int): When the value is 0 and the input is LoDTensor, the output_dim_idx
            dimension of the created Tensor is set to the batch_size value of input.
            The default value is 0.
        output_dim_idx(int): Used to specify which dimension of Tensor is created to be set
            the value of batch_size of input Tensor. The default value is 0.
T
tianshuo78520a 已提交
645
        force_cpu(bool): data should be on CPU if it's true, default value is False.
Y
yuyang18 已提交
646 647

    Returns:
W
wangchaochaohu 已提交
648
        Variable: Tensor which will be created according to dtype.
H
haowang101779990 已提交
649 650 651 652 653

    Examples:

        .. code-block:: python

654
             import paddle.fluid as fluid
W
wangchaochaohu 已提交
655
             like = fluid.layers.fill_constant(shape=[1,2], value=10, dtype='int64') #like=[[10, 10]]
W
wangchaochaohu 已提交
656
             data = fluid.layers.fill_constant_batch_size_like(
W
wangchaochaohu 已提交
657
                    input=like, shape=[1], value=0, dtype='int64') #like=[[10, 10]] data=[0]
H
haowang101779990 已提交
658

659
    """
660 661 662 663 664 665 666
    if in_dygraph_mode():
        if not isinstance(dtype, core.VarDesc.VarType):
            dtype = convert_np_dtype_to_dtype_(dtype)

        place = _current_expected_place()
        if force_cpu:
            place = core.CPUPlace()
667 668 669
        out = _C_ops.full_batch_size_like(
            input, shape, dtype, value, input_dim_idx, output_dim_idx, place
        )
670 671
        out.stop_gradient = True
        return out
672
    else:
姜永久 已提交
673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694
        helper = LayerHelper("fill_constant_batch_size_like", **locals())
        out = helper.create_variable_for_type_inference(dtype=dtype)
        attrs = {
            'shape': shape,
            'dtype': out.dtype,
            'value': float(value),
            'input_dim_idx': input_dim_idx,
            'output_dim_idx': output_dim_idx,
            'force_cpu': force_cpu,
        }
        if convert_dtype(dtype) in ['int64', 'int32']:
            attrs['str_value'] = str(int(value))
        else:
            attrs['str_value'] = str(float(value))
        helper.append_op(
            type='fill_constant_batch_size_like',
            inputs={'Input': input},
            outputs={'Out': [out]},
            attrs=attrs,
        )
        out.stop_gradient = True
        return out
Y
Yu Yang 已提交
695 696


S
sneaxiy 已提交
697 698
def argmin(x, axis=0):
    """
699 700 701
        :alias_main: paddle.argmin
        :alias: paddle.argmin,paddle.tensor.argmin,paddle.tensor.search.argmin
        :old_api: paddle.fluid.layers.argmin
S
swtkiwi 已提交
702

S
sneaxiy 已提交
703 704
    **argmin**

705 706
    This OP computes the indices of the min elements of the input tensor's
    element along the provided axis.
S
sneaxiy 已提交
707 708

    Args:
709 710 711 712 713
        x(Variable): An input N-D Tensor with type float32, float64, int16,
            int32, int64, uint8.
        axis(int, optional): Axis to compute indices along. The effective range
            is [-R, R), where R is Rank(x). when axis<0, it works the same way
            as axis+R. Default is 0.
F
fengjiayi 已提交
714

S
sneaxiy 已提交
715
    Returns:
716
        Variable: A Tensor with data type int64.
F
fengjiayi 已提交
717

S
sneaxiy 已提交
718 719
    Examples:
        .. code-block:: python
F
fengjiayi 已提交
720

721
            import paddle.fluid as fluid
722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748
            import numpy as np

            in1 = np.array([[[5,8,9,5],
                            [0,0,1,7],
                            [6,9,2,4]],
                            [[5,2,4,2],
                            [4,7,7,9],
                            [1,7,0,6]]])
            with fluid.dygraph.guard():
                x = fluid.dygraph.to_variable(in1)
                out1 = fluid.layers.argmin(x=x, axis=-1)
                out2 = fluid.layers.argmin(x=x, axis=0)
                out3 = fluid.layers.argmin(x=x, axis=1)
                out4 = fluid.layers.argmin(x=x, axis=2)
                print(out1.numpy())
                # [[0 0 2]
                #  [1 0 2]]
                print(out2.numpy())
                # [[0 1 1 1]
                #  [0 0 0 0]
                #  [1 1 1 0]]
                print(out3.numpy())
                # [[1 1 1 2]
                #  [2 0 2 0]]
                print(out4.numpy())
                # [[0 0 2]
                #  [1 0 2]]
S
sneaxiy 已提交
749
    """
750
    check_variable_and_dtype(
751 752 753 754 755
        x,
        'x',
        ['float32', 'float64', 'uint8', 'int16', 'int32', 'int64'],
        'argmin',
    )
S
sneaxiy 已提交
756
    helper = LayerHelper("arg_min", **locals())
X
Xin Pan 已提交
757
    out = helper.create_variable_for_type_inference(VarDesc.VarType.INT64)
758 759 760 761 762 763
    helper.append_op(
        type='arg_min',
        inputs={'X': x},
        outputs={'Out': [out]},
        attrs={'axis': axis},
    )
764
    out.stop_gradient = True
S
sneaxiy 已提交
765 766 767 768 769 770 771
    return out


def argmax(x, axis=0):
    """
    **argmax**

772 773
    This OP computes the indices of the max elements of the input tensor's
    element along the provided axis.
S
sneaxiy 已提交
774 775

    Args:
776 777 778 779 780
        x(Variable): An input N-D Tensor with type float32, float64, int16,
            int32, int64, uint8.
        axis(int, optional): Axis to compute indices along. The effective range
            is [-R, R), where R is Rank(x). when axis<0, it works the same way
            as axis+R. Default is 0.
F
fengjiayi 已提交
781

S
sneaxiy 已提交
782
    Returns:
783
        Variable: A Tensor with data type int64.
F
fengjiayi 已提交
784

S
sneaxiy 已提交
785 786
    Examples:
        .. code-block:: python
F
fengjiayi 已提交
787

788
            import paddle.fluid as fluid
789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815
            import numpy as np

            in1 = np.array([[[5,8,9,5],
                            [0,0,1,7],
                            [6,9,2,4]],
                            [[5,2,4,2],
                            [4,7,7,9],
                            [1,7,0,6]]])
            with fluid.dygraph.guard():
                x = fluid.dygraph.to_variable(in1)
                out1 = fluid.layers.argmax(x=x, axis=-1)
                out2 = fluid.layers.argmax(x=x, axis=0)
                out3 = fluid.layers.argmax(x=x, axis=1)
                out4 = fluid.layers.argmax(x=x, axis=2)
                print(out1.numpy())
                # [[2 3 1]
                #  [0 3 1]]
                print(out2.numpy())
                # [[0 0 0 0]
                #  [1 1 1 1]
                #  [0 0 0 1]]
                print(out3.numpy())
                # [[2 2 0 1]
                #  [0 1 1 1]]
                print(out4.numpy())
                # [[2 3 1]
                #  [0 3 1]]
S
sneaxiy 已提交
816
    """
817
    check_variable_and_dtype(
818 819 820 821 822
        x,
        'x',
        ['float32', 'float64', 'uint8', 'int16', 'int32', 'int64'],
        'argmax',
    )
S
sneaxiy 已提交
823
    helper = LayerHelper("arg_max", **locals())
X
Xin Pan 已提交
824
    out = helper.create_variable_for_type_inference(VarDesc.VarType.INT64)
825 826 827 828 829 830
    helper.append_op(
        type='arg_max',
        inputs={'X': x},
        outputs={'Out': [out]},
        attrs={'axis': axis},
    )
831
    out.stop_gradient = True
S
sneaxiy 已提交
832 833 834
    return out


835
def zeros(shape, dtype, force_cpu=False, name=None):
Y
Yu Yang 已提交
836
    """
837 838
    The OP creates a tensor of specified :attr:`shape` and :attr:`dtype`, and fills it with 0.
    Its :attr:`stop_gradient` will be set to True to stop gradient computation.
839

840
    Parameters:
841
        shape(tuple|list|Tensor): Shape of output Tensor, the data type of ``shape`` is int32 or int64.
W
wangchaochaohu 已提交
842
        dtype (np.dtype|str): Data type of output Tensor, it supports
843
            bool, float16, float32, float64, int32 and int64.
844 845
        force_cpu (bool, optional): Whether force to store the output Tensor in CPU memory.
            If :attr:`force_cpu` is False, the output Tensor will be stored in running device memory.
846
            Default: False.
847 848
        name(str, optional): The default value is None.  Normally there is no need for user to set this
            property.  For more information, please refer to :ref:`api_guide_Name`.
849 850

    Returns:
851
        Tensor: A tensor of data type :attr:`dtype` with shape :attr:`shape` and all elements set to 0.
852 853 854 855

    Examples:
        .. code-block:: python

856
          import paddle.fluid as fluid
857
          data = fluid.layers.zeros(shape=[3, 2], dtype='float32') # [[0., 0.], [0., 0.], [0., 0.]]
858

859 860 861
          # shape is a Tensor
          shape = fluid.layers.fill_constant(shape=[2], dtype='int32', value=2)
          data1 = fluid.layers.zeros(shape=shape, dtype='int32') #[[0, 0], [0, 0]]
Y
Yu Yang 已提交
862 863
    """
    return fill_constant(value=0.0, **locals())