varbase_patch_methods.py 38.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import inspect
16
import numpy as np
17 18
import warnings
import weakref
19
import sys
20 21

import paddle
22
from .. import framework
姜永久 已提交
23
from ..framework import convert_np_dtype_to_dtype_
24
from .. import core
25
from .. import unique_name
26 27 28 29 30 31 32 33 34
from ..framework import (
    Variable,
    Parameter,
    ParamBase,
    _getitem_impl_,
    _setitem_impl_,
    EagerParamBase,
    in_dygraph_mode,
)
35
from .base import switch_to_static_graph
36
from .math_op_patch import monkey_patch_math_varbase
37
from .parallel import scale_loss
L
Leo Chen 已提交
38
from paddle.fluid.data_feeder import convert_dtype, _PADDLE_DTYPE_2_NUMPY_DTYPE
39
import paddle.utils.deprecated as deprecated
C
chenjian 已提交
40
import paddle.profiler as profiler
41
from paddle.profiler.utils import in_profiler_mode
42
from paddle import _C_ops, _legacy_C_ops
43
from paddle.device import get_all_custom_device_type
44
from paddle.fluid.framework import _global_flags
45

46 47
_grad_scalar = None

48

49
class TensorHookRemoveHelper:
50 51
    """
    A helper class that for removing Tensor gradient's hook.
52
    NOTE(wuweilong):the operation weakref.ref(tensor) will cause some unexpected errors in eager mode.
53 54 55
    """

    def __init__(self, tensor, hook_id):
56
        self._tensor = (
57 58 59
            tensor
            if framework.global_var._in_eager_mode_
            else weakref.ref(tensor)
60
        )
61 62 63 64 65 66 67 68 69
        self._hook_id = hook_id

    def remove(self):
        """
        Remove reference Tensor's hook.

        Returns:
            bool: Return True if removed successfully
        """
70 71 72 73 74
        tensor = (
            self._tensor
            if framework.global_var._in_eager_mode_
            else self._tensor()
        )
75 76 77 78 79 80 81
        if tensor is not None:
            res = tensor._remove_grad_hook(self._hook_id)
            if res is True:
                return True
            else:
                warnings.warn(
                    "The backward hook (ID: %d) of Tensor `%s` you want to remove does not exist or has been removed."
82 83 84
                    % (self._hook_id, tensor.name),
                    RuntimeWarning,
                )
85 86 87
        return False


88 89 90
_already_patch_repr = False


91
def monkey_patch_varbase():
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
    @switch_to_static_graph
    def _to_static_var(self, to_parameter=False, **kwargs):
        """
        **Notes**:
            **This API is ONLY available in Dygraph mode**

        Transform a VarBase into static Variable with same attributes. It's a low level interface used
        in dy2static and shall not be called directly.

        Args:
            to_parameter (bool): It takes effect only if the input a VarBase. If set True,
                                 the VarBase will be converted into framework.Parameters. Otherwise, it will
                                 be converted into framework.Variable. Default False.

        Examples:
            .. code-block:: python

                import paddle.fluid as fluid
                from paddle.fluid.dygraph.base import to_variable
                import numpy as np

                data = np.ones([3, 1024], dtype='float32')
                with fluid.dygraph.guard():
                    var_base = to_variable(data)
                    static_var = var_base._to_static_var()

        """
119

120
        # Note: getattr(self, attr, None) will call x.grad=x.gradient(), but gradient() only available in dygraph.
121
        # It will fail. So, for propery that different between dynamic and static graph, should not getattr(self, attr, None).
122
        attr_not_need_keys = ['grad', 'T', 'place', '_place_str']
123
        param_keys = ['stop_gradient', 'trainable']
J
Jiabin Yang 已提交
124
        if isinstance(self, (ParamBase, EagerParamBase)):
125
            attr_kwargs = self.__dict__.copy()
126 127
            for key in param_keys:
                attr_kwargs[key] = getattr(self, key)
128
        else:
129 130
            attr_names = []
            for name in dir(self):
131
                if name not in attr_not_need_keys:
132 133 134
                    if not inspect.ismethod(
                        getattr(self, name)
                    ) and not name.startswith('_'):
135
                        attr_names.append(name)
136 137 138 139 140 141
            attr_kwargs = {name: getattr(self, name) for name in attr_names}

        attr_keys = ['block', 'shape', 'dtype', 'type', 'name', 'persistable']
        for attr in attr_keys:
            attr_kwargs[attr] = getattr(self, attr, None)

142 143 144 145
        # If specify block, use it instead of self.block
        if 'block' in kwargs:
            attr_kwargs['block'] = kwargs['block']

146 147
        attr_kwargs.update(kwargs)

J
Jiabin Yang 已提交
148
        if to_parameter or isinstance(self, (ParamBase, EagerParamBase)):
149
            del attr_kwargs['persistable']
150 151
            # NOTE(Aurelius84): All parameters should be placed into global block.
            attr_kwargs['block'] = attr_kwargs['block'].program.global_block()
