base.py 24.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
from ..wrapped_decorator import signature_safe_contextmanager, wrap_decorator
S
songyouwei 已提交
15
import decorator
16
import contextlib
17 18
import functools
import inspect
19
import sys
20 21 22
import numpy as np
from paddle.fluid import core
from paddle.fluid import framework
H
hong 已提交
23
from paddle.fluid.multiprocess_utils import CleanupFuncRegistrar
M
minqiyang 已提交
24
from .tracer import Tracer
Z
Zeng Jinle 已提交
25
import logging
26
from ..data_feeder import convert_dtype
L
Leo Chen 已提交
27
import warnings
28
from ..framework import _get_paddle_place
29
import paddle
30

31
__all__ = [
32 33
    'no_grad', 'no_grad_', 'grad', 'guard', 'enable_dygraph', 'disable_dygraph',
    'enabled', 'to_variable'
34
]
35 36


37 38 39 40 41 42 43 44 45 46 47
def _switch_to_static_graph_(func):
    def __impl__(*args, **kwargs):
        with framework._dygraph_guard(None):
            return func(*args, **kwargs)

    return __impl__


switch_to_static_graph = wrap_decorator(_switch_to_static_graph_)


48 49 50 51 52 53
@signature_safe_contextmanager
def program_desc_tracing_guard(enable):
    tracer = framework._dygraph_tracer()
    if tracer:
        original_val = tracer._enable_program_desc_tracing
        tracer._enable_program_desc_tracing = enable
54 55 56 57 58
    try:
        yield
    finally:
        if tracer:
            tracer._enable_program_desc_tracing = original_val
59 60


61 62 63
_functional_dygraph_context_manager = None


64 65
@signature_safe_contextmanager
def param_guard(parameters):
66
    # Note: parameters is a reference of self._parameters or self._buffers
67 68 69 70
    if not framework.in_dygraph_mode() and parameters:
        origin_parameters = parameters.copy()
        for name, var_base in parameters.items():
            if isinstance(var_base, core.VarBase):
71 72 73 74 75 76 77 78 79 80 81 82 83
                # Convert ParamBase into Parameter with same attributes in dy2stat.
                if isinstance(var_base, framework.ParamBase):
                    new_var = var_base._to_static_var(to_parameter=True)
                else:
                    # Check whether has been created before.
                    if var_base.name in var_base.block.vars:
                        new_var = var_base.block.vars[var_base.name]
                    # Note(Aurelius84): Convert VarBase in self._buffers into Variabe with
                    # same attributes and set persistable=True to allow saving this var.
                    # Because users can create a VarBase in `__init__`  like a
                    # `mask` Tensor or `hidden_0` in RNN layers, which is equivalent to a Parameter
                    # and necessary for inferring. It will be pruned if it's not necessary for inferring.
                    else:
84 85 86 87
                        # But if its shape is empty while created from `create_variable()`, we consider this buffer
                        # non-persistable. See case of `drop_state` in lstm api.
                        is_persistable = len(var_base.shape) > 0

88
                        new_var = var_base._to_static_var(
89
                            to_parameter=False, persistable=is_persistable)
90 91 92 93 94 95 96
                parameters[name] = new_var
        yield
        parameters.update(origin_parameters)
    else:
        yield


97
def enabled():
98 99 100
    """
    This function checks whether the program runs in dynamic graph mode or not.
    You can enter dynamic graph mode with :ref:`api_fluid_dygraph_guard` api,
101 102
    or enable and disable dynamic graph mode with :ref:`api_fluid_dygraph_enable_dygraph`
    and :ref:`api_fluid_dygraph_disable_dygraph` api .
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120

    **Note**:
        ``fluid.dygraph.enabled`` is the alias of ``fluid.in_dygraph_mode``, and
        ``fluid.in_dygraph_mode`` is recommended to use.

    Returns:
        bool: Whether the program is running in dynamic graph mode.

    Examples:
        .. code-block:: python

            import paddle.fluid as fluid

            fluid.enable_dygraph()  # Now we are in dygragh mode
            print(fluid.dygraph.enabled())  # True
            fluid.disable_dygraph()
            print(fluid.dygraph.enabled())  # False
    """
