base.py 25.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
from ..wrapped_decorator import signature_safe_contextmanager, wrap_decorator
S
songyouwei 已提交
15
import decorator
16
import contextlib
17 18
import functools
import inspect
19
import sys
20 21 22
import numpy as np
from paddle.fluid import core
from paddle.fluid import framework
H
hong 已提交
23
from paddle.fluid.multiprocess_utils import CleanupFuncRegistrar
M
minqiyang 已提交
24
from .tracer import Tracer
Z
Zeng Jinle 已提交
25
import logging
26
from ..data_feeder import convert_dtype
L
Leo Chen 已提交
27
import warnings
28
from ..framework import _get_paddle_place
29
import paddle
30

31
__all__ = [
32 33
    'no_grad', 'no_grad_', 'grad', 'guard', 'enable_dygraph', 'disable_dygraph',
    'enabled', 'to_variable'
34
]
35 36


37 38 39 40 41 42 43 44 45 46 47
def _switch_to_static_graph_(func):
    def __impl__(*args, **kwargs):
        with framework._dygraph_guard(None):
            return func(*args, **kwargs)

    return __impl__


switch_to_static_graph = wrap_decorator(_switch_to_static_graph_)


48 49 50 51 52 53
@signature_safe_contextmanager
def program_desc_tracing_guard(enable):
    tracer = framework._dygraph_tracer()
    if tracer:
        original_val = tracer._enable_program_desc_tracing
        tracer._enable_program_desc_tracing = enable
54 55 56 57 58
    try:
        yield
    finally:
        if tracer:
            tracer._enable_program_desc_tracing = original_val
59 60


61 62 63
_functional_dygraph_context_manager = None


64 65
@signature_safe_contextmanager
def param_guard(parameters):
66
    from paddle.fluid.dygraph.dygraph_to_static.program_translator import in_declarative_mode
67
    # Note: parameters is a reference of self._parameters or self._buffers
68
    if in_declarative_mode() and not framework.in_dygraph_mode() and parameters:
69 70
        origin_parameters = parameters.copy()
        for name, var_base in parameters.items():
71 72 73 74 75
            if isinstance(var_base, list):
                new_var = [_convert_into_variable(var) for var in var_base]
            else:
                new_var = _convert_into_variable(var_base)
            parameters[name] = new_var
76 77 78 79 80 81
        yield
        parameters.update(origin_parameters)
    else:
        yield


82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
def _convert_into_variable(var_base):
    """
    Convert Varbase into Variable.
    """
    if isinstance(var_base, core.VarBase):
        # Check whether has been created before.
        new_var = var_base.block._find_var_recursive(var_base.name)
        if new_var is not None:
            assert isinstance(new_var, framework.Variable)
        # Convert ParamBase into Parameter with same attributes in dy2stat.
        elif isinstance(var_base, framework.ParamBase):
            new_var = var_base._to_static_var(to_parameter=True)
        else:
            # Note(Aurelius84): Convert VarBase in self._buffers into Variable with
            # same attributes and set persistable=True to allow saving this var.
            # Because users can create a VarBase in `__init__`  like a
            # `mask` Tensor or `hidden_0` in RNN layers, which is equivalent to a Parameter
            # and necessary for inferring. It will be pruned if it's not necessary for inferring.

            # But if its shape is empty while created from `create_variable()`, we consider this buffer
            # non-persistable. See case of `drop_state` in lstm api.
            is_persistable = len(var_base.shape) > 0

            new_var = var_base._to_static_var(
                to_parameter=False, persistable=is_persistable)
        return new_var
    else:
        return var_base