152 153 154 155 156
            static_var = Parameter(**attr_kwargs)
        else:
            static_var = Variable(**attr_kwargs)
        return static_var

157 158 159 160 161
    # TODO(jiabin): move this to cplusplus end if we find some performance issue on it
    @framework.dygraph_only
    def set_value(self, value):
        """
        **Notes**:
T
tianshuo78520a 已提交
162
            **This API is ONLY available in Dygraph mode**
163 164 165 166 167 168 169 170 171 172 173

        Set a new value for this Variable.

        Args:
            value (Variable|np.ndarray): the new value.

        Examples:
            .. code-block:: python

                import paddle.fluid as fluid
                from paddle.fluid.dygraph.base import to_variable
174
                from paddle.nn import Linear
175 176
                import numpy as np

177
                data = np.ones([3, 1024], dtype='float32')
178
                with fluid.dygraph.guard():
179
                    linear = Linear(1024, 4)
180
                    t = to_variable(data)
181
                    linear(t)  # call with default weight
182
                    custom_weight = np.random.randn(1024, 4).astype("float32")
183 184
                    linear.weight.set_value(custom_weight)  # change existing weight
                    out = linear(t)  # call with different weight
185 186

        """
187
        if framework.global_var._in_eager_mode_:
188
            base_tensor = core.eager.Tensor
189 190
        else:
            base_tensor = core.VarBase
191 192 193
        assert isinstance(
            value, (np.ndarray, base_tensor, dict, str)
        ), "Variable set_value function, arguments type only support Variable, numpy, VarBase, dict, string."
S
Steffy-zxf 已提交
194 195 196 197 198

        if isinstance(value, (dict, str)):
            assert len(self) == len(
                value
            ), "Variable length not match, Variable [ {} ] need tensor with length {} but load set tensor with length {}".format(
199 200
                self.name, len(self), len(value)
            )
S
Steffy-zxf 已提交
201 202 203 204 205
            if isinstance(value, dict):
                self.value().set_vocab(value)
            else:
                self.value().set_string_list(value)
        else:
206 207 208 209 210
            assert self.shape == list(
                value.shape
            ), "Variable Shape not match, Variable [ {} ] need tensor with shape {} but load set tensor with shape {}".format(
                self.name, self.shape, value.shape
            )
C
crystal 已提交
211 212 213 214 215

            if isinstance(value, base_tensor):
                dtype = value.dtype
            else:
                dtype = convert_np_dtype_to_dtype_(value.dtype)
216

217 218 219 220 221
            assert (
                self.dtype == dtype
            ), "Variable dtype not match, Variable [ {} ] need tensor with dtype {}  but load tensor with dtype {}".format(
                self.name, self.dtype, dtype
            )
222

223
            # NOTE(wuweilong): self could be VarBase or Tensor, the subsequent behavior are defined in different files
224
            # if self is VarBase, method value() return Variable that bindded in imperative.cc, get_tensor() bindded in pybind.cc
225
            # if self is Tensor, method value() return self that defined in this file, get_tensor() defined in eager_method.cc
226
            # this Interface behavior will be unifed in the future.
227 228 229
            self.value().get_tensor().set(
                value, framework._current_expected_place()
            )
230 231

    @framework.dygraph_only
232
    def backward(self, grad_tensor=None, retain_graph=False):
233
        """
234
        Run backward of current Graph which starts from current Tensor.
235

236 237 238 239
        The new gradient will accumulat on previous gradient.

        You can clear gradient by ``Tensor.clear_grad()`` .

240
        Args:
C
chenjian 已提交
241 242
            grad_tensor(Tensor, optional): initial gradient values of the current Tensor. If `grad_tensor` is None,
            the initial gradient values of the current Tensor would be Tensor filled with 1.0;
243 244 245
            if `grad_tensor` is not None, it must have the same length as the current Tensor.
            Teh default value is None.

246
            retain_graph(bool, optional): If False, the graph used to compute grads will be freed. If you would
247 248 249
                like to add more ops to the built graph after calling this method( :code:`backward` ), set the parameter
                :code:`retain_graph` to True, then the grads will be retained. Thus, seting it to False is much more memory-efficient.
                Defaults to False.
250 251 252 253 254 255
        Returns:
            NoneType: None

        Examples:
            .. code-block:: python

256
                import paddle
257 258 259 260 261 262 263 264 265 266 267 268 269 270
                x = paddle.to_tensor(5., stop_gradient=False)
                for i in range(5):
                    y = paddle.pow(x, 4.0)
                    y.backward()
                    print("{}: {}".format(i, x.grad))
                # 0: [500.]
                # 1: [1000.]
                # 2: [1500.]
                # 3: [2000.]
                # 4: [2500.]

                x.clear_grad()
                print("{}".format(x.grad))
                # 0.
271

272 273 274 275 276 277 278 279 280 281 282
                grad_tensor=paddle.to_tensor(2.)
                for i in range(5):
                    y = paddle.pow(x, 4.0)
                    y.backward(grad_tensor)
                    print("{}: {}".format(i, x.grad))
                # 0: [1000.]
                # 1: [2000.]
                # 2: [3000.]
                # 3: [4000.]
                # 4: [5000.]

283
        """