L
lujun 已提交
121
    return framework.in_dygraph_mode()
122 123


124 125
def enable_dygraph(place=None):
    """
126 127 128 129 130

    .. note::
        Dynamic graph mode is turn ON by default since paddle 2.0.0

    This API turn OFF static graph mode. You can turn ON static graph mode by `enable_static <./disable_dygraph_en.html>`_ .
131 132

    Parameters:
133 134 135
        place(paddle.CPUPlace|paddle.CUDAPlace|str, optional): Place to run dynamic graph. Default: None. Which means that the running place will be 
            determined according to the way of paddle compilation. If ``place`` is string, It can be ``cpu``, and ``gpu:x``, where ``x`` is the
            index of the GPUs.
136 137 138 139 140 141 142

    return:
        None

    Examples:
        .. code-block:: python

143 144 145 146 147 148 149 150
            import paddle
            print(paddle.in_dynamic_mode())  # True, dynamic mode is turn ON by default since paddle 2.0.0

            paddle.enable_static()
            print(paddle.in_dynamic_mode())  # False, Now we are in static mode

            paddle.disable_static()
            print(paddle.in_dynamic_mode())  # True, Now we are in dynamic mode
151 152 153

    """
    global _functional_dygraph_context_manager
S
songyouwei 已提交
154
    if _functional_dygraph_context_manager is None:
155 156
        _functional_dygraph_context_manager = guard(
            place=_get_paddle_place(place))
S
songyouwei 已提交
157
        _functional_dygraph_context_manager.__enter__()
158

H
hong 已提交
159 160 161
        # call disable_dygraph when Python exit
        CleanupFuncRegistrar.register(disable_dygraph)

162 163 164

def disable_dygraph():
    """
165 166 167 168 169

    .. note::
        Dynamic graph mode is turn ON by default since paddle 2.0.0

    This API turn ON static graph mode. You can turn ON static graph mode by `disable_static <./enable_dygraph_en.html>`_ .
170 171 172 173 174 175 176

    return:
        None

    Examples:
        .. code-block:: python

177 178 179 180 181 182 183 184
            import paddle
            print(paddle.in_dynamic_mode())  # True, dynamic mode is turn ON by default since paddle 2.0.0

            paddle.enable_static()
            print(paddle.in_dynamic_mode())  # False, Now we are in static mode

            paddle.disable_static()
            print(paddle.in_dynamic_mode())  # True, Now we are in dynamic mode
185 186 187 188 189 190 191 192

    """
    global _functional_dygraph_context_manager
    if _functional_dygraph_context_manager is not None:
        _functional_dygraph_context_manager.__exit__(*sys.exc_info())
        _functional_dygraph_context_manager = None


193 194 195 196
@signature_safe_contextmanager
def _switch_tracer_mode_guard_(is_train=True):
    tracer = framework._dygraph_tracer()
    if tracer:
197 198
        has_grad = tracer._has_grad
        tracer._has_grad = is_train
199 200 201
        try:
            yield
        finally:
202
            tracer._has_grad = has_grad
203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266
    else:
        yield