112
def enabled():
113 114 115
    """
    This function checks whether the program runs in dynamic graph mode or not.
    You can enter dynamic graph mode with :ref:`api_fluid_dygraph_guard` api,
116 117
    or enable and disable dynamic graph mode with :ref:`api_fluid_dygraph_enable_dygraph`
    and :ref:`api_fluid_dygraph_disable_dygraph` api .
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135

    **Note**:
        ``fluid.dygraph.enabled`` is the alias of ``fluid.in_dygraph_mode``, and
        ``fluid.in_dygraph_mode`` is recommended to use.

    Returns:
        bool: Whether the program is running in dynamic graph mode.

    Examples:
        .. code-block:: python

            import paddle.fluid as fluid

            fluid.enable_dygraph()  # Now we are in dygragh mode
            print(fluid.dygraph.enabled())  # True
            fluid.disable_dygraph()
            print(fluid.dygraph.enabled())  # False
    """
L
lujun 已提交
136
    return framework.in_dygraph_mode()
137 138


139 140
def enable_dygraph(place=None):
    """
141 142 143 144 145

    .. note::
        Dynamic graph mode is turn ON by default since paddle 2.0.0

    This API turn OFF static graph mode. You can turn ON static graph mode by `enable_static <./disable_dygraph_en.html>`_ .
146 147

    Parameters:
148 149 150
        place(paddle.CPUPlace|paddle.CUDAPlace|str, optional): Place to run dynamic graph. Default: None. Which means that the running place will be 
            determined according to the way of paddle compilation. If ``place`` is string, It can be ``cpu``, and ``gpu:x``, where ``x`` is the
            index of the GPUs.
151 152 153 154 155 156 157

    return:
        None

    Examples:
        .. code-block:: python

158 159 160 161 162 163 164 165
            import paddle
            print(paddle.in_dynamic_mode())  # True, dynamic mode is turn ON by default since paddle 2.0.0

            paddle.enable_static()
            print(paddle.in_dynamic_mode())  # False, Now we are in static mode

            paddle.disable_static()
            print(paddle.in_dynamic_mode())  # True, Now we are in dynamic mode
166 167 168

    """
    global _functional_dygraph_context_manager
S
songyouwei 已提交
169
    if _functional_dygraph_context_manager is None:
170 171
        _functional_dygraph_context_manager = guard(
            place=_get_paddle_place(place))
S
songyouwei 已提交
172
        _functional_dygraph_context_manager.__enter__()
173

H
hong 已提交
174 175 176
        # call disable_dygraph when Python exit
        CleanupFuncRegistrar.register(disable_dygraph)

177 178 179

def disable_dygraph():
    """
180 181 182 183 184

    .. note::
        Dynamic graph mode is turn ON by default since paddle 2.0.0

    This API turn ON static graph mode. You can turn ON static graph mode by `disable_static <./enable_dygraph_en.html>`_ .
185 186 187 188 189 190 191

    return:
        None

    Examples:
        .. code-block:: python

192 193 194 195 196 197 198 199
            import paddle
            print(paddle.in_dynamic_mode())  # True, dynamic mode is turn ON by default since paddle 2.0.0

            paddle.enable_static()
            print(paddle.in_dynamic_mode())  # False, Now we are in static mode

            paddle.disable_static()
            print(paddle.in_dynamic_mode())  # True, Now we are in dynamic mode
200 201 202 203 204 205 206 207

    """
    global _functional_dygraph_context_manager
    if _functional_dygraph_context_manager is not None:
        _functional_dygraph_context_manager.__exit__(*sys.exc_info())
        _functional_dygraph_context_manager = None


208 209 210 211
@signature_safe_contextmanager
def _switch_tracer_mode_guard_(is_train=True):
    tracer = framework._dygraph_tracer()
    if tracer:
212 213
        has_grad = tracer._has_grad
        tracer._has_grad = is_train
214 215 216
        try:
            yield
        finally:
217
            tracer._has_grad = has_grad
218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281
    else:
        yield