J
Jiabin Yang 已提交
284
        if framework._non_static_mode():
285 286
            if in_profiler_mode():
                record_event = profiler.RecordEvent(
287 288
                    "Gradient Backward", profiler.TracerEventType.Backward
                )
289
                record_event.begin()
290
            if grad_tensor is not None:
291
                if framework.global_var._in_eager_mode_:
292
                    assert isinstance(
293 294
                        grad_tensor, core.eager.Tensor
                    ), "The type of grad_tensor must be paddle.Tensor"
295 296
                else:
                    assert isinstance(
297 298
                        grad_tensor, paddle.Tensor
                    ), "The type of grad_tensor must be paddle.Tensor"
299 300 301 302 303
                assert (
                    grad_tensor.shape == self.shape
                ), "Tensor shape not match, Tensor of grad_tensor [ {} ] with shape {} mismatch Tensor [ {} ] with shape {}".format(
                    grad_tensor.name, grad_tensor.shape, self.name, self.shape
                )
304

305
            if framework.global_var._in_eager_mode_:
306 307 308 309
                if grad_tensor is None:
                    grad_tensor = []
                else:
                    grad_tensor = [grad_tensor]
310 311 312
            if _grad_scalar:
                # When using amp with Fleet DistributedStrategy, we do loss scaling implicitly.
                self = _grad_scalar.scale(self)
313 314 315 316 317
            if (
                paddle.is_compiled_with_xpu()
                or paddle.is_compiled_with_npu()
                or paddle.is_compiled_with_mlu()
            ):
318
                # TODO(liuyuhui): Currently only for xpu. Will be removed in the future.
319
                scaled_loss = scale_loss(self)
320
                if framework.global_var._in_eager_mode_:
321 322 323
                    core.eager.run_backward(
                        [scaled_loss], grad_tensor, retain_graph
                    )
324
                else:
325 326 327 328 329 330
                    core.dygraph_run_backward(
                        [scaled_loss],
                        [grad_tensor],
                        retain_graph,
                        framework._dygraph_tracer(),
                    )
331
            else:
332
                if framework.global_var._in_eager_mode_:
333 334
                    core.eager.run_backward([self], grad_tensor, retain_graph)
                else:
335 336 337 338 339 340
                    core.dygraph_run_backward(
                        [self],
                        [grad_tensor],
                        retain_graph,
                        framework._dygraph_tracer(),
                    )
341 342
            if in_profiler_mode():
                record_event.end()
343 344
        else:
            raise ValueError(
345 346
                "Variable.backward() is only available in DyGraph mode"
            )
347 348

    @framework.dygraph_only
349 350
    @deprecated(
        since="2.1.0",
351
        level=1,
352
        reason="Please use tensor.grad, which returns the tensor value of the gradient.",
353
    )
354 355
    def gradient(self):
        """
356 357 358 359
        .. warning::
          This API will be deprecated in the future, it is recommended to use
          :code:`x.grad` which returns the tensor value of the gradient.

360
        Get the Gradient of Current Tensor.
361 362

        Returns:
363
            ndarray: Numpy value of the gradient of current Tensor
364 365 366 367

        Examples:
            .. code-block:: python

368
                import paddle
369

370 371 372
                x = paddle.to_tensor(5., stop_gradient=False)
                y = paddle.pow(x, 4.0)
                y.backward()
373
                print("grad of x: {}".format(x.gradient()))
374
                # [500.]
375 376

        """
377
        if framework.global_var._in_eager_mode_:
378
            if self.grad is None:
379
                return None
380 381
            if self.grad.is_selected_rows():
                return (np.array(self.grad.numpy()), np.array(self.grad.rows()))
382 383 384 385
            return self.grad.numpy()
        else:
            if self._grad_ivar() is None:
                return None
386

387 388
            new_ivar = self._grad_ivar()
            # TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op
389
            if (
390
                _global_flags()['FLAGS_npu_storage_format']
391 392
                and 'npu' in get_all_custom_device_type()
            ):
393 394
                new_ivar = paddle.incubate._npu_identity(x=new_ivar, format=-1)
            new_ivar = new_ivar._copy_to(core.CPUPlace(), True)
395
            if self._grad_ivar().type == core.VarDesc.VarType.SELECTED_ROWS:
396 397 398 399
                return (
                    np.array(new_ivar.value().get_selected_rows().get_tensor()),
                    np.array(new_ivar.value().get_selected_rows().rows()),
                )
400 401
            else:
                return np.array(new_ivar.value().get_tensor())
402