def no_grad(func=None):
    """
    :api_attr: imperative

    Create a context which disables dygraph gradient calculation.
    In this mode, the result of every computation will have `stop_gradient=True`.

    Also functions as a decorator. (Make sure to instantiate without parenthesis.)

    Examples:

     .. code-block:: python

        import numpy as np
        import paddle.fluid as fluid

        # use as generator

        data = np.array([[2, 3], [4, 5]]).astype('float32')
        with fluid.dygraph.guard():
            l0 = fluid.Linear(2, 2)  # l0.weight.gradient() is None
            l1 = fluid.Linear(2, 2)
            with fluid.dygraph.no_grad():
                # l1.weight.stop_gradient is False
                tmp = l1.weight * 2  # tmp.stop_gradient is True
            x = fluid.dygraph.to_variable(data)
            y = l0(x) + tmp
            o = l1(y)
            o.backward()
            print(tmp.gradient() is None)  # True
            print(l0.weight.gradient() is None)  # False

        # use as decorator

        @fluid.dygraph.no_grad
        def test_layer():
            with fluid.dygraph.guard():
                inp = np.ones([3, 1024], dtype='float32')
                t = fluid.dygraph.base.to_variable(inp)
                linear1 = fluid.Linear(1024, 4, bias_attr=False)
                linear2 = fluid.Linear(4, 4)
                ret = linear1(t)
                dy_ret = linear2(ret)

        test_layer()

    """
    if func is None:
        return _switch_tracer_mode_guard_(is_train=False)
    else:

        @decorator.decorator
        def __impl__(func, *args, **kwargs):
            with _switch_tracer_mode_guard_(is_train=False):
                return func(*args, **kwargs)

        return __impl__(func)


class no_grad_:
267
    """
268 269
    :api_attr: imperative

270
    Create a context which disables dygraph gradient calculation.
271 272
    In this mode, the result of every computation will have `stop_gradient` set
    to `True`.
273

274
    Also functions as a decorator. (Make sure to use an instance.)
275 276 277 278 279 280

    Examples:

     .. code-block:: python

        import numpy as np
281
        import paddle
282

283 284 285
        # use as generator

        data = np.array([[2, 3], [4, 5]]).astype('float32')
286 287 288
        l0 = paddle.nn.Linear(2, 2)  # l0.weight.gradient() is None
        l1 = paddle.nn.Linear(2, 2)
        with paddle.no_grad():
289 290
            # l1.weight.stop_gradient is False
            tmp = l1.weight * 2  # tmp.stop_gradient is True
291
        x = paddle.to_tensor(data)
292 293 294 295 296
        y = l0(x) + tmp
        o = l1(y)
        o.backward()
        print(tmp.gradient() is None)  # True
        print(l0.weight.gradient() is None)  # False
297 298 299

        # use as decorator

300
        @paddle.no_grad()
301
        def test_layer():
302
            inp = np.ones([3, 1024], dtype='float32')
303 304 305
            t = paddle.to_tensor(inp)
            linear1 = paddle.nn.Linear(1024, 4, bias_attr=False)
            linear2 = paddle.nn.Linear(4, 4)
306 307
            ret = linear1(t)
            dy_ret = linear2(ret)
308 309 310 311

        test_layer()
    """

312
    def __call__(self, func):
S
songyouwei 已提交
313
        @decorator.decorator
314 315
        def _decorate_function(func, *args, **kwargs):
            with self:
316
                return func(*args, **kwargs)
317

318 319 320 321 322 323 324 325 326 327 328 329 330 331 332
        @decorator.decorator
        def _decorate_generator(func, *args, **kwargs):
            gen = func(*args, **kwargs)
            with self:
                for x in gen:
                    yield x

        if inspect.isgeneratorfunction(func):
            return _decorate_generator(func)
        else:
            return _decorate_function(func)

    def __enter__(self):
        tracer = framework._dygraph_tracer()
        if tracer:
333 334
            self.orig = tracer._has_grad
            tracer._has_grad = False
335 336 337 338

    def __exit__(self, *args):
        tracer = framework._dygraph_tracer()
        if tracer:
339
            tracer._has_grad = self.orig
340 341