def no_grad(func=None):
    """
    :api_attr: imperative

    Create a context which disables dygraph gradient calculation.
    In this mode, the result of every computation will have `stop_gradient=True`.

    Also functions as a decorator. (Make sure to instantiate without parenthesis.)

    Examples:

     .. code-block:: python

        import numpy as np
        import paddle.fluid as fluid

        # use as generator

        data = np.array([[2, 3], [4, 5]]).astype('float32')
        with fluid.dygraph.guard():
            l0 = fluid.Linear(2, 2)  # l0.weight.gradient() is None
            l1 = fluid.Linear(2, 2)
            with fluid.dygraph.no_grad():
                # l1.weight.stop_gradient is False
                tmp = l1.weight * 2  # tmp.stop_gradient is True
            x = fluid.dygraph.to_variable(data)
            y = l0(x) + tmp
            o = l1(y)
            o.backward()
            print(tmp.gradient() is None)  # True
            print(l0.weight.gradient() is None)  # False

        # use as decorator

        @fluid.dygraph.no_grad
        def test_layer():
            with fluid.dygraph.guard():
                inp = np.ones([3, 1024], dtype='float32')
                t = fluid.dygraph.base.to_variable(inp)
                linear1 = fluid.Linear(1024, 4, bias_attr=False)
                linear2 = fluid.Linear(4, 4)
                ret = linear1(t)
                dy_ret = linear2(ret)

        test_layer()

    """
    if func is None:
        return _switch_tracer_mode_guard_(is_train=False)
    else:

        @decorator.decorator
        def __impl__(func, *args, **kwargs):
            with _switch_tracer_mode_guard_(is_train=False):
                return func(*args, **kwargs)

        return __impl__(func)


class no_grad_:
282
    """
283 284
    :api_attr: imperative

285
    Create a context which disables dygraph gradient calculation.
286 287
    In this mode, the result of every computation will have `stop_gradient` set
    to `True`.
288

289
    Also functions as a decorator. (Make sure to use an instance.)
290 291 292 293 294 295

    Examples:

     .. code-block:: python

        import numpy as np
296
        import paddle
297

298 299 300
        # use as generator

        data = np.array([[2, 3], [4, 5]]).astype('float32')
301 302 303
        l0 = paddle.nn.Linear(2, 2)  # l0.weight.gradient() is None
        l1 = paddle.nn.Linear(2, 2)
        with paddle.no_grad():
304 305
            # l1.weight.stop_gradient is False
            tmp = l1.weight * 2  # tmp.stop_gradient is True
306
        x = paddle.to_tensor(data)
307 308 309 310 311
        y = l0(x) + tmp
        o = l1(y)
        o.backward()
        print(tmp.gradient() is None)  # True
        print(l0.weight.gradient() is None)  # False
312 313 314

        # use as decorator

315
        @paddle.no_grad()
316
        def test_layer():
317
            inp = np.ones([3, 1024], dtype='float32')
318 319 320
            t = paddle.to_tensor(inp)
            linear1 = paddle.nn.Linear(1024, 4, bias_attr=False)
            linear2 = paddle.nn.Linear(4, 4)
321 322
            ret = linear1(t)
            dy_ret = linear2(ret)
323 324 325 326

        test_layer()
    """

327
    def __call__(self, func):
S
songyouwei 已提交
328
        @decorator.decorator
329 330
        def _decorate_function(func, *args, **kwargs):
            with self:
331
                return func(*args, **kwargs)
332

333 334 335 336 337 338 339 340 341 342 343 344 345 346 347
        @decorator.decorator
        def _decorate_generator(func, *args, **kwargs):
            gen = func(*args, **kwargs)
            with self:
                for x in gen:
                    yield x

        if inspect.isgeneratorfunction(func):
            return _decorate_generator(func)
        else:
            return _decorate_function(func)

    def __enter__(self):
        tracer = framework._dygraph_tracer()
        if tracer:
348 349
            self.orig = tracer._has_grad
            tracer._has_grad = False
350 351 352 353

    def __exit__(self, *args):
        tracer = framework._dygraph_tracer()
        if tracer:
354
            tracer._has_grad = self.orig
355 356