403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463
    @framework.dygraph_only
    def register_hook(self, hook):
        """
        Registers a backward hook for current Tensor.

        The hook will be called every time the gradient Tensor of current Tensor is computed.

        The hook should not modify the input gradient Tensor, but it can optionally return
        a new gradient Tensor which will be used in place of current Tensor's gradient.

        The hook should have the following signature:

            hook(grad) -> Tensor or None

        Args:
            hook(function): A backward hook to be registered for Tensor.grad

        Returns:
            TensorHookRemoveHelper: A helper object that can be used to remove the registered hook by calling `remove()` method.

        Examples:
            .. code-block:: python

                import paddle

                # hook function return None
                def print_hook_fn(grad):
                    print(grad)

                # hook function return Tensor
                def double_hook_fn(grad):
                    grad = grad * 2
                    return grad

                x = paddle.to_tensor([0., 1., 2., 3.], stop_gradient=False)
                y = paddle.to_tensor([4., 5., 6., 7.], stop_gradient=False)
                z = paddle.to_tensor([1., 2., 3., 4.])

                # one Tensor can register multiple hooks
                h = x.register_hook(print_hook_fn)
                x.register_hook(double_hook_fn)

                w = x + y
                # register hook by lambda function
                w.register_hook(lambda grad: grad * 2)

                o = z.matmul(w)
                o.backward()
                # print_hook_fn print content in backward
                # Tensor(shape=[4], dtype=float32, place=CUDAPlace(0), stop_gradient=False,
                #        [2., 4., 6., 8.])

                print("w.grad:", w.grad) # w.grad: [1. 2. 3. 4.]
                print("x.grad:", x.grad) # x.grad: [ 4.  8. 12. 16.]
                print("y.grad:", y.grad) # y.grad: [2. 4. 6. 8.]

                # remove hook
                h.remove()
        """
        if self.stop_gradient is True:
            raise RuntimeError(
464 465
                "Cannot register hook on a tensor that stop gradient."
            )
466 467 468 469 470

        hook_id = self._register_grad_hook(hook)
        helper = TensorHookRemoveHelper(self, hook_id)
        return helper

471 472 473 474 475 476 477 478 479
    @framework.dygraph_only
    def _to(self, device=None, dtype=None, blocking=None):

        if device is None and dtype is None and blocking is None:
            return self

        if device is not None:
            if isinstance(device, str):
                device = paddle.device._convert_to_place(device)
480
            elif isinstance(
481 482 483 484 485 486 487 488 489
                device,
                (
                    core.CPUPlace,
                    core.CUDAPlace,
                    core.CUDAPinnedPlace,
                    core.XPUPlace,
                    core.CustomPlace,
                ),
            ):
490 491 492
                pass
            else:
                raise ValueError(
493
                    "device value error, must be str, paddle.CPUPlace(), paddle.CUDAPlace(), paddle.CUDAPinnedPlace(), paddle.XPUPlace() or paddle.CustomPlace(), but the type of device is "
494 495
                    + type(device).__name__
                )
496 497 498 499 500

        if blocking is None:
            blocking = True
        else:
            assert isinstance(
501 502
                blocking, bool
            ), "blocking value error, must be the True, False or None"
503 504 505 506 507 508

        def transform(t, device, dtype, blocking):
            if device is None:
                device = t.place
            if dtype is None:
                dtype = t.dtype
509 510
            if type(dtype) is str:
                dtype = framework.convert_np_dtype_to_dtype_(dtype)
511 512 513

            # 1. gpu place need to determine whether the memory is sufficient for allocation.
            if t.place.is_gpu_place():
514
                size_dtype = core.size_of_dtype(dtype)
515 516 517 518
                # Note(weilong wu): Paddle GPU minimum memory allocation unit is 256 bytes,
                # waiting_alloc_memory will compute the memory space occupied by 't'.
                # Coefficient 1.2 is used to avoid OOM that may occur in this critical state when the memory is just enough.
                waiting_alloc_memory = (
519 520
                    ((t._numel() * size_dtype) / 256 + 1) * 256 * 1.2
                )
521
                gpu_memory_available = core.gpu_memory_available()
522 523 524 525 526 527 528 529 530 531 532 533 534
                if gpu_memory_available < waiting_alloc_memory:
                    # Copy Tensor to cpu
                    t_used = t._copy_to(paddle.CPUPlace(), blocking)
                    # Release memory of t
                    t._clear()
                else:
                    # Tensor still in GPU
                    t_used = t
            else:
                t_used = t

            # 2. cast Tensor to dtype
            if dtype is not None and dtype != t_used.dtype:
535
                with paddle.fluid.framework._dygraph_place_guard(
536 537
                    place=t_used.place
                ):
538
                    t_casted = t_used.cast(dtype=dtype)
539 540 541 542
            else:
                t_casted = t_used

            # 3. Copy casted Tensor(in CPU or GPU) to device
543 544 545 546
            if device is not None and not t_casted.place._equals(device):
                new_t = t_casted._copy_to(device, blocking)
            else:
                new_t = t_casted
547 548 549 550 551 552 553 554 555 556 557 558

            # 4. Share Tensor to origin Tensor
            dst_tensor = t.value().get_tensor()
            src_tensor = new_t.value().get_tensor()
            dst_tensor._share_data_with(src_tensor)

            return t

        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=UserWarning)
            return transform(self, device, dtype, blocking)