S
rename  
sneaxiy 已提交
342
@signature_safe_contextmanager
P
Paddle CI 已提交
343
def guard(place=None):
344
    """
345 346
    :api_attr: imperative

347
    This context will create a dygraph context for dygraph to run, using python ``with`` statement.
348

349
    Parameters:
350 351 352 353
        place(fluid.CPUPlace| fluid.CUDAPlace|str, optional): Place to execute dygraph. 
            If None, the running place will be determined according to the way of paddle compilation.
            If ``place`` is string, It can be ``cpu``, ``gpu:x`` and ``xpu:x``, where ``x`` is the
            index of the GPUs or XPUs. Default: None
354 355 356 357 358 359 360 361 362 363 364 365

    return:
        None

    Examples:

     .. code-block:: python

        import numpy as np
        import paddle.fluid as fluid

        with fluid.dygraph.guard():
366
            inp = np.ones([3, 1024], dtype='float32')
367
            t = fluid.dygraph.base.to_variable(inp)
368 369 370 371
            linear1 = fluid.Linear(1024, 4, bias_attr=False)
            linear2 = fluid.Linear(4, 4)
            ret = linear1(t)
            dy_ret = linear2(ret)
372 373

    """
374 375
    train = framework.Program()
    startup = framework.Program()
J
Jiabin Yang 已提交
376
    tracer = Tracer()
377
    VarBase = core.VarBase
M
minqiyang 已提交
378

379
    if place is not None:
380
        expected_place = _get_paddle_place(place)
381 382
    else:
        expected_place = framework._current_expected_place()
M
minqiyang 已提交
383

384 385
    with framework.program_guard(train, startup):
        with framework.unique_name.guard():
L
lujun 已提交
386
            with framework._dygraph_guard(tracer):
387
                with framework._dygraph_place_guard(expected_place):
P
Paddle CI 已提交
388
                    yield
389 390


391 392 393 394
@framework.dygraph_only
def grad(outputs,
         inputs,
         grad_outputs=None,
Z
Zeng Jinle 已提交
395
         retain_graph=None,
396
         create_graph=False,
Z
Zeng Jinle 已提交
397 398
         only_inputs=True,
         allow_unused=False,
399
         no_grad_vars=None):