S
rename  
sneaxiy 已提交
357
@signature_safe_contextmanager
P
Paddle CI 已提交
358
def guard(place=None):
359
    """
360 361
    :api_attr: imperative

362
    This context will create a dygraph context for dygraph to run, using python ``with`` statement.
363

364
    Parameters:
365 366 367 368
        place(fluid.CPUPlace| fluid.CUDAPlace|str, optional): Place to execute dygraph. 
            If None, the running place will be determined according to the way of paddle compilation.
            If ``place`` is string, It can be ``cpu``, ``gpu:x`` and ``xpu:x``, where ``x`` is the
            index of the GPUs or XPUs. Default: None
369 370 371 372 373 374 375 376 377 378 379 380

    return:
        None

    Examples:

     .. code-block:: python

        import numpy as np
        import paddle.fluid as fluid

        with fluid.dygraph.guard():
381
            inp = np.ones([3, 1024], dtype='float32')
382
            t = fluid.dygraph.base.to_variable(inp)
383 384 385 386
            linear1 = fluid.Linear(1024, 4, bias_attr=False)
            linear2 = fluid.Linear(4, 4)
            ret = linear1(t)
            dy_ret = linear2(ret)
387 388

    """
389 390
    train = framework.Program()
    startup = framework.Program()
J
Jiabin Yang 已提交
391
    tracer = Tracer()
392
    VarBase = core.VarBase
M
minqiyang 已提交
393

394
    if place is not None:
395
        expected_place = _get_paddle_place(place)
396 397
    else:
        expected_place = framework._current_expected_place()
M
minqiyang 已提交
398

399 400
    with framework.program_guard(train, startup):
        with framework.unique_name.guard():
L
lujun 已提交
401
            with framework._dygraph_guard(tracer):
402
                with framework._dygraph_place_guard(expected_place):
P
Paddle CI 已提交
403
                    yield
404 405


406 407 408 409
@framework.dygraph_only
def grad(outputs,
         inputs,
         grad_outputs=None,
Z
Zeng Jinle 已提交
410
         retain_graph=None,
411
         create_graph=False,
Z
Zeng Jinle 已提交
412 413
         only_inputs=True,
         allow_unused=False,
414
         no_grad_vars=None):