559 560 561
    @property
    def grad(self):
        """
562
        .. warning::
C
chenjian 已提交
563
          This API will return the tensor value of the gradient. If you want
564 565 566 567 568 569 570 571 572 573 574
          to get the numpy value of the gradient, you can use :code:`x.grad.numpy()`.

        Get the Gradient of Current Tensor.

        Returns:
            Tensor: the gradient of current Tensor

        Examples:
            .. code-block:: python

                import paddle
575

576 577 578 579 580 581 582
                x = paddle.to_tensor(5., stop_gradient=False)
                y = paddle.pow(x, 4.0)
                y.backward()
                print("grad of x: {}".format(x.grad))
                # Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False, [500.])

        """
583 584 585 586
        msg = (
            'tensor.grad will return the tensor value of the gradient.'
            ' This is an incompatible upgrade for tensor.grad API. '
            ' It\'s return type changes from numpy.ndarray in version 2.0 to paddle.Tensor in version 2.1.0. '
587
            ' If you want to get the numpy value of the gradient, you can use :code:`x.grad.numpy()`'
588
        )
589
        warning_msg = "\033[93m\nWarning:\n%s \033[0m" % (msg)
590 591 592
        # ensure ANSI escape sequences print correctly in cmd and powershell
        if sys.platform.lower() == 'win32':
            warning_msg = "\nWarning:\n%s " % (msg)
593
        warnings.warn(warning_msg)
594
        return self._grad_ivar()
595

596 597 598 599 600 601
    def clear_grad(self):
        """
        The alias of clear_gradient().
        """
        self.clear_gradient()

602 603
    def item(self, *args):
        """
C
chenjian 已提交
604
        Convert element at specific position in Tensor into Python scalars. If the position is not specified, the Tensor must be a
605
        single-element Tensor.
606 607 608 609 610 611 612 613 614

        Args:
            *args(int): The input coordinates. If it's single int, the data in the corresponding order of flattened Tensor will be returned.
                Default: None, and it must be in the case where Tensor has only one element.

        Returns(Python scalar): A Python scalar, whose dtype is corresponds to the dtype of Tensor.

        Raises:
            ValueError: If the Tensor has more than one element, there must be coordinates.
C
chenjian 已提交
615

616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643
        Examples:
            .. code-block:: python

                import paddle

                x = paddle.to_tensor(1)
                print(x.item())             #1
                print(type(x.item()))       #<class 'int'>

                x = paddle.to_tensor(1.0)
                print(x.item())             #1.0
                print(type(x.item()))       #<class 'float'>

                x = paddle.to_tensor(True)
                print(x.item())             #True
                print(type(x.item()))       #<class 'bool'>

                x = paddle.to_tensor(1+1j)
                print(x.item())             #(1+1j)
                print(type(x.item()))       #<class 'complex'>

                x = paddle.to_tensor([[1.1, 2.2, 3.3]])
                print(x.item(2))            #3.3
                print(x.item(0, 2))         #3.3

        """
        return self._getitem_from_offset(*args).item()

644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664
    @property
    def inplace_version(self):
        """
        The inplace version of current Tensor.
        The version number is incremented whenever the current Tensor is modified through an inplace operation.

        **Notes: This is a read-only property**

        Examples:
          .. code-block:: python

            import paddle
            var = paddle.ones(shape=[4, 2, 3], dtype="float32")
            print(var.inplace_version)  # 0

            var[1] = 2.2
            print(var.inplace_version)  # 1

        """
        return self._inplace_version()

665 666
    def __str__(self):
        """
667
        Convert a VarBase object to a readable string.
668

669
        Returns(str): A readable string.
670 671 672 673

        Examples:
            .. code-block:: python

674
                import paddle
675
                x = paddle.rand([2, 5])
676
                print(x)
C
chenjian 已提交
677

678 679 680
                # Tensor(shape=[2, 5], dtype=float32, place=CPUPlace,
                #        [[0.30574632, 0.55739117, 0.30902600, 0.39413780, 0.44830436],
                #         [0.79010487, 0.53972793, 0.09495186, 0.44267157, 0.72112119]])
681
        """
682
        if framework.global_var._in_eager_mode_:
683
            from paddle.tensor.to_string import tensor_to_string
684

685
            return tensor_to_string(self)
686 687
        else:
            from paddle.tensor.to_string import to_string
688

689
            return to_string(self)
690

691 692 693 694 695 696 697 698 699 700 701
    def __deepcopy__(self, memo):
        """
        Deep copy Tensor, it will always performs Tensor copy.

        Examples:
            .. code-block:: python

                import paddle
                import copy
                x = paddle.to_tensor(2.)
                y = copy.deepcopy(x)
C
chenjian 已提交
702

703 704 705 706 707 708 709 710 711 712 713 714 715
                print(x)
                # Tensor(shape=[1], dtype=float32, place=CPUPlace, stop_gradient=True,
                #        [2.])

                print(y)
                # Tensor(shape=[1], dtype=float32, place=CPUPlace, stop_gradient=True,
                #        [2.])

        """
        if not self.is_leaf:
            raise RuntimeError(
                "Only Leaf Tensor support the deepcopy at the moment, non-Leaf Tensors contains graph information that does't support deepcopy"
            )