Z
Zeng Jinle 已提交
400 401 402 403 404 405 406
    ''' 
    .. note::
        **This API is ONLY available in Dygraph mode.**

    This API computes the sum of gradients of `outputs` with respect to each `inputs` .

    Parameters:
407 408 409 410
        outputs (Tensor|list(Tensor)|tuple(Tensor)): the output Tensor or 
            Tensor list/tuple of the graph to compute gradients.
        inputs (Tensor|list(Tensor)|tuple(Tensor)): the input Tensor or 
            Tensor list/tuple of the graph to compute gradients. The returned
Z
Zeng Jinle 已提交
411
            values of this API are the gradients of `inputs` . 
412
        grad_outputs (Tensor|list(Tensor|None)|tuple(Tensor|None), optional): 
Z
Zeng Jinle 已提交
413 414 415 416 417 418
            initial gradient values of `outputs` . If `grad_outputs` is None, 
            the initial gradient values of `outputs` would be Tensors filled with 1; 
            if `grad_outputs` is not None, it must have the same length as `outputs` , 
            and in this case, the initial gradient value of the i-th `outputs` would
            be: (1) a Tensor filled with 1 when the i-th element of `grad_outputs` 
            is None; (2) the i-th element of `grad_outputs` when the i-th element of
419
            `grad_outputs` is a Tensor. Default None.
Z
Zeng Jinle 已提交
420 421 422 423 424 425 426 427 428 429 430
        retain_graph (bool, optional): whether to retain the forward graph which 
            is used to calculate the gradient. When it is True, the graph would 
            be retained, in which way users can calculate backward twice for the 
            same graph. When it is False, the graph would be freed. Default None,
            which means it is equal to `create_graph` . 
        create_graph (bool, optional): whether to create the gradient graphs of
            the computing process. When it is True, higher order derivatives are
            supported to compute; when it is False, the gradient graphs of the
            computing process would be discarded. Default False.
        only_inputs (bool, optional): whether to only compute the gradients of
            `inputs` . If it is False, the gradients of all remaining leaf 
431
            Tensors in the graph would be also computed and accumulated. 
Z
Zeng Jinle 已提交
432 433 434 435
            If it is True, only the gradients of `inputs` would be computed.
            Default True. only_inputs=False is under development, and it is
            not supported yet.    
        allow_unused (bool, optional): whether to raise error or return None if some 
436
            Tensors of `inputs` are unreachable in the graph. If some Tensors of 
Z
Zeng Jinle 已提交
437 438 439
            `inputs` are unreachable in the graph (i.e., their gradients are None),  
            error would be raised if allow_unused=False, or None would be returned as
            their gradients if allow_unused=True. Default False.
440 441
        no_grad_vars (Tensor|list(Tensor)|tuple(Tensor)|set(Tensor), optional): 
            the Tensors whose gradients are not needed to compute. Default None.
Z
Zeng Jinle 已提交
442 443

    Returns:
444 445
        tuple: a tuple of Tensors, whose length is the same as the Tensor number 
        inside `inputs`, and the i-th returned Tensor is the sum of gradients of 
Z
Zeng Jinle 已提交
446 447 448 449 450
        `outputs` with respect to the i-th `inputs`.

    Examples 1:
        .. code-block:: python

451
            import paddle
Z
Zeng Jinle 已提交
452 453

            def test_dygraph_grad(create_graph):
454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479
                x = paddle.ones(shape=[1], dtype='float32')
                x.stop_gradient = False
                y = x * x

                # Since y = x * x, dx = 2 * x
                dx = paddle.grad(
                        outputs=[y],
                        inputs=[x],
                        create_graph=create_graph,
                        retain_graph=True)[0]

                z = y + dx

                # If create_graph = False, the gradient of dx
                # would not be backpropagated. Therefore,
                # z = x * x + dx, and x.gradient() = 2 * x = 2.0

                # If create_graph = True, the gradient of dx
                # would be backpropagated. Therefore,
                # z = x * x + dx = x * x + 2 * x, and
                # x.gradient() = 2 * x + 2 = 4.0

                z.backward()
                return x.gradient()

            print(test_dygraph_grad(create_graph=False)) # [2.]
Z
Zeng Jinle 已提交
480 481 482 483 484
            print(test_dygraph_grad(create_graph=True)) # [4.]

    Examples 2:
        .. code-block:: python

485
            import paddle
Z
Zeng Jinle 已提交
486 487

            def test_dygraph_grad(grad_outputs=None):
488
                x = paddle.to_tensor(2.0)
Z
Zeng Jinle 已提交
489 490 491 492 493 494 495 496 497 498 499 500 501 502 503
                x.stop_gradient = False

                y1 = x * x
                y2 = x * 3 

                # If grad_outputs=None, dy1 = [1], dy2 = [1].
                # If grad_outputs=[g1, g2], then:
                #    - dy1 = [1] if g1 is None else g1
                #    - dy2 = [1] if g2 is None else g2

                # Since y1 = x * x, dx = 2 * x * dy1.
                # Since y2 = x * 3, dx = 3 * dy2.
                # Therefore, the final result would be:
                # dx = 2 * x * dy1 + 3 * dy2 = 4 * dy1 + 3 * dy2.

504
                dx = paddle.grad(
Z
Zeng Jinle 已提交
505 506 507 508 509 510
                    outputs=[y1, y2], 
                    inputs=[x],
                    grad_outputs=grad_outputs)[0]

                return dx.numpy()

511
            grad_value = paddle.to_tensor(4.0)
Z
Zeng Jinle 已提交
512 513 514 515
            # dy1 = [1], dy2 = [1]
            print(test_dygraph_grad(None)) # [7.]

            # dy1 = [1], dy2 = [4]
516
            print(test_dygraph_grad([None, grad_value])) # [16.]
Z
Zeng Jinle 已提交
517 518

            # dy1 = [4], dy2 = [1]
519
            print(test_dygraph_grad([grad_value, None])) # [19.]
Z
Zeng Jinle 已提交
520 521

            # dy1 = [3], dy2 = [4]
522
            grad_y1 = paddle.to_tensor(3.0)
523
            print(test_dygraph_grad([grad_y1, grad_value])) # [24.]
Z
Zeng Jinle 已提交
524 525
	'''