Z
Zeng Jinle 已提交
415 416
    ''' 
    .. note::
417
        **This API is ONLY available in imperative mode.**
Z
Zeng Jinle 已提交
418 419 420 421

    This API computes the sum of gradients of `outputs` with respect to each `inputs` .

    Parameters:
422 423 424 425
        outputs (Tensor|list(Tensor)|tuple(Tensor)): the output Tensor or 
            Tensor list/tuple of the graph to compute gradients.
        inputs (Tensor|list(Tensor)|tuple(Tensor)): the input Tensor or 
            Tensor list/tuple of the graph to compute gradients. The returned
Z
Zeng Jinle 已提交
426
            values of this API are the gradients of `inputs` . 
427
        grad_outputs (Tensor|list(Tensor|None)|tuple(Tensor|None), optional): 
Z
Zeng Jinle 已提交
428 429 430 431 432 433
            initial gradient values of `outputs` . If `grad_outputs` is None, 
            the initial gradient values of `outputs` would be Tensors filled with 1; 
            if `grad_outputs` is not None, it must have the same length as `outputs` , 
            and in this case, the initial gradient value of the i-th `outputs` would
            be: (1) a Tensor filled with 1 when the i-th element of `grad_outputs` 
            is None; (2) the i-th element of `grad_outputs` when the i-th element of
434
            `grad_outputs` is a Tensor. Default None.
Z
Zeng Jinle 已提交
435 436 437 438 439 440 441 442 443 444 445
        retain_graph (bool, optional): whether to retain the forward graph which 
            is used to calculate the gradient. When it is True, the graph would 
            be retained, in which way users can calculate backward twice for the 
            same graph. When it is False, the graph would be freed. Default None,
            which means it is equal to `create_graph` . 
        create_graph (bool, optional): whether to create the gradient graphs of
            the computing process. When it is True, higher order derivatives are
            supported to compute; when it is False, the gradient graphs of the
            computing process would be discarded. Default False.
        only_inputs (bool, optional): whether to only compute the gradients of
            `inputs` . If it is False, the gradients of all remaining leaf 
446
            Tensors in the graph would be also computed and accumulated. 
Z
Zeng Jinle 已提交
447 448 449 450
            If it is True, only the gradients of `inputs` would be computed.
            Default True. only_inputs=False is under development, and it is
            not supported yet.    
        allow_unused (bool, optional): whether to raise error or return None if some 
451
            Tensors of `inputs` are unreachable in the graph. If some Tensors of 
Z
Zeng Jinle 已提交
452 453 454
            `inputs` are unreachable in the graph (i.e., their gradients are None),  
            error would be raised if allow_unused=False, or None would be returned as
            their gradients if allow_unused=True. Default False.
455 456
        no_grad_vars (Tensor|list(Tensor)|tuple(Tensor)|set(Tensor), optional): 
            the Tensors whose gradients are not needed to compute. Default None.
Z
Zeng Jinle 已提交
457 458

    Returns:
459 460
        tuple: a tuple of Tensors, whose length is the same as the Tensor number 
        inside `inputs`, and the i-th returned Tensor is the sum of gradients of 
Z
Zeng Jinle 已提交
461 462 463 464 465
        `outputs` with respect to the i-th `inputs`.

    Examples 1:
        .. code-block:: python

466
            import paddle
Z
Zeng Jinle 已提交
467 468

            def test_dygraph_grad(create_graph):
469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494
                x = paddle.ones(shape=[1], dtype='float32')
                x.stop_gradient = False
                y = x * x

                # Since y = x * x, dx = 2 * x
                dx = paddle.grad(
                        outputs=[y],
                        inputs=[x],
                        create_graph=create_graph,
                        retain_graph=True)[0]

                z = y + dx

                # If create_graph = False, the gradient of dx
                # would not be backpropagated. Therefore,
                # z = x * x + dx, and x.gradient() = 2 * x = 2.0

                # If create_graph = True, the gradient of dx
                # would be backpropagated. Therefore,
                # z = x * x + dx = x * x + 2 * x, and
                # x.gradient() = 2 * x + 2 = 4.0

                z.backward()
                return x.gradient()

            print(test_dygraph_grad(create_graph=False)) # [2.]
Z
Zeng Jinle 已提交
495 496 497 498 499
            print(test_dygraph_grad(create_graph=True)) # [4.]

    Examples 2:
        .. code-block:: python

500
            import paddle
Z
Zeng Jinle 已提交
501 502

            def test_dygraph_grad(grad_outputs=None):
503
                x = paddle.to_tensor(2.0)
Z
Zeng Jinle 已提交
504 505 506 507 508 509 510 511 512 513 514 515 516 517 518
                x.stop_gradient = False

                y1 = x * x
                y2 = x * 3 

                # If grad_outputs=None, dy1 = [1], dy2 = [1].
                # If grad_outputs=[g1, g2], then:
                #    - dy1 = [1] if g1 is None else g1
                #    - dy2 = [1] if g2 is None else g2

                # Since y1 = x * x, dx = 2 * x * dy1.
                # Since y2 = x * 3, dx = 3 * dy2.
                # Therefore, the final result would be:
                # dx = 2 * x * dy1 + 3 * dy2 = 4 * dy1 + 3 * dy2.

519
                dx = paddle.grad(
Z
Zeng Jinle 已提交
520 521 522 523 524 525
                    outputs=[y1, y2], 
                    inputs=[x],
                    grad_outputs=grad_outputs)[0]

                return dx.numpy()

526
            grad_value = paddle.to_tensor(4.0)
Z
Zeng Jinle 已提交
527 528 529 530
            # dy1 = [1], dy2 = [1]
            print(test_dygraph_grad(None)) # [7.]

            # dy1 = [1], dy2 = [4]
531
            print(test_dygraph_grad([None, grad_value])) # [16.]
Z
Zeng Jinle 已提交
532 533

            # dy1 = [4], dy2 = [1]
534
            print(test_dygraph_grad([grad_value, None])) # [19.]
Z
Zeng Jinle 已提交
535 536

            # dy1 = [3], dy2 = [4]
537
            grad_y1 = paddle.to_tensor(3.0)
538
            print(test_dygraph_grad([grad_y1, grad_value])) # [24.]
Z
Zeng Jinle 已提交
539 540
	'''