716
        if framework.global_var._in_eager_mode_:
717
            new_varbase = core.eager.Tensor()
718 719
        else:
            new_varbase = core.VarBase()
720 721 722 723 724
        new_varbase.name = self.name + unique_name.generate("_deepcopy")
        memo[id(self)] = new_varbase
        new_varbase.copy_(self, True)
        return new_varbase

725 726 727
    @property
    def block(self):
        return framework.default_main_program().global_block()
728

729 730
    def __nonzero__(self):
        numel = np.prod(self.shape)
731 732 733
        assert (
            numel == 1
        ), "When Variable is used as the condition of if/while , Variable can only contain one element."
734
        if framework.global_var._in_eager_mode_:
735 736 737 738 739 740
            assert self._is_initialized(), "tensor not initialized"
            return bool(np.all(self.numpy() > 0))
        else:
            tensor = self.value().get_tensor()
            assert tensor._is_initialized(), "tensor not initialized"
            return bool(np.all(tensor.__array__() > 0))
741 742 743 744

    def __bool__(self):
        return self.__nonzero__()

745
    def __array__(self, dtype=None):
746 747
        """
        Returns a numpy array shows the value of current Tensor.
C
chenjian 已提交
748

749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769
        Returns:
            ndarray: The numpy value of current Tensor.

        Returns type:
            ndarray: dtype is same as current Tensor

        Examples:
            .. code-block:: python

                import paddle
                import numpy as np
                x = paddle.randn([2, 2])
                x_array = np.array(x)

                print(type(x_array))      #<class 'numpy.ndarray'>
                print(x_array.shape)      #(2, 2)
        """
        array = self.numpy()
        if dtype:
            array = array.astype(dtype)
        return array
770

W
WeiXin 已提交
771
    def contain_tensor(item):
772
        if not isinstance(item, (tuple, list)):
W
WeiXin 已提交
773 774 775 776
            item = [item]

        for slice_item in item:
            if isinstance(slice_item, slice):
777 778 779 780 781
                if (
                    isinstance(slice_item.start, Variable)
                    or isinstance(slice_item.stop, Variable)
                    or isinstance(slice_item.step, Variable)
                ):
W
WeiXin 已提交
782 783
                    return True
            else:
784 785 786 787
                if (
                    isinstance(slice_item, (Variable, np.ndarray))
                    and Variable.dtype != paddle.bool
                ):
W
WeiXin 已提交
788 789 790
                    return True
        return False

791
    def __getitem__(self, item):
W
WeiXin 已提交
792 793 794 795 796 797
        def is_list_tuple(index, contain_type):
            def _is_list_tuple(item):
                if isinstance(item, (tuple, list)):
                    for s in item:
                        if not _is_list_tuple(s):
                            return False
798 799 800
                else:
                    if type(item) != contain_type:
                        return False
W
WeiXin 已提交
801
                return True
802

W
WeiXin 已提交
803 804 805 806 807 808 809 810
            if not isinstance(index, (tuple, list)):
                return False
            for s in index:
                if not _is_list_tuple(s):
                    return False
            return True

        if contain_tensor(item) or is_list_tuple(item, int):
811 812 813 814 815 816 817 818
            # 1. Call _getitem_impl_ when item contains tensor.
            # Why not call a c++ function ? Because item can't be parsed when it contains tensor.
            return _getitem_impl_(self, item)

        else:
            # 2. Call c++ func getitem_index_not_tensor to speedup.
            return self._getitem_index_not_tensor(item)

W
WeiXin 已提交
819
    def __setitem__(self, item, value):
Z
zyfncg 已提交
820 821 822
        def contain_tensor_or_list(item):
            if not isinstance(item, tuple):
                item = [item]
W
WeiXin 已提交
823

Z
zyfncg 已提交
824 825 826 827 828 829 830 831
            for slice_item in item:
                if isinstance(slice_item, list):
                    return True
                elif isinstance(slice_item, Variable):
                    return True

            return False

832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853
        def is_combine_index(item):
            var_type = None
            item_type = None
            if isinstance(item, (tuple, list)):
                for slice_item in item:
                    if item_type is None:
                        item_type = type(slice_item)
                    else:
                        if type(slice_item) != item_type:
                            return True

                    if isinstance(slice_item, Variable):
                        if var_type is None:
                            var_type = slice_item.dtype
                        else:
                            if var_type != slice_item.dtype:
                                return True
                return False

            return False

        if contain_tensor_or_list(item) and not is_combine_index(item):
Z
zyfncg 已提交
854 855
            # To reuse code with static graph,
            # Call _setitem_impl_ when item contains tensor or list.
W
WeiXin 已提交
856 857 858
            return _setitem_impl_(self, item, value)

        else:
859
            if framework.global_var._in_eager_mode_:
W
wanghuancoder 已提交
860 861 862 863
                return self.__setitem_eager_tensor__(item, value)
            else:
                # Call c++ func __setitem_varbase__ to speedup.
                return self.__setitem_varbase__(item, value)
W
WeiXin 已提交
864

865 866 867 868
    @framework.dygraph_only
    def _set_grad_ivar(self, value):
        if isinstance(self, EagerParamBase):
            self.grad = value
869
            self._unset_fake_empty()
870 871
        else:
            raise TypeError(
872 873
                "_set_grad_ivar is only supported for Parameter Tensor"
            )
874

875 876 877 878
    @framework.dygraph_only
    def value(self):
        return self

J
Jiabin Yang 已提交
879 880 881 882 883 884 885 886
    @framework.dygraph_only
    def _slice(self, begin_idx, end_idx):
        return core.eager.Tensor(self.get_tensor()._slice(begin_idx, end_idx))

    @framework.dygraph_only
    def _numel(self):
        return self.get_tensor()._numel()

B
Baibaifan 已提交
887 888 889 890
    @framework.dygraph_only
    def _clear_data(self):
        self.get_tensor()._clear()

891
    @framework.dygraph_only
892 893
    def _use_gpudnn(self, use_gpudnn=True):
        return self._tensor_use_gpudnn(use_gpudnn)
894

895 896
    @framework.dygraph_only
    def _uva(self, device_id=0):
W
Weilong Wu 已提交
897 898 899 900 901 902 903 904 905 906 907 908 909 910 911
        '''
        Returns self tensor with the UVA(unified virtual addressing).

        Args:
            device_id(int, optional): The destination GPU device id. Default: None, means current device.

        Examples:
            .. code-block:: python

              # required: gpu
              import paddle
              x = paddle.to_tensor([1, 2, 3], place=paddle.CPUPlace())
              x._uva()
              print(x)
        '''
912 913
        self._tensor_uva(device_id)

J
Jiabin Yang 已提交
914 915 916 917 918 919 920 921 922 923 924
    @framework.dygraph_only
    def cpu(self):
        if self.place.is_cpu_place():
            return self
        else:
            res = self._copy_to(core.CPUPlace(), True)
            res.stop_gradient = self.stop_gradient
            res.persistable = self.persistable
            return res

    @framework.dygraph_only
925
    def cuda(self, device_id=None, blocking=True):
926
        if device_id is None:
927 928 929 930 931 932 933 934 935
            res_place = framework._current_expected_place()
            if not isinstance(res_place, core.CUDAPlace):
                res_place = core.CUDAPlace(0)
        elif isinstance(device_id, int):
            res_place = core.CUDAPlace(device_id)
        else:
            raise ValueError("device_id must be int|None")

        if self.place._equals(res_place):
J
Jiabin Yang 已提交
936 937
            return self
        else:
938
            res = self._copy_to(res_place, True)
J
Jiabin Yang 已提交
939 940 941 942
            res.stop_gradient = self.stop_gradient
            res.persistable = self.persistable
            return res

W
wanghuancoder 已提交
943 944 945 946 947 948 949 950 951 952
    @framework.dygraph_only
    def pin_memory(self):
        if self.place.is_cuda_pinned_place():
            return self
        else:
            res = self._copy_to(core.CUDAPinnedPlace(), True)
            res.stop_gradient = self.stop_gradient
            res.persistable = self.persistable
            return res

953 954
    @framework.dygraph_only
    def values(self):
Z
zhangkaihuo 已提交
955 956 957 958 959 960 961 962 963 964 965 966
        """
        **Notes**:
            **This API is ONLY available in Dygraph mode**
        Get the values of current SparseTensor(COO or CSR).

        Returns:
            Tensor: A DenseTensor

        Examples:
            .. code-block:: python

                import paddle
967 968 969 970 971 972
                indices = [[0, 0, 1, 2, 2], [1, 3, 2, 0, 1]]
                values = [1, 2, 3, 4, 5]
                dense_shape = [3, 4]
                sparse_x = paddle.sparse.sparse_coo_tensor(paddle.to_tensor(indices, dtype='int32'), paddle.to_tensor(values, dtype='float32'), shape=dense_shape)
                print(sparse_x.values())
                #[1, 2, 3, 4, 5]
Z
zhangkaihuo 已提交
973
        """
974
        return _C_ops.sparse_values(self)
975 976 977

    @framework.dygraph_only
    def to_dense(self):
Z
zhangkaihuo 已提交
978 979 980 981 982 983 984 985 986 987 988 989
        """
        **Notes**:
            **This API is ONLY available in Dygraph mode**
        Convert the current SparseTensor(COO or CSR) to DenseTensor.

        Returns:
            Tensor: A DenseTensor

        Examples:
            .. code-block:: python

                import paddle
990 991 992 993 994 995 996 997
                indices = [[0, 0, 1, 2, 2], [1, 3, 2, 0, 1]]
                values = [1, 2, 3, 4, 5]
                dense_shape = [3, 4]
                sparse_x = paddle.sparse.sparse_coo_tensor(paddle.to_tensor(indices, dtype='int64'), paddle.to_tensor(values, dtype='float32'), shape=dense_shape)
                dense_x = sparse_x.to_dense()
                #[[0., 1., 0., 2.],
                # [0., 0., 3., 0.],
                # [4., 5., 0., 0.]]
Z
zhangkaihuo 已提交
998 999
        """

1000
        return _C_ops.sparse_to_dense(self)
1001 1002 1003

    @framework.dygraph_only
    def to_sparse_coo(self, sparse_dim):
Z
zhangkaihuo 已提交
1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015
        """
        **Notes**:
            **This API is ONLY available in Dygraph mode**
        Convert the current DenseTensor to SparseTensor in COO format.

        Returns:
            Tensor: A SparseCooTensor

        Examples:
            .. code-block:: python

                import paddle
1016 1017 1018 1019 1020 1021
                dense_x = [[0, 1, 0, 2], [0, 0, 3, 4]]
                dense_x = paddle.to_tensor(dense_x, dtype='float32')
                sparse_x = dense_x.to_sparse_coo(sparse_dim=2)
                #indices=[[0, 0, 1, 1],
                #         [1, 3, 2, 3]],
                #values=[1., 2., 3., 4.]
Z
zhangkaihuo 已提交
1022 1023
        """

1024
        return _C_ops.sparse_to_sparse_coo(self, sparse_dim)
1025

1026 1027 1028
    def __hash__(self):
        return hash(id(self))

1029
    if framework.global_var._in_eager_mode_ and not hasattr(core, "eager"):
1030 1031
        return

1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055
    for method_name, method in (
        ("__bool__", __bool__),
        ("__nonzero__", __nonzero__),
        ("_to_static_var", _to_static_var),
        ("set_value", set_value),
        ("block", block),
        ("backward", backward),
        ("clear_grad", clear_grad),
        ("inplace_version", inplace_version),
        ("gradient", gradient),
        ("register_hook", register_hook),
        ("__str__", __str__),
        ("__repr__", __str__),
        ("__deepcopy__", __deepcopy__),
        ("__module__", "paddle"),
        ("__array__", __array__),
        ("__getitem__", __getitem__),
        ("item", item),
        ("__setitem__", __setitem__),
        ("_to", _to),
        ("values", values),
        ("to_dense", to_dense),
        ("to_sparse_coo", to_sparse_coo),
    ):
1056
        if framework.global_var._in_eager_mode_:
1057
            setattr(core.eager.Tensor, method_name, method)
L
Leo Chen 已提交
1058
        else:
1059 1060
            setattr(core.VarBase, method_name, method)

1061
    if framework.global_var._in_eager_mode_:
1062 1063
        setattr(core.eager.Tensor, "_set_grad_ivar", _set_grad_ivar)
        setattr(core.eager.Tensor, "value", value)
J
Jiabin Yang 已提交
1064 1065
        setattr(core.eager.Tensor, "cpu", cpu)
        setattr(core.eager.Tensor, "cuda", cuda)
W
wanghuancoder 已提交
1066
        setattr(core.eager.Tensor, "pin_memory", pin_memory)
J
Jiabin Yang 已提交
1067 1068
        setattr(core.eager.Tensor, "_slice", _slice)
        setattr(core.eager.Tensor, "_numel", _numel)
1069
        setattr(core.eager.Tensor, "_uva", _uva)
B
Baibaifan 已提交
1070
        setattr(core.eager.Tensor, "_clear_data", _clear_data)
1071
        setattr(core.eager.Tensor, "__hash__", __hash__)
1072
        setattr(core.eager.Tensor, "_use_gpudnn", _use_gpudnn)
1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085
    else:
        setattr(core.VarBase, "__name__", "Tensor")
        setattr(core.VarBase, "grad", grad)

    global _already_patch_repr
    if not _already_patch_repr:
        # NOTE(zhiqiu): pybind11 will set a default __str__ method of enum class.
        # So, we need to overwrite it to a more readable one.
        # See details in https://github.com/pybind/pybind11/issues/2537.
        origin = getattr(core.VarDesc.VarType, "__repr__")

        def dtype_str(dtype):
            if dtype in _PADDLE_DTYPE_2_NUMPY_DTYPE:
1086 1087 1088
                numpy_dtype = _PADDLE_DTYPE_2_NUMPY_DTYPE[dtype]
                if numpy_dtype == 'uint16':
                    numpy_dtype = 'bfloat16'
1089
                prefix = 'paddle.'
1090
                return prefix + numpy_dtype
1091 1092 1093
            else:
                # for example, paddle.fluid.core.VarDesc.VarType.LOD_TENSOR
                return origin(dtype)
L
Leo Chen 已提交
1094

1095 1096
        setattr(core.VarDesc.VarType, "__repr__", dtype_str)
        _already_patch_repr = True
L
Leo Chen 已提交
1097

1098 1099
    # patch math methods for varbase
    monkey_patch_math_varbase()