526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562
    def check_in_out(in_out_list, name):
        assert in_out_list is not None, "{} should not be None".format(name)

        if isinstance(in_out_list, (list, tuple)):
            assert len(in_out_list) > 0, "{} cannot be empty".format(name)
            for each_var in in_out_list:
                assert isinstance(
                    each_var,
                    core.VarBase), "Elements of {} must be Variable".format(
                        name)
            return in_out_list
        else:
            assert isinstance(
                in_out_list,
                core.VarBase), "{} must be Variable or list of Variable".format(
                    name)
            return [in_out_list]

    outputs = check_in_out(outputs, 'outputs')
    inputs = check_in_out(inputs, 'inputs')

    if grad_outputs is not None:
        if not isinstance(grad_outputs, (list, tuple)):
            grad_outputs = [grad_outputs]

        for each_var in grad_outputs:
            if each_var is not None:
                assert isinstance(
                    each_var, core.VarBase
                ), "grad_outputs must be None, a Variable or a list containing None or Variables"
    else:
        grad_outputs = []

    if len(grad_outputs) > 0:
        assert len(grad_outputs) == len(
            outputs), "The length of grad_outputs must be equal to outputs"

Z
Zeng Jinle 已提交
563 564 565 566 567 568 569
    if no_grad_vars is None:
        no_grad_vars = []
    elif isinstance(no_grad_vars, core.VarBase):
        no_grad_vars = [no_grad_vars]
    elif isinstance(no_grad_vars, (list, tuple, set)):
        no_grad_vars = list(no_grad_vars)
        for var in no_grad_vars:
570
            assert isinstance(
Z
Zeng Jinle 已提交
571
                var, core.VarBase), "no_grad_vars can only contains Variable"
572 573
    else:
        raise AssertionError(
Z
Zeng Jinle 已提交
574
            "no_grad_vars must be None, Variable or list/tuple/set of Variables")
575 576 577

    assert isinstance(create_graph, bool), "create_graph must be True or False"

Z
Zeng Jinle 已提交
578 579 580 581 582 583 584 585 586 587 588
    if retain_graph is None:
        retain_graph = create_graph

    assert isinstance(retain_graph,
                      bool), "retain_graph must be None, True or False"

    assert isinstance(allow_unused, bool), "allow_unused must be True or False"

    assert isinstance(only_inputs, bool), "only_inputs must be True or False"
    assert only_inputs, "only_inputs=False is not supported yet"

589 590
    place = core.Place()
    place.set_place(framework._current_expected_place())
591 592 593
    return core.dygraph_partial_grad(inputs, outputs, grad_outputs,
                                     no_grad_vars, place, create_graph,
                                     retain_graph, allow_unused, only_inputs)
594 595