541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577
    def check_in_out(in_out_list, name):
        assert in_out_list is not None, "{} should not be None".format(name)

        if isinstance(in_out_list, (list, tuple)):
            assert len(in_out_list) > 0, "{} cannot be empty".format(name)
            for each_var in in_out_list:
                assert isinstance(
                    each_var,
                    core.VarBase), "Elements of {} must be Variable".format(
                        name)
            return in_out_list
        else:
            assert isinstance(
                in_out_list,
                core.VarBase), "{} must be Variable or list of Variable".format(
                    name)
            return [in_out_list]

    outputs = check_in_out(outputs, 'outputs')
    inputs = check_in_out(inputs, 'inputs')

    if grad_outputs is not None:
        if not isinstance(grad_outputs, (list, tuple)):
            grad_outputs = [grad_outputs]

        for each_var in grad_outputs:
            if each_var is not None:
                assert isinstance(
                    each_var, core.VarBase
                ), "grad_outputs must be None, a Variable or a list containing None or Variables"
    else:
        grad_outputs = []

    if len(grad_outputs) > 0:
        assert len(grad_outputs) == len(
            outputs), "The length of grad_outputs must be equal to outputs"

Z
Zeng Jinle 已提交
578 579 580 581 582 583 584
    if no_grad_vars is None:
        no_grad_vars = []
    elif isinstance(no_grad_vars, core.VarBase):
        no_grad_vars = [no_grad_vars]
    elif isinstance(no_grad_vars, (list, tuple, set)):
        no_grad_vars = list(no_grad_vars)
        for var in no_grad_vars:
585
            assert isinstance(
Z
Zeng Jinle 已提交
586
                var, core.VarBase), "no_grad_vars can only contains Variable"
587 588
    else:
        raise AssertionError(
Z
Zeng Jinle 已提交
589
            "no_grad_vars must be None, Variable or list/tuple/set of Variables")
590 591 592

    assert isinstance(create_graph, bool), "create_graph must be True or False"

Z
Zeng Jinle 已提交
593 594 595 596 597 598 599 600 601 602 603
    if retain_graph is None:
        retain_graph = create_graph

    assert isinstance(retain_graph,
                      bool), "retain_graph must be None, True or False"

    assert isinstance(allow_unused, bool), "allow_unused must be True or False"

    assert isinstance(only_inputs, bool), "only_inputs must be True or False"
    assert only_inputs, "only_inputs=False is not supported yet"

604 605
    place = core.Place()
    place.set_place(framework._current_expected_place())
606 607 608
    return core.dygraph_partial_grad(inputs, outputs, grad_outputs,
                                     no_grad_vars, place, create_graph,
                                     retain_graph, allow_unused, only_inputs)
609 610


611
@framework.dygraph_only
612
def to_variable(value, name=None, zero_copy=None, dtype=None):
613
    r"""
614 615
    :api_attr: imperative

C
chentianyu03 已提交
616 617
    The API will create a ``Variable`` object from 
    tuple, list, numpy\.ndarray or Variable object.
618

619
    Parameters:
C
chentianyu03 已提交
620 621
        value(tuple|list|ndarray|Variable|Tensor): Initial data. 
            Can be a list, tuple, NumPy ndarray, Variable, Tensor.
622 623 624
            The shape can be multi-dimensional. The data type is one of 
            numpy\.{float16, float32, float64, int16, int32, int64, 
            uint8, uint16, complex64, complex128}.
625 626
        name(str, optional): The default value is None. Normally there is no 
            need for user to set this property. For more information, please 
L
Leo Chen 已提交
627
            refer to :ref:`api_guide_Name` . 
628 629
        zero_copy(bool, optional): Whether to share memory with the input numpy 
            array. This parameter only works with CPUPlace and will be set to 
L
Leo Chen 已提交
630
            True when it is None. Default: None. (Note: zero_copy is discarded temporally for some reason.)
631 632 633
        dtype(str, optional): The desired data type of returned ``Variable`` .
            Can be 'bool' , 'float16' , 'float32' , 'float64' , 'int8' , 'int16' , 
            'int32' , 'int64' , 'uint8' . Default: None.
634

635
    Returns:
C
chentianyu03 已提交
636
        Variable : If ``value`` is a tuple/list/numpy\.ndarray object, 
637
            return ``Tensor`` created from the corresponding numpy\.ndarray object, which has 
C
chentianyu03 已提交
638
            same data type and shape with ``value``. 
639

640 641 642 643 644 645 646 647

    Examples:

     .. code-block:: python

        import numpy as np
        import paddle.fluid as fluid

648
        with fluid.dygraph.guard(fluid.CPUPlace()):
649
            x = np.ones([2, 2], np.float32)
650 651 652
            y = fluid.dygraph.to_variable(x, zero_copy=False)
            x[0][0] = -1
            y[0][0].numpy()  # array([1.], dtype=float32)
653
            y = fluid.dygraph.to_variable(x)
654 655
            x[0][0] = 0
            y[0][0].numpy()  # array([0.], dtype=float32)
656 657 658 659
            c = np.array([2+1j, 2])
            z = fluid.dygraph.to_variable(c)
            z.numpy() # array([2.+1.j, 2.+0.j])
            z.dtype # 'complex128'
660 661 662 663 664 665 666

            y = fluid.dygraph.to_variable([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]])
            y.shape     # [3L, 2L]

            y = fluid.dygraph.to_variable(((0.1, 1.2), (2.2, 3.1), (4.9, 5.2)), dtype='int32')
            y.shape     # [3L, 2L]

667
    """
668
    support_type = (list, tuple, np.ndarray, core.VarBase, framework.Variable,
C
chentianyu03 已提交
669
                    core.Tensor, core.LoDTensor)
670 671 672 673
    if not isinstance(value, support_type):
        raise TypeError(
            "The type of 'value' in fluid.dygraph.to_variable must be %s, but received %s."
            % (support_type, type(value)))
C
chentianyu03 已提交
674
    if isinstance(value, (core.VarBase, framework.Variable)):
675 676 677 678
        return value
    elif isinstance(value, (core.Tensor, core.LoDTensor)):
        return core.VarBase(value)
    else:
679 680
        if isinstance(framework._current_expected_place(),
                      framework.core.CPUPlace):
L
Leo Chen 已提交
681
            #TODO(zhiqiu): we found two problems when enable zero_copy on CPUPlace.
682
            # (1): eigen requires 16-bytes alignments, but the data of numpy array may not statisfy.
L
Leo Chen 已提交
683 684 685 686 687 688 689 690 691
            # Details: https://eigen.tuxfamily.org/dox/group__TopicUnalignedArrayAssert.html
            # (2): when used in flask framework, it may result in hang.
            # Details: https://github.com/PaddlePaddle/Paddle/issues/26635
            # So, we temporally diable the zero_copy strategy.
            if zero_copy == True:
                warnings.warn(
                    "Currently, zero_copy is not supported, and it will be discarded."
                )
                zero_copy = False
692 693
        else:
            assert not zero_copy, "zero_copy mode can only be used with CPUPlace"
694 695 696 697 698 699 700 701 702

        if not isinstance(value, np.ndarray):
            value = np.array(value)

        if dtype is not None:
            dtype = convert_dtype(dtype)
            if value.dtype != dtype:
                value = value.astype(dtype)

C
chentianyu03 已提交
703 704 705 706 707 708 709
        py_var = core.VarBase(
            value=value,
            place=framework._current_expected_place(),
            persistable=False,
            zero_copy=zero_copy,
            name=name if name else '')
        return py_var