596
@framework.dygraph_only
597
def to_variable(value, name=None, zero_copy=None, dtype=None):
598
    r"""
599 600
    :api_attr: imperative

C
chentianyu03 已提交
601 602
    The API will create a ``Variable`` object from 
    tuple, list, numpy\.ndarray or Variable object.
603

604
    Parameters:
C
chentianyu03 已提交
605 606
        value(tuple|list|ndarray|Variable|Tensor): Initial data. 
            Can be a list, tuple, NumPy ndarray, Variable, Tensor.
607 608 609
            The shape can be multi-dimensional. The data type is one of 
            numpy\.{float16, float32, float64, int16, int32, int64, 
            uint8, uint16, complex64, complex128}.
610 611
        name(str, optional): The default value is None. Normally there is no 
            need for user to set this property. For more information, please 
L
Leo Chen 已提交
612
            refer to :ref:`api_guide_Name` . 
613 614
        zero_copy(bool, optional): Whether to share memory with the input numpy 
            array. This parameter only works with CPUPlace and will be set to 
L
Leo Chen 已提交
615
            True when it is None. Default: None. (Note: zero_copy is discarded temporally for some reason.)
616 617 618
        dtype(str, optional): The desired data type of returned ``Variable`` .
            Can be 'bool' , 'float16' , 'float32' , 'float64' , 'int8' , 'int16' , 
            'int32' , 'int64' , 'uint8' . Default: None.
619

620
    Returns:
C
chentianyu03 已提交
621
        Variable : If ``value`` is a tuple/list/numpy\.ndarray object, 
622
            return ``Tensor`` created from the corresponding numpy\.ndarray object, which has 
C
chentianyu03 已提交
623
            same data type and shape with ``value``. 
624

625 626 627 628 629 630 631 632

    Examples:

     .. code-block:: python

        import numpy as np
        import paddle.fluid as fluid

633
        with fluid.dygraph.guard(fluid.CPUPlace()):
634
            x = np.ones([2, 2], np.float32)
635 636 637
            y = fluid.dygraph.to_variable(x, zero_copy=False)
            x[0][0] = -1
            y[0][0].numpy()  # array([1.], dtype=float32)
638
            y = fluid.dygraph.to_variable(x)
639 640
            x[0][0] = 0
            y[0][0].numpy()  # array([0.], dtype=float32)
641 642 643 644
            c = np.array([2+1j, 2])
            z = fluid.dygraph.to_variable(c)
            z.numpy() # array([2.+1.j, 2.+0.j])
            z.dtype # 'complex128'
645 646 647 648 649 650 651

            y = fluid.dygraph.to_variable([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]])
            y.shape     # [3L, 2L]

            y = fluid.dygraph.to_variable(((0.1, 1.2), (2.2, 3.1), (4.9, 5.2)), dtype='int32')
            y.shape     # [3L, 2L]

652
    """
653
    support_type = (list, tuple, np.ndarray, core.VarBase, framework.Variable,
C
chentianyu03 已提交
654
                    core.Tensor, core.LoDTensor)
655 656 657 658
    if not isinstance(value, support_type):
        raise TypeError(
            "The type of 'value' in fluid.dygraph.to_variable must be %s, but received %s."
            % (support_type, type(value)))
C
chentianyu03 已提交
659
    if isinstance(value, (core.VarBase, framework.Variable)):
660 661 662 663
        return value
    elif isinstance(value, (core.Tensor, core.LoDTensor)):
        return core.VarBase(value)
    else:
664 665
        if isinstance(framework._current_expected_place(),
                      framework.core.CPUPlace):
L
Leo Chen 已提交
666 667 668 669 670 671 672 673 674 675 676
            #TODO(zhiqiu): we found two problems when enable zero_copy on CPUPlace.
            # (1): eigen requires 16-bytes alignments, but the data of numpy array may not statisfy. 
            # Details: https://eigen.tuxfamily.org/dox/group__TopicUnalignedArrayAssert.html
            # (2): when used in flask framework, it may result in hang.
            # Details: https://github.com/PaddlePaddle/Paddle/issues/26635
            # So, we temporally diable the zero_copy strategy.
            if zero_copy == True:
                warnings.warn(
                    "Currently, zero_copy is not supported, and it will be discarded."
                )
                zero_copy = False
677 678
        else:
            assert not zero_copy, "zero_copy mode can only be used with CPUPlace"
679 680 681 682 683 684 685 686 687

        if not isinstance(value, np.ndarray):
            value = np.array(value)

        if dtype is not None:
            dtype = convert_dtype(dtype)
            if value.dtype != dtype:
                value = value.astype(dtype)

C
chentianyu03 已提交
688 689 690 691 692 693 694
        py_var = core.VarBase(
            value=value,
            place=framework._current_expected_place(),
            persistable=False,
            zero_copy=zero_copy,
            name=name if name else '')
        return py